StoryCheckTools/graph/directed_acyclic_graph/DAGPresent.py

454 lines
15 KiB
Python

from graph.directed_acyclic_graph.DAGLayout import DAGGraph, DAGOrderHelper
from graph.DataType import Arrow, Point
from typing import List, Dict
from PyQt5.QtWidgets import QWidget, QApplication, QGraphicsView, QGraphicsScene, QGraphicsItem
from PyQt5.QtCore import QRectF, Qt, QPointF, pyqtSignal
from PyQt5.QtGui import QFont, QBrush, QPen, QPainter, QPainterPath, QMouseEvent
import sys
from enum import Enum
from abc import abstractmethod
class GraphNodeType(Enum):
ActivePresentNode = 0,
PenetrateNode = 1,
TransitionCurve = 2,
class GraphNode:
@abstractmethod
def node_type(self) -> GraphNodeType:
raise NotImplementedError("node_type")
@abstractmethod
def highlight(self, flag: bool):
raise NotImplementedError("highlight")
@abstractmethod
def is_highlighted(self) -> bool:
raise NotImplementedError("is_highlighted")
@abstractmethod
def boundingRect(self) -> QRectF:
raise NotImplementedError("boundingRect")
@abstractmethod
def pos(self) -> QPointF:
raise NotImplementedError("pos")
class StoryNodeType(Enum):
"""
故事节点类型
"""
StoryStartNode = 0,
FragmentNode = 1,
class ActivePresentNode(QGraphicsItem, GraphNode):
"""
故事情节名称节点
"""
def __init__(self, name: str, type: StoryNodeType, font: QFont, parent):
QGraphicsItem.__init__(self, parent)
self.__highlight_mark = False
self.node_type_within_story = type
self.fragment_name = name
self.font_bind = font
self.data_bind: object = None
pass
def node_type(self) -> GraphNodeType:
return GraphNodeType.ActivePresentNode
def highlight(self, flag: bool):
self.__highlight_mark = flag
pass
def is_highlighted(self) -> bool:
return self.__highlight_mark
def boundingRect(self) -> QRectF:
size = self.font_bind.pixelSize()
if self.node_type_within_story == StoryNodeType.FragmentNode:
return QRectF(0, 0, size * len(self.fragment_name) + 10, size + 10)
else:
return QRectF(0, 0, size * len(self.fragment_name) + 10, size + 10)
def paint(self, painter, option, widget=...):
outline = self.boundingRect()
text_rect = QRectF(outline.x() + 5, outline.y() + 5, outline.width() - 10, outline.height() - 10)
painter.save()
if self.node_type_within_story == StoryNodeType.FragmentNode:
painter.fillRect(outline, Qt.gray)
else:
painter.fillRect(outline, Qt.green)
pass
if self.__highlight_mark:
painter.setPen(Qt.red)
painter.drawRect(outline)
painter.setFont(self.font_bind)
painter.drawText(text_rect, self.fragment_name)
painter.restore()
pass
# 添加自定义功能函数
def attach_data(self, data_inst: object):
self.data_bind = data_inst
pass
def get_data(self) -> object:
return self.data_bind
pass
class PenetrateNode(QGraphicsItem, GraphNode):
"""
贯穿连线节点
"""
def __init__(self, width: float, parent):
QGraphicsItem.__init__(self, parent)
self.width_store = width
self.__highlight_mark = False
self.data_bind: object = None
pass
def node_type(self) -> GraphNodeType:
return GraphNodeType.PenetrateNode
def highlight(self, flag: bool):
self.__highlight_mark = flag
pass
def is_highlighted(self) -> bool:
return self.__highlight_mark
def resize_width(self, width: float):
self.width_store = width
self.update()
pass
def boundingRect(self):
return QRectF(0, 0, self.width_store, 8)
def paint(self, painter, option, widget=...):
outline = self.boundingRect()
painter.save()
if self.__highlight_mark:
painter.fillRect(QRectF(0, 2, outline.width(), 4), Qt.red)
else:
painter.fillRect(QRectF(0, 2, outline.width(), 4), Qt.black)
pass
painter.restore()
pass
# 添加自定义功能函数
def attach_data(self, data_inst: object):
self.data_bind = data_inst
pass
def get_data(self) -> object:
return self.data_bind
class TransitionCurve(QGraphicsItem, GraphNode):
def __init__(self, start: GraphNode, end: GraphNode, prev_layrer_width:float, parent):
QGraphicsItem.__init__(self, parent)
self.__highlight_mark = False
self.start_node = start
self.end_node = end
self.prev_layer_w = prev_layrer_width
self.outline = QRectF()
self.data_bind: object = None
pass
def node_type(self) -> GraphNodeType:
return GraphNodeType.TransitionCurve
def highlight(self, flag: bool):
self.__highlight_mark = flag
pass
def is_highlighted(self) -> bool:
return self.__highlight_mark
def layout_refresh(self):
orect = self.start_node.boundingRect()
erect = self.end_node.boundingRect()
xpos = self.start_node.pos().x() + self.start_node.boundingRect().width()
width_value = self.end_node.pos().x() - self.start_node.pos().x() - orect.width()
ypos = min(self.start_node.pos().y(), self.end_node.pos().y())
bottom_y = max(self.start_node.pos().y() + orect.height(), self.end_node.pos().y() + erect.height())
self.setPos(xpos, ypos)
self.outline = QRectF(0, 0, width_value, bottom_y - ypos)
self.update()
def boundingRect(self):
return self.outline
pass
def paint(self, painter, option, widget=...):
outline = self.outline
start_rect = self.start_node.boundingRect()
end_rect = self.end_node.boundingRect()
aj_start_pos = self.start_node.pos() + QPointF(start_rect.width(), start_rect.height()/2)
aj_end_pos = self.end_node.pos() + QPointF(0, end_rect.height()/2)
line_span = self.prev_layer_w - start_rect.width()
painter.save()
painter.setRenderHint(QPainter.RenderHint.Antialiasing)
#if aj_start_pos.y() < aj_end_pos.y():
start_pos = aj_start_pos - self.pos()
end_pos = aj_end_pos - self.pos()
line_epos = start_pos + QPointF(line_span, 0)
control_pos0 = line_epos + QPointF((outline.width() - line_span)/3, 0)
control_pos1 = end_pos - QPointF((outline.width() - line_span)/3, 0)
npen = QPen(Qt.black)
if self.__highlight_mark:
npen = QPen(Qt.red)
pass
npen.setWidthF(4)
painter.setPen(npen)
painter.drawLine(start_pos, line_epos)
path0 = QPainterPath()
path0.moveTo(line_epos)
path0.cubicTo(control_pos0, control_pos1, end_pos)
painter.drawPath(path0)
painter.restore()
pass
# 添加自定义功能函数
def attach_data(self, data_inst: object):
self.data_bind = data_inst
pass
def get_data(self) -> object:
return self.data_bind
class Direction(Enum):
RankLR = 0,
RankTB = 1,
class DAGActiveView(QGraphicsView):
nodes_clicked = pyqtSignal(int,int,list)
def __init__(self, parent):
QGraphicsView.__init__(self, parent)
self.setViewportUpdateMode(QGraphicsView.ViewportUpdateMode.FullViewportUpdate)
font = QFont()
font.setPixelSize(20)
self.setFont(font)
self.layer_span = 200
self.node_span = 20
self.scene_bind = QGraphicsScene(self)
self.setScene(self.scene_bind)
self.__highlight_nodelist: List[GraphNode] = []
self.__total_graph_nodes: Dict[str, GraphNode] = {}
pass
def update_with_edges(self, arrows: List[Arrow]) -> None:
tools = DAGGraph()
tools.rebuild_from_edges(arrows)
tools.graph_layout()
total_nodes = tools.nodes_with_layout
previous_node_end = 0
previous_graphics_nodes = []
# 迭代呈现层
for layer_idx in range(0, tools.max_layer_count):
current_level_nodes: List[DAGOrderHelper] = list(filter(lambda n: n.layer_number == layer_idx, total_nodes))
current_level_nodes.sort(key=lambda n: n.sort_number)
# 构建当前层节点
ypos_acc = 0
current_graphics_nodes = []
for node in current_level_nodes:
if node.is_fake_node():
curr_gnode = PenetrateNode(20, None)
curr_gnode.setPos(previous_node_end, ypos_acc)
curr_gnode.attach_data(node)
ypos_acc += curr_gnode.boundingRect().height()
current_graphics_nodes.append(curr_gnode)
self.scene_bind.addItem(curr_gnode)
fragm_start = node.relate_bind.bind_point().point_name
fragm_end = node.towards_to.bind_node.point_name
node_key = (f"plac${fragm_start}::{fragm_end}-{node.layer_number}")
curr_gnode.node_key_bind = "plac", fragm_start, fragm_end, node.layer_number
self.__total_graph_nodes[node_key] = curr_gnode
pass
else:
node_type_vx = StoryNodeType.FragmentNode
if node.layer_bind.input_count == 0:
node_type_vx = StoryNodeType.StoryStartNode
curr_gnode = ActivePresentNode(node.layer_bind.bind_point().point_name,
node_type_vx, self.font(), None)
curr_gnode.attach_data(node)
curr_gnode.setPos(previous_node_end, ypos_acc)
ypos_acc += curr_gnode.boundingRect().height()
current_graphics_nodes.append(curr_gnode)
self.scene_bind.addItem(curr_gnode)
node_key = f"node@{node.layer_bind.bind_point().point_name}"
curr_gnode.node_key_bind = "node", node.layer_bind.bind_point().point_name
self.__total_graph_nodes[node_key] = curr_gnode
pass
ypos_acc += self.node_span
pass
# 调整同层节点宽度
curr_layer_width = 0
for n in current_graphics_nodes:
curr_layer_width = max(curr_layer_width, n.boundingRect().width())
pass
for n in current_graphics_nodes:
if hasattr(n, "resize_width"):
n.resize_width(curr_layer_width)
pass
pass
previous_node_end += curr_layer_width + self.layer_span
if len(previous_graphics_nodes) > 0:
prev_layer_width = 0
for n in previous_graphics_nodes:
prev_layer_width = max(prev_layer_width, n.boundingRect().width())
pass
for curr_gnode in current_graphics_nodes:
sort_helper: DAGOrderHelper = curr_gnode.get_data()
for prev_gnode in previous_graphics_nodes:
if prev_gnode.get_data() in sort_helper.get_upstream_nodes():
line_cmbn = TransitionCurve(prev_gnode, curr_gnode, prev_layer_width,None)
self.scene_bind.addItem(line_cmbn)
line_cmbn.layout_refresh()
relate_node_name = ""
if prev_gnode.node_type() == GraphNodeType.ActivePresentNode:
relate_node_name = prev_gnode.get_data().layer_bind.bind_point().point_name
elif prev_gnode.node_type() == GraphNodeType.PenetrateNode:
relate_node_name = prev_gnode.get_data().relate_bind.bind_point().point_name
towards_node_name = ""
if curr_gnode.node_type() == GraphNodeType.ActivePresentNode:
towards_node_name = sort_helper.layer_bind.bind_point().point_name
elif curr_gnode.node_type() == GraphNodeType.PenetrateNode:
towards_node_name = sort_helper.towards_to.bind_point().point_name
fragm_start = relate_node_name
fragm_end = towards_node_name
node_key = f"curv&{relate_node_name}::{towards_node_name}-{sort_helper.layer_number}"
line_cmbn.node_key_bind = "curv", fragm_start, fragm_end, sort_helper.layer_number
self.__total_graph_nodes[node_key] = line_cmbn
pass
pass
pass
pass
previous_graphics_nodes = current_graphics_nodes
pass
pass
def highlight_graph_link(self, highlight_path: List[str]):
for n in self.__highlight_nodelist:
n.highlight(False)
pass
self.__highlight_nodelist.clear()
start_node = self.__total_graph_nodes[f"node@{highlight_path[0]}"]
start_node.highlight(True)
self.__highlight_nodelist.append(start_node)
for idx in range(1, len(highlight_path)):
start_name = highlight_path[idx-1]
end_name = highlight_path[idx]
end_node = self.__total_graph_nodes[f"node@{end_name}"]
end_node.highlight(True)
self.__highlight_nodelist.append(end_node)
plac_key = f"plac${start_name}::{end_name}"
curv_key = f"curv&{start_name}::{end_name}"
for key in self.__total_graph_nodes:
if key.startswith(plac_key):
placex = self.__total_graph_nodes[key]
placex.highlight(True)
self.__highlight_nodelist.append(placex)
pass
if key.startswith(curv_key):
curvx = self.__total_graph_nodes[key]
curvx.highlight(True)
self.__highlight_nodelist.append(curvx)
pass
pass
pass
self.scene_bind.update()
self.update()
pass
def mousePressEvent(self, event: QMouseEvent):
QGraphicsView.mousePressEvent(self, event)
if event.button() == Qt.MouseButton.LeftButton:
gitems = self.items(event.pos())
noderef_names = []
for gnode in gitems:
if gnode.node_key_bind[0].startswith("node"):
noderef_names.append(gnode.node_key_bind)
pass
pass
self.nodes_clicked.emit(event.pos().x(), event.pos().y(), noderef_names[0:1])
pass
if __name__ == "__main__":
app = QApplication(sys.argv)
view = DAGActiveView(None)
view.show()
arrows = [
Arrow(Point('a'), Point('b')),
Arrow(Point('a'), Point('c')),
Arrow(Point('c'), Point('d')),
Arrow(Point('a'), Point('d')),
Arrow(Point('c'), Point('e')),
Arrow(Point('c'), Point('f')),
]
view.update_with_edges(arrows)
app.exec()