454 lines
15 KiB
Python
454 lines
15 KiB
Python
from graph.directed_acyclic_graph.DAGLayout import DAGGraph, DAGOrderHelper
|
|
from graph.DataType import Arrow, Point
|
|
from typing import List, Dict
|
|
from PyQt5.QtWidgets import QWidget, QApplication, QGraphicsView, QGraphicsScene, QGraphicsItem
|
|
from PyQt5.QtCore import QRectF, Qt, QPointF, pyqtSignal
|
|
from PyQt5.QtGui import QFont, QBrush, QPen, QPainter, QPainterPath, QMouseEvent
|
|
import sys
|
|
from enum import Enum
|
|
from abc import abstractmethod
|
|
|
|
|
|
class GraphNodeType(Enum):
|
|
ActivePresentNode = 0,
|
|
PenetrateNode = 1,
|
|
TransitionCurve = 2,
|
|
|
|
|
|
class GraphNode:
|
|
@abstractmethod
|
|
def node_type(self) -> GraphNodeType:
|
|
raise NotImplementedError("node_type")
|
|
|
|
@abstractmethod
|
|
def highlight(self, flag: bool):
|
|
raise NotImplementedError("highlight")
|
|
|
|
@abstractmethod
|
|
def is_highlighted(self) -> bool:
|
|
raise NotImplementedError("is_highlighted")
|
|
|
|
@abstractmethod
|
|
def boundingRect(self) -> QRectF:
|
|
raise NotImplementedError("boundingRect")
|
|
|
|
@abstractmethod
|
|
def pos(self) -> QPointF:
|
|
raise NotImplementedError("pos")
|
|
|
|
|
|
class StoryNodeType(Enum):
|
|
"""
|
|
故事节点类型
|
|
"""
|
|
StoryStartNode = 0,
|
|
FragmentNode = 1,
|
|
|
|
|
|
class ActivePresentNode(QGraphicsItem, GraphNode):
|
|
"""
|
|
故事情节名称节点
|
|
"""
|
|
|
|
def __init__(self, name: str, type: StoryNodeType, font: QFont, parent):
|
|
QGraphicsItem.__init__(self, parent)
|
|
|
|
self.__highlight_mark = False
|
|
self.node_type_within_story = type
|
|
self.fragment_name = name
|
|
self.font_bind = font
|
|
self.data_bind: object = None
|
|
pass
|
|
|
|
def node_type(self) -> GraphNodeType:
|
|
return GraphNodeType.ActivePresentNode
|
|
|
|
def highlight(self, flag: bool):
|
|
self.__highlight_mark = flag
|
|
pass
|
|
|
|
def is_highlighted(self) -> bool:
|
|
return self.__highlight_mark
|
|
|
|
def boundingRect(self) -> QRectF:
|
|
size = self.font_bind.pixelSize()
|
|
if self.node_type_within_story == StoryNodeType.FragmentNode:
|
|
return QRectF(0, 0, size * len(self.fragment_name) + 10, size + 10)
|
|
else:
|
|
return QRectF(0, 0, size * len(self.fragment_name) + 10, size + 10)
|
|
|
|
def paint(self, painter, option, widget=...):
|
|
outline = self.boundingRect()
|
|
text_rect = QRectF(outline.x() + 5, outline.y() + 5, outline.width() - 10, outline.height() - 10)
|
|
|
|
painter.save()
|
|
|
|
if self.node_type_within_story == StoryNodeType.FragmentNode:
|
|
painter.fillRect(outline, Qt.gray)
|
|
else:
|
|
painter.fillRect(outline, Qt.green)
|
|
pass
|
|
|
|
if self.__highlight_mark:
|
|
painter.setPen(Qt.red)
|
|
|
|
painter.drawRect(outline)
|
|
painter.setFont(self.font_bind)
|
|
painter.drawText(text_rect, self.fragment_name)
|
|
painter.restore()
|
|
|
|
pass
|
|
|
|
# 添加自定义功能函数
|
|
def attach_data(self, data_inst: object):
|
|
self.data_bind = data_inst
|
|
pass
|
|
|
|
def get_data(self) -> object:
|
|
return self.data_bind
|
|
|
|
pass
|
|
|
|
|
|
class PenetrateNode(QGraphicsItem, GraphNode):
|
|
"""
|
|
贯穿连线节点
|
|
"""
|
|
|
|
def __init__(self, width: float, parent):
|
|
QGraphicsItem.__init__(self, parent)
|
|
self.width_store = width
|
|
self.__highlight_mark = False
|
|
self.data_bind: object = None
|
|
pass
|
|
|
|
def node_type(self) -> GraphNodeType:
|
|
return GraphNodeType.PenetrateNode
|
|
|
|
def highlight(self, flag: bool):
|
|
self.__highlight_mark = flag
|
|
pass
|
|
|
|
def is_highlighted(self) -> bool:
|
|
return self.__highlight_mark
|
|
|
|
def resize_width(self, width: float):
|
|
self.width_store = width
|
|
self.update()
|
|
pass
|
|
|
|
def boundingRect(self):
|
|
return QRectF(0, 0, self.width_store, 8)
|
|
|
|
def paint(self, painter, option, widget=...):
|
|
outline = self.boundingRect()
|
|
|
|
painter.save()
|
|
|
|
if self.__highlight_mark:
|
|
painter.fillRect(QRectF(0, 2, outline.width(), 4), Qt.red)
|
|
else:
|
|
painter.fillRect(QRectF(0, 2, outline.width(), 4), Qt.black)
|
|
pass
|
|
|
|
painter.restore()
|
|
pass
|
|
|
|
# 添加自定义功能函数
|
|
def attach_data(self, data_inst: object):
|
|
self.data_bind = data_inst
|
|
pass
|
|
|
|
def get_data(self) -> object:
|
|
return self.data_bind
|
|
|
|
|
|
class TransitionCurve(QGraphicsItem, GraphNode):
|
|
def __init__(self, start: GraphNode, end: GraphNode, prev_layrer_width:float, parent):
|
|
QGraphicsItem.__init__(self, parent)
|
|
self.__highlight_mark = False
|
|
self.start_node = start
|
|
self.end_node = end
|
|
self.prev_layer_w = prev_layrer_width
|
|
self.outline = QRectF()
|
|
self.data_bind: object = None
|
|
pass
|
|
|
|
def node_type(self) -> GraphNodeType:
|
|
return GraphNodeType.TransitionCurve
|
|
|
|
def highlight(self, flag: bool):
|
|
self.__highlight_mark = flag
|
|
pass
|
|
|
|
def is_highlighted(self) -> bool:
|
|
return self.__highlight_mark
|
|
|
|
def layout_refresh(self):
|
|
orect = self.start_node.boundingRect()
|
|
erect = self.end_node.boundingRect()
|
|
|
|
xpos = self.start_node.pos().x() + self.start_node.boundingRect().width()
|
|
width_value = self.end_node.pos().x() - self.start_node.pos().x() - orect.width()
|
|
|
|
ypos = min(self.start_node.pos().y(), self.end_node.pos().y())
|
|
bottom_y = max(self.start_node.pos().y() + orect.height(), self.end_node.pos().y() + erect.height())
|
|
|
|
self.setPos(xpos, ypos)
|
|
self.outline = QRectF(0, 0, width_value, bottom_y - ypos)
|
|
self.update()
|
|
|
|
def boundingRect(self):
|
|
return self.outline
|
|
pass
|
|
|
|
def paint(self, painter, option, widget=...):
|
|
outline = self.outline
|
|
|
|
start_rect = self.start_node.boundingRect()
|
|
end_rect = self.end_node.boundingRect()
|
|
aj_start_pos = self.start_node.pos() + QPointF(start_rect.width(), start_rect.height()/2)
|
|
aj_end_pos = self.end_node.pos() + QPointF(0, end_rect.height()/2)
|
|
line_span = self.prev_layer_w - start_rect.width()
|
|
|
|
painter.save()
|
|
painter.setRenderHint(QPainter.RenderHint.Antialiasing)
|
|
|
|
#if aj_start_pos.y() < aj_end_pos.y():
|
|
start_pos = aj_start_pos - self.pos()
|
|
end_pos = aj_end_pos - self.pos()
|
|
|
|
line_epos = start_pos + QPointF(line_span, 0)
|
|
control_pos0 = line_epos + QPointF((outline.width() - line_span)/3, 0)
|
|
control_pos1 = end_pos - QPointF((outline.width() - line_span)/3, 0)
|
|
|
|
npen = QPen(Qt.black)
|
|
if self.__highlight_mark:
|
|
npen = QPen(Qt.red)
|
|
pass
|
|
|
|
npen.setWidthF(4)
|
|
painter.setPen(npen)
|
|
|
|
painter.drawLine(start_pos, line_epos)
|
|
|
|
path0 = QPainterPath()
|
|
path0.moveTo(line_epos)
|
|
path0.cubicTo(control_pos0, control_pos1, end_pos)
|
|
painter.drawPath(path0)
|
|
|
|
painter.restore()
|
|
pass
|
|
|
|
# 添加自定义功能函数
|
|
def attach_data(self, data_inst: object):
|
|
self.data_bind = data_inst
|
|
pass
|
|
|
|
def get_data(self) -> object:
|
|
return self.data_bind
|
|
|
|
|
|
class Direction(Enum):
|
|
RankLR = 0,
|
|
RankTB = 1,
|
|
|
|
|
|
class DAGActiveView(QGraphicsView):
|
|
nodes_clicked = pyqtSignal(int,int,list)
|
|
|
|
def __init__(self, parent):
|
|
QGraphicsView.__init__(self, parent)
|
|
self.setViewportUpdateMode(QGraphicsView.ViewportUpdateMode.FullViewportUpdate)
|
|
|
|
font = QFont()
|
|
font.setPixelSize(20)
|
|
self.setFont(font)
|
|
|
|
self.layer_span = 200
|
|
self.node_span = 20
|
|
|
|
self.scene_bind = QGraphicsScene(self)
|
|
self.setScene(self.scene_bind)
|
|
|
|
self.__highlight_nodelist: List[GraphNode] = []
|
|
self.__total_graph_nodes: Dict[str, GraphNode] = {}
|
|
pass
|
|
|
|
def update_with_edges(self, arrows: List[Arrow]) -> None:
|
|
tools = DAGGraph()
|
|
tools.rebuild_from_edges(arrows)
|
|
tools.graph_layout()
|
|
|
|
total_nodes = tools.nodes_with_layout
|
|
previous_node_end = 0
|
|
previous_graphics_nodes = []
|
|
# 迭代呈现层
|
|
for layer_idx in range(0, tools.max_layer_count):
|
|
current_level_nodes: List[DAGOrderHelper] = list(filter(lambda n: n.layer_number == layer_idx, total_nodes))
|
|
current_level_nodes.sort(key=lambda n: n.sort_number)
|
|
|
|
# 构建当前层节点
|
|
ypos_acc = 0
|
|
current_graphics_nodes = []
|
|
for node in current_level_nodes:
|
|
if node.is_fake_node():
|
|
curr_gnode = PenetrateNode(20, None)
|
|
curr_gnode.setPos(previous_node_end, ypos_acc)
|
|
curr_gnode.attach_data(node)
|
|
ypos_acc += curr_gnode.boundingRect().height()
|
|
current_graphics_nodes.append(curr_gnode)
|
|
self.scene_bind.addItem(curr_gnode)
|
|
|
|
fragm_start = node.relate_bind.bind_point().point_name
|
|
fragm_end = node.towards_to.bind_node.point_name
|
|
node_key = (f"plac${fragm_start}::{fragm_end}-{node.layer_number}")
|
|
curr_gnode.node_key_bind = "plac", fragm_start, fragm_end, node.layer_number
|
|
self.__total_graph_nodes[node_key] = curr_gnode
|
|
pass
|
|
else:
|
|
node_type_vx = StoryNodeType.FragmentNode
|
|
if node.layer_bind.input_count == 0:
|
|
node_type_vx = StoryNodeType.StoryStartNode
|
|
|
|
curr_gnode = ActivePresentNode(node.layer_bind.bind_point().point_name,
|
|
node_type_vx, self.font(), None)
|
|
|
|
curr_gnode.attach_data(node)
|
|
curr_gnode.setPos(previous_node_end, ypos_acc)
|
|
ypos_acc += curr_gnode.boundingRect().height()
|
|
current_graphics_nodes.append(curr_gnode)
|
|
self.scene_bind.addItem(curr_gnode)
|
|
|
|
node_key = f"node@{node.layer_bind.bind_point().point_name}"
|
|
curr_gnode.node_key_bind = "node", node.layer_bind.bind_point().point_name
|
|
self.__total_graph_nodes[node_key] = curr_gnode
|
|
pass
|
|
|
|
ypos_acc += self.node_span
|
|
pass
|
|
|
|
# 调整同层节点宽度
|
|
curr_layer_width = 0
|
|
for n in current_graphics_nodes:
|
|
curr_layer_width = max(curr_layer_width, n.boundingRect().width())
|
|
pass
|
|
for n in current_graphics_nodes:
|
|
if hasattr(n, "resize_width"):
|
|
n.resize_width(curr_layer_width)
|
|
pass
|
|
pass
|
|
|
|
previous_node_end += curr_layer_width + self.layer_span
|
|
if len(previous_graphics_nodes) > 0:
|
|
prev_layer_width = 0
|
|
for n in previous_graphics_nodes:
|
|
prev_layer_width = max(prev_layer_width, n.boundingRect().width())
|
|
pass
|
|
|
|
for curr_gnode in current_graphics_nodes:
|
|
sort_helper: DAGOrderHelper = curr_gnode.get_data()
|
|
for prev_gnode in previous_graphics_nodes:
|
|
if prev_gnode.get_data() in sort_helper.get_upstream_nodes():
|
|
line_cmbn = TransitionCurve(prev_gnode, curr_gnode, prev_layer_width,None)
|
|
self.scene_bind.addItem(line_cmbn)
|
|
line_cmbn.layout_refresh()
|
|
|
|
relate_node_name = ""
|
|
if prev_gnode.node_type() == GraphNodeType.ActivePresentNode:
|
|
relate_node_name = prev_gnode.get_data().layer_bind.bind_point().point_name
|
|
elif prev_gnode.node_type() == GraphNodeType.PenetrateNode:
|
|
relate_node_name = prev_gnode.get_data().relate_bind.bind_point().point_name
|
|
|
|
towards_node_name = ""
|
|
if curr_gnode.node_type() == GraphNodeType.ActivePresentNode:
|
|
towards_node_name = sort_helper.layer_bind.bind_point().point_name
|
|
elif curr_gnode.node_type() == GraphNodeType.PenetrateNode:
|
|
towards_node_name = sort_helper.towards_to.bind_point().point_name
|
|
|
|
fragm_start = relate_node_name
|
|
fragm_end = towards_node_name
|
|
node_key = f"curv&{relate_node_name}::{towards_node_name}-{sort_helper.layer_number}"
|
|
line_cmbn.node_key_bind = "curv", fragm_start, fragm_end, sort_helper.layer_number
|
|
self.__total_graph_nodes[node_key] = line_cmbn
|
|
pass
|
|
pass
|
|
pass
|
|
pass
|
|
|
|
previous_graphics_nodes = current_graphics_nodes
|
|
pass
|
|
|
|
pass
|
|
|
|
def highlight_graph_link(self, highlight_path: List[str]):
|
|
for n in self.__highlight_nodelist:
|
|
n.highlight(False)
|
|
pass
|
|
self.__highlight_nodelist.clear()
|
|
|
|
start_node = self.__total_graph_nodes[f"node@{highlight_path[0]}"]
|
|
start_node.highlight(True)
|
|
self.__highlight_nodelist.append(start_node)
|
|
|
|
for idx in range(1, len(highlight_path)):
|
|
start_name = highlight_path[idx-1]
|
|
end_name = highlight_path[idx]
|
|
|
|
end_node = self.__total_graph_nodes[f"node@{end_name}"]
|
|
end_node.highlight(True)
|
|
self.__highlight_nodelist.append(end_node)
|
|
|
|
plac_key = f"plac${start_name}::{end_name}"
|
|
curv_key = f"curv&{start_name}::{end_name}"
|
|
for key in self.__total_graph_nodes:
|
|
if key.startswith(plac_key):
|
|
placex = self.__total_graph_nodes[key]
|
|
placex.highlight(True)
|
|
self.__highlight_nodelist.append(placex)
|
|
pass
|
|
if key.startswith(curv_key):
|
|
curvx = self.__total_graph_nodes[key]
|
|
curvx.highlight(True)
|
|
self.__highlight_nodelist.append(curvx)
|
|
pass
|
|
pass
|
|
pass
|
|
|
|
self.scene_bind.update()
|
|
self.update()
|
|
pass
|
|
|
|
def mousePressEvent(self, event: QMouseEvent):
|
|
QGraphicsView.mousePressEvent(self, event)
|
|
|
|
if event.button() == Qt.MouseButton.LeftButton:
|
|
gitems = self.items(event.pos())
|
|
noderef_names = []
|
|
for gnode in gitems:
|
|
if gnode.node_key_bind[0].startswith("node"):
|
|
noderef_names.append(gnode.node_key_bind)
|
|
pass
|
|
pass
|
|
|
|
self.nodes_clicked.emit(event.pos().x(), event.pos().y(), noderef_names[0:1])
|
|
pass
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app = QApplication(sys.argv)
|
|
view = DAGActiveView(None)
|
|
view.show()
|
|
|
|
arrows = [
|
|
Arrow(Point('a'), Point('b')),
|
|
Arrow(Point('a'), Point('c')),
|
|
Arrow(Point('c'), Point('d')),
|
|
Arrow(Point('a'), Point('d')),
|
|
Arrow(Point('c'), Point('e')),
|
|
Arrow(Point('c'), Point('f')),
|
|
]
|
|
view.update_with_edges(arrows)
|
|
|
|
app.exec()
|