StoryCheckTools/graph/undirected_graph/UDGPresent.py

285 lines
8.9 KiB
Python

import sys
from abc import abstractmethod
from enum import Enum
from typing import List, Dict, Tuple
import networkx as nx
from PyQt5.QtCore import QPointF, QRectF, Qt, pyqtSignal
from PyQt5.QtGui import QFont, QPainterPath, QPen, QPainter, QMouseEvent
from PyQt5.QtWidgets import QGraphicsItem, QGraphicsView, QApplication, QGraphicsScene
from graph.DataType import Point, Line
class PresentNodeType(Enum):
PresentNode = 0,
ConnectionNode = 1,
class GraphNode:
@abstractmethod
def node_type(self) -> PresentNodeType:
raise NotImplementedError("node_type")
@abstractmethod
def highlight(self, flag: bool):
raise NotImplementedError("highlight")
@abstractmethod
def is_highlighted(self) -> bool:
raise NotImplementedError("is_highlighted")
@abstractmethod
def boundingRect(self) -> QRectF:
raise NotImplementedError("boundingRect")
@abstractmethod
def pos(self) -> QPointF:
raise NotImplementedError("pos")
class PresentNode(QGraphicsItem, GraphNode):
def __init__(self, name: str, font: QFont, parent):
QGraphicsItem.__init__(self, parent)
self.node_name = name
self.__is_highlight_mark = False
self.__font_bind = font
self.__sibling_list = []
self.setZValue(10)
pass
def sibling_append(self, node: 'PresentNode'):
if node not in self.__sibling_list:
self.__sibling_list.append(node)
pass
pass
def sibling_nodes(self) -> List['PresentNode']:
return self.__sibling_list
def node_type(self) -> PresentNodeType:
return PresentNodeType.PresentNode
def highlight(self, flag: bool):
self.__is_highlight_mark = flag
pass
def is_highlighted(self) -> bool:
return self.__is_highlight_mark
pass
def boundingRect(self) -> QRectF:
width_x = self.__font_bind.pixelSize() * (len(self.node_name)+1)
height_x = self.__font_bind.pixelSize()
return QRectF(0, 0, width_x + 10, height_x + 10)
pass
def paint(self, painter, option, widget = ...):
outline = self.boundingRect()
path_icon = QPainterPath()
path_icon.lineTo(outline.height()/2 - 5, outline.height() -5)
path_icon.lineTo(outline.height()/2, outline.height()/2)
path_icon.lineTo(outline.height() - 5, outline.height()/2 - 5)
path_icon.lineTo(0, 0)
painter.save()
painter.setRenderHint(QPainter.Antialiasing)
painter.drawRect(outline)
if self.__is_highlight_mark:
brush = Qt.red
painter.setPen(Qt.red)
else:
brush = Qt.black
painter.setPen(Qt.black)
pass
painter.fillPath(path_icon, brush)
painter.translate(outline.height(), 5)
painter.drawText(outline, self.node_name)
painter.restore()
pass
class ConnectionNode(QGraphicsItem, GraphNode):
def __init__(self, p0: GraphNode, p1: GraphNode, parent):
QGraphicsItem.__init__(self, parent)
self.__highlight_mark = False
self.__point0 = p0
self.__point1 = p1
self.__outline = QRectF()
self.setZValue(1)
pass
def node_type(self) -> PresentNodeType:
return PresentNodeType.ConnectionNode
def highlight(self, flag: bool):
self.__highlight_mark = flag
pass
def is_highlighted(self) -> bool:
return self.__highlight_mark
pass
def relayout_exec(self):
start_pos = self.__point0.pos()
end_pos = self.__point1.pos()
start_x = min(start_pos.x(), end_pos.x())
start_y = min(start_pos.y(), end_pos.y())
end_x = max(start_pos.x(), end_pos.x())
end_y = max(start_pos.y(), end_pos.y())
self.setPos(QPointF(start_x, start_y))
self.__outline = QRectF(0, 0, end_x - start_x, end_y - start_y)
pass
def boundingRect(self):
return self.__outline
pass
def paint(self, painter, option, widget = ...):
start_pos = self.__point0.pos()
end_pos = self.__point1.pos()
outline = self.boundingRect()
painter.save()
painter.setRenderHint(QPainter.Antialiasing)
if self.__highlight_mark:
pen = QPen(Qt.red)
else:
pen = QPen(Qt.lightGray)
pen.setWidthF(3)
painter.setPen(pen)
if start_pos.y() < end_pos.y():
if start_pos.x() < end_pos.x():
painter.drawLine(outline.topLeft(), outline.bottomRight())
else:
painter.drawLine(outline.topRight(), outline.bottomLeft())
else:
if start_pos.x() < end_pos.x():
painter.drawLine(outline.bottomLeft(), outline.topRight())
else:
painter.drawLine(outline.topLeft(), outline.bottomRight())
painter.restore()
pass
class UDGPresent(QGraphicsView):
node_clicked = pyqtSignal(str)
def __init__(self, parent):
QGraphicsView.__init__(self, parent)
self.__highlight_nodes: List[GraphNode] = []
self.node_set: Dict[str, GraphNode] = {}
self.__layout_graph = nx.Graph()
self.__scene_bind = QGraphicsScene(self)
self.setScene(self.__scene_bind)
font = QFont()
font.setPixelSize(20)
self.setFont(font)
pass
def rebuild_from_edges(self, line_set: List[Line]):
self.node_set.clear()
edge_set: Dict[str, Tuple[GraphNode, GraphNode]] = {}
for line in line_set:
start_node = line.points()[0]
if start_node.point_name not in self.node_set:
self.node_set[start_node.point_name] = PresentNode(start_node.point_name, self.font(), None)
pass
self.__layout_graph.add_node(start_node.point_name)
end_node = line.points()[1]
if start_node.point_name == end_node.point_name:
continue
if end_node.point_name not in self.node_set:
self.node_set[end_node.point_name] = PresentNode(end_node.point_name, self.font(), None)
pass
self.__layout_graph.add_node(end_node.point_name)
self.__layout_graph.add_edge(start_node.point_name, end_node.point_name)
start_force_point: PresentNode = self.node_set[start_node.point_name]
other_force_point: PresentNode = self.node_set[end_node.point_name]
if other_force_point not in start_force_point.sibling_nodes():
start_force_point.sibling_append(other_force_point)
if start_force_point not in other_force_point.sibling_nodes():
other_force_point.sibling_append(start_force_point)
pass
pos_map = nx.spring_layout(self.__layout_graph)
scala_value:float = 0
for name in pos_map:
primitive_pos = pos_map[name]
target_gnode: PresentNode = self.node_set[name]
sibling_nodes = target_gnode.sibling_nodes()
for sib in sibling_nodes:
sib_primitive_pos = pos_map[sib.node_name]
prim_x_span = primitive_pos[0] - sib_primitive_pos[0]
prim_y_span = primitive_pos[1] - sib_primitive_pos[1]
target_rect = target_gnode.boundingRect()
scala_value = max(scala_value, target_rect.width()/prim_x_span)
scala_value = max(scala_value, target_rect.height()/prim_y_span)
pass
pass
for name in pos_map:
primitive_pos = pos_map[name]
target_gnode: PresentNode = self.node_set[name]
target_gnode.setPos(primitive_pos[0] * scala_value, primitive_pos[1] * scala_value)
self.__scene_bind.addItem(target_gnode)
pass
for edge in nx.edges(self.__layout_graph):
edge_start = edge[0]
edge_end = edge[1]
node_one = self.node_set[edge_start]
node_two = self.node_set[edge_end]
connection = ConnectionNode(node_one, node_two, None)
self.scene().addItem(connection)
connection.relayout_exec()
pass
pass
def mousePressEvent(self, event: QMouseEvent):
QGraphicsView.mousePressEvent(self, event)
if event.button() == Qt.MouseButton.LeftButton:
item: GraphNode = self.itemAt(event.pos())
if item.node_type() == PresentNodeType.PresentNode:
vnode: PresentNode = item
self.node_clicked.emit(vnode.node_name)
print(vnode.node_name)
pass
pass
pass
if __name__ == "__main__":
app = QApplication(sys.argv)
view = UDGPresent(None)
view.show()
list_in = [
Line(Point('a中古'), Point('b')),
Line(Point('a中古'), Point('c')),
Line(Point('a中古'), Point('d')),
]
view.rebuild_from_edges(list_in)
app.exec()