from graph.directed_acyclic_graph.DAGLayout import DAGGraph, DAGOrderHelper from graph.DataType import Arrow, Point from typing import List, Dict from PyQt5.QtWidgets import QWidget, QApplication, QGraphicsView, QGraphicsScene, QGraphicsItem from PyQt5.QtCore import QRectF, Qt, QPointF, pyqtSignal from PyQt5.QtGui import QFont, QBrush, QPen, QPainter, QPainterPath, QMouseEvent import sys from enum import Enum from abc import abstractmethod class GraphNodeType(Enum): ActivePresentNode = 0, PenetrateNode = 1, TransitionCurve = 2, class GraphNode: @abstractmethod def node_type(self) -> GraphNodeType: raise NotImplementedError("node_type") @abstractmethod def highlight(self, flag: bool): raise NotImplementedError("highlight") @abstractmethod def is_highlighted(self) -> bool: raise NotImplementedError("is_highlighted") @abstractmethod def boundingRect(self) -> QRectF: raise NotImplementedError("boundingRect") @abstractmethod def pos(self) -> QPointF: raise NotImplementedError("pos") class StoryNodeType(Enum): """ 故事节点类型 """ StoryStartNode = 0, FragmentNode = 1, class ActivePresentNode(QGraphicsItem, GraphNode): """ 故事情节名称节点 """ def __init__(self, name: str, type: StoryNodeType, font: QFont, parent): QGraphicsItem.__init__(self, parent) self.__highlight_mark = False self.node_type_within_story = type self.fragment_name = name self.font_bind = font self.data_bind: object = None pass def node_type(self) -> GraphNodeType: return GraphNodeType.ActivePresentNode def highlight(self, flag: bool): self.__highlight_mark = flag pass def is_highlighted(self) -> bool: return self.__highlight_mark def boundingRect(self) -> QRectF: size = self.font_bind.pixelSize() if self.node_type_within_story == StoryNodeType.FragmentNode: return QRectF(0, 0, size * len(self.fragment_name) + 10, size + 10) else: return QRectF(0, 0, size * len(self.fragment_name) + 10, size + 10) def paint(self, painter, option, widget=...): outline = self.boundingRect() text_rect = QRectF(outline.x() + 5, outline.y() + 5, outline.width() - 10, outline.height() - 10) painter.save() if self.node_type_within_story == StoryNodeType.FragmentNode: painter.fillRect(outline, Qt.gray) else: painter.fillRect(outline, Qt.green) pass if self.__highlight_mark: painter.setPen(Qt.red) painter.drawRect(outline) painter.setFont(self.font_bind) painter.drawText(text_rect, self.fragment_name) painter.restore() pass # 添加自定义功能函数 def attach_data(self, data_inst: object): self.data_bind = data_inst pass def get_data(self) -> object: return self.data_bind pass class PenetrateNode(QGraphicsItem, GraphNode): """ 贯穿连线节点 """ def __init__(self, width: float, parent): QGraphicsItem.__init__(self, parent) self.width_store = width self.__highlight_mark = False self.data_bind: object = None pass def node_type(self) -> GraphNodeType: return GraphNodeType.PenetrateNode def highlight(self, flag: bool): self.__highlight_mark = flag pass def is_highlighted(self) -> bool: return self.__highlight_mark def resize_width(self, width: float): self.width_store = width self.update() pass def boundingRect(self): return QRectF(0, 0, self.width_store, 8) def paint(self, painter, option, widget=...): outline = self.boundingRect() painter.save() if self.__highlight_mark: painter.fillRect(QRectF(0, 2, outline.width(), 4), Qt.red) else: painter.fillRect(QRectF(0, 2, outline.width(), 4), Qt.black) pass painter.restore() pass # 添加自定义功能函数 def attach_data(self, data_inst: object): self.data_bind = data_inst pass def get_data(self) -> object: return self.data_bind class TransitionCurve(QGraphicsItem, GraphNode): def __init__(self, start: GraphNode, end: GraphNode, prev_layrer_width:float, parent): QGraphicsItem.__init__(self, parent) self.__highlight_mark = False self.start_node = start self.end_node = end self.prev_layer_w = prev_layrer_width self.outline = QRectF() self.data_bind: object = None pass def node_type(self) -> GraphNodeType: return GraphNodeType.TransitionCurve def highlight(self, flag: bool): self.__highlight_mark = flag pass def is_highlighted(self) -> bool: return self.__highlight_mark def layout_refresh(self): orect = self.start_node.boundingRect() erect = self.end_node.boundingRect() xpos = self.start_node.pos().x() + self.start_node.boundingRect().width() width_value = self.end_node.pos().x() - self.start_node.pos().x() - orect.width() ypos = min(self.start_node.pos().y(), self.end_node.pos().y()) bottom_y = max(self.start_node.pos().y() + orect.height(), self.end_node.pos().y() + erect.height()) self.setPos(xpos, ypos) self.outline = QRectF(0, 0, width_value, bottom_y - ypos) self.update() def boundingRect(self): return self.outline pass def paint(self, painter, option, widget=...): outline = self.outline start_rect = self.start_node.boundingRect() end_rect = self.end_node.boundingRect() aj_start_pos = self.start_node.pos() + QPointF(start_rect.width(), start_rect.height()/2) aj_end_pos = self.end_node.pos() + QPointF(0, end_rect.height()/2) line_span = self.prev_layer_w - start_rect.width() painter.save() painter.setRenderHint(QPainter.RenderHint.Antialiasing) #if aj_start_pos.y() < aj_end_pos.y(): start_pos = aj_start_pos - self.pos() end_pos = aj_end_pos - self.pos() line_epos = start_pos + QPointF(line_span, 0) control_pos0 = line_epos + QPointF((outline.width() - line_span)/3, 0) control_pos1 = end_pos - QPointF((outline.width() - line_span)/3, 0) npen = QPen(Qt.black) if self.__highlight_mark: npen = QPen(Qt.red) pass npen.setWidthF(4) painter.setPen(npen) painter.drawLine(start_pos, line_epos) path0 = QPainterPath() path0.moveTo(line_epos) path0.cubicTo(control_pos0, control_pos1, end_pos) painter.drawPath(path0) painter.restore() pass # 添加自定义功能函数 def attach_data(self, data_inst: object): self.data_bind = data_inst pass def get_data(self) -> object: return self.data_bind class Direction(Enum): RankLR = 0, RankTB = 1, class DAGActiveView(QGraphicsView): nodes_clicked = pyqtSignal(int,int,list) def __init__(self, parent): QGraphicsView.__init__(self, parent) self.setViewportUpdateMode(QGraphicsView.ViewportUpdateMode.FullViewportUpdate) font = QFont() font.setPixelSize(20) self.setFont(font) self.layer_span = 200 self.node_span = 20 self.scene_bind = QGraphicsScene(self) self.setScene(self.scene_bind) self.__highlight_nodelist: List[GraphNode] = [] self.__total_graph_nodes: Dict[str, GraphNode] = {} pass def update_with_edges(self, arrows: List[Arrow]) -> None: tools = DAGGraph() tools.rebuild_from_edges(arrows) tools.graph_layout() total_nodes = tools.nodes_with_layout previous_node_end = 0 previous_graphics_nodes = [] # 迭代呈现层 for layer_idx in range(0, tools.max_layer_count): current_level_nodes: List[DAGOrderHelper] = list(filter(lambda n: n.layer_number == layer_idx, total_nodes)) current_level_nodes.sort(key=lambda n: n.sort_number) # 构建当前层节点 ypos_acc = 0 current_graphics_nodes = [] for node in current_level_nodes: if node.is_fake_node(): curr_gnode = PenetrateNode(20, None) curr_gnode.setPos(previous_node_end, ypos_acc) curr_gnode.attach_data(node) ypos_acc += curr_gnode.boundingRect().height() current_graphics_nodes.append(curr_gnode) self.scene_bind.addItem(curr_gnode) fragm_start = node.relate_bind.bind_point().point_name fragm_end = node.towards_to.bind_node.point_name node_key = (f"plac${fragm_start}::{fragm_end}-{node.layer_number}") curr_gnode.node_key_bind = "plac", fragm_start, fragm_end, node.layer_number self.__total_graph_nodes[node_key] = curr_gnode pass else: node_type_vx = StoryNodeType.FragmentNode if node.layer_bind.input_count == 0: node_type_vx = StoryNodeType.StoryStartNode curr_gnode = ActivePresentNode(node.layer_bind.bind_point().point_name, node_type_vx, self.font(), None) curr_gnode.attach_data(node) curr_gnode.setPos(previous_node_end, ypos_acc) ypos_acc += curr_gnode.boundingRect().height() current_graphics_nodes.append(curr_gnode) self.scene_bind.addItem(curr_gnode) node_key = f"node@{node.layer_bind.bind_point().point_name}" curr_gnode.node_key_bind = "node", node.layer_bind.bind_point().point_name self.__total_graph_nodes[node_key] = curr_gnode pass ypos_acc += self.node_span pass # 调整同层节点宽度 curr_layer_width = 0 for n in current_graphics_nodes: curr_layer_width = max(curr_layer_width, n.boundingRect().width()) pass for n in current_graphics_nodes: if hasattr(n, "resize_width"): n.resize_width(curr_layer_width) pass pass previous_node_end += curr_layer_width + self.layer_span if len(previous_graphics_nodes) > 0: prev_layer_width = 0 for n in previous_graphics_nodes: prev_layer_width = max(prev_layer_width, n.boundingRect().width()) pass for curr_gnode in current_graphics_nodes: sort_helper: DAGOrderHelper = curr_gnode.get_data() for prev_gnode in previous_graphics_nodes: if prev_gnode.get_data() in sort_helper.get_upstream_nodes(): line_cmbn = TransitionCurve(prev_gnode, curr_gnode, prev_layer_width,None) self.scene_bind.addItem(line_cmbn) line_cmbn.layout_refresh() relate_node_name = "" if prev_gnode.node_type() == GraphNodeType.ActivePresentNode: relate_node_name = prev_gnode.get_data().layer_bind.bind_point().point_name elif prev_gnode.node_type() == GraphNodeType.PenetrateNode: relate_node_name = prev_gnode.get_data().relate_bind.bind_point().point_name towards_node_name = "" if curr_gnode.node_type() == GraphNodeType.ActivePresentNode: towards_node_name = sort_helper.layer_bind.bind_point().point_name elif curr_gnode.node_type() == GraphNodeType.PenetrateNode: towards_node_name = sort_helper.towards_to.bind_point().point_name fragm_start = relate_node_name fragm_end = towards_node_name node_key = f"curv&{relate_node_name}::{towards_node_name}-{sort_helper.layer_number}" line_cmbn.node_key_bind = "curv", fragm_start, fragm_end, sort_helper.layer_number self.__total_graph_nodes[node_key] = line_cmbn pass pass pass pass previous_graphics_nodes = current_graphics_nodes pass pass def highlight_graph_link(self, highlight_path: List[str]): for n in self.__highlight_nodelist: n.highlight(False) pass self.__highlight_nodelist.clear() start_node = self.__total_graph_nodes[f"node@{highlight_path[0]}"] start_node.highlight(True) self.__highlight_nodelist.append(start_node) for idx in range(1, len(highlight_path)): start_name = highlight_path[idx-1] end_name = highlight_path[idx] end_node = self.__total_graph_nodes[f"node@{end_name}"] end_node.highlight(True) self.__highlight_nodelist.append(end_node) plac_key = f"plac${start_name}::{end_name}" curv_key = f"curv&{start_name}::{end_name}" for key in self.__total_graph_nodes: if key.startswith(plac_key): placex = self.__total_graph_nodes[key] placex.highlight(True) self.__highlight_nodelist.append(placex) pass if key.startswith(curv_key): curvx = self.__total_graph_nodes[key] curvx.highlight(True) self.__highlight_nodelist.append(curvx) pass pass pass self.scene_bind.update() self.update() pass def mousePressEvent(self, event: QMouseEvent): QGraphicsView.mousePressEvent(self, event) if event.button() == Qt.MouseButton.LeftButton: gitems = self.items(event.pos()) noderef_names = [] for gnode in gitems: if gnode.node_key_bind[0].startswith("node"): noderef_names.append(gnode.node_key_bind) pass pass self.nodes_clicked.emit(event.pos().x(), event.pos().y(), noderef_names[0:1]) pass if __name__ == "__main__": app = QApplication(sys.argv) view = DAGActiveView(None) view.show() arrows = [ Arrow(Point('a'), Point('b')), Arrow(Point('a'), Point('c')), Arrow(Point('c'), Point('d')), Arrow(Point('a'), Point('d')), Arrow(Point('c'), Point('e')), Arrow(Point('c'), Point('f')), ] view.update_with_edges(arrows) app.exec()