StoryCheckTools/graph/directed_acyclic_graph/DAGLayout.py

341 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from math import floor,ceil
from graph.DataType import Point, Arrow
from typing import List, Dict, Tuple
class DAGLayerHelper:
def __init__(self, bind: Point):
self.bind_node = bind
self.input_count: int = 0
self.next_points: List[DAGLayerHelper] = []
self.layer_v: int = 0
pass
def bind_point(self) -> Point:
return self.bind_node
def next_append(self, inst: 'DAGLayerHelper'):
self.next_points.append(inst)
inst.input_count += 1
pass
def next_nodes(self) -> List['DAGLayerHelper']:
return self.next_points
def make_copy(self) -> 'DAGLayerHelper':
temp_ps = []
for n in self.next_points:
temp_ps.append(n.make_copy())
pass
ins = DAGLayerHelper(self.bind_node.make_copy())
ins.input_count = self.input_count
ins.next_points = temp_ps
ins.layer_v = self.layer_v
return ins
class DAGOrderHelper:
def __init__(self, relate:DAGLayerHelper|None, towards:DAGLayerHelper | None, bind:DAGLayerHelper|None):
self.layer_bind = bind
self.relate_bind = relate
self.towards_to = towards
self.layer_number = 0
self.sort_number:float = 0
self.__prev_layer_nodes: List['DAGOrderHelper'] = []
if bind is not None:
self.layer_number = bind.layer_v
pass
pass
def is_fake_node(self) -> bool:
return self.layer_bind is None
def get_upstream_nodes(self):
return self.__prev_layer_nodes
def append_upstream_node(self, node: 'DAGOrderHelper'):
self.__prev_layer_nodes.append(node)
pass
class DAGGraph:
def __init__(self):
self.graph_inst: Dict[str, DAGLayerHelper] = {}
self.nodes_with_layout: List[DAGOrderHelper] = []
self.max_layer_count = 0
pass
def rebuild_from_edges(self, arrow_list: List[Arrow]) -> None:
"""
通过有序边构建有向图
:param arrow_list: 有向边集合
"""
for arr in arrow_list:
start = arr.start_point()
start_helper = None
if start.point_name in self.graph_inst:
start_helper = self.graph_inst[start.point_name]
else:
start_helper = DAGLayerHelper(start)
self.graph_inst[start.point_name] = start_helper
end = arr.end_point()
end_helper = None
if end.point_name in self.graph_inst:
end_helper = self.graph_inst[end.point_name]
else:
end_helper = DAGLayerHelper(end)
self.graph_inst[end.point_name] = end_helper
start_helper.next_append(end_helper)
pass
pass
def __spawns_peak(self, ref_set: List[DAGLayerHelper]) -> Tuple[DAGLayerHelper, List[DAGLayerHelper]] | None:
"""
拓扑排序迭代处理
:param ref_set:
:return:
"""
for inst in ref_set:
if inst.input_count == 0:
for it_nxt in inst.next_nodes():
it_nxt.input_count -= 1
if it_nxt not in ref_set:
ref_set.append(it_nxt)
pass
pass
ref_set.remove(inst)
self.graph_inst.pop(inst.bind_point().point_name)
return inst, ref_set
for inst in self.graph_inst.values():
if inst.input_count == 0:
if inst in ref_set:
ref_set.remove(inst)
for it_nxt in inst.next_nodes():
it_nxt.input_count -= 1
if it_nxt not in ref_set:
ref_set.append(it_nxt)
pass
pass
self.graph_inst.pop(inst.bind_point().point_name)
return inst, ref_set
pass
if len(self.graph_inst) > 0:
raise RuntimeError("有向无环图中发现环形结构!")
return None
def __graph_recovery(self, sort_seqs: List[DAGLayerHelper]) -> None:
"""
通过拓扑排序结果恢复数据图
:param sort_seqs: 有序序列
"""
# 清空cache
for it in sort_seqs:
it.input_count = 0
pass
# 入度复原
for it in sort_seqs:
for nxt in it.next_nodes():
nxt.input_count += 1
pass
pass
# 数据图恢复
self.graph_inst.clear()
for it in sort_seqs:
self.graph_inst[it.bind_point().point_name] = it
pass
pass
def __node_layering(self, inst: DAGLayerHelper, layer_current: int = 0) -> int:
"""
节点分层处理,返回本次执行路径最大长度
:param inst: 当前节点
:param layer_current: 节点等级
:return: 最长路径长度
"""
max_remains = layer_current
if layer_current == 0 or inst.layer_v < layer_current:
inst.layer_v = layer_current
values = inst.next_nodes()
for fork in values:
max_remains = max(self.__node_layering(fork, inst.layer_v + 1), max_remains)
pass
return max_remains + 1
def __node_layering_adj(self, inst: DAGLayerHelper):
if inst.input_count > 1:
return inst.layer_v - 1
if len(inst.next_nodes()) == 0:
return inst.layer_v - 1
layer_number = 2**32
for cinst in inst.next_nodes():
layer_number = min(layer_number, self.__node_layering_adj(cinst))
pass
inst.layer_v = layer_number
return inst.layer_v - 1
def __tidy_graph_nodes(self) -> List[DAGOrderHelper]:
nodes_temp: Dict[str, DAGOrderHelper] = {}
# 注册所有数据图实节点
for node in self.graph_inst.values():
nodes_temp[node.bind_point().point_name] = DAGOrderHelper(bind=node, relate=None, towards=None)
pass
temp_array: List[DAGOrderHelper] = []
temp_array.extend(nodes_temp.values())
# 生成链接fake-node节点并执行链接
for node in self.graph_inst.values():
for next in node.next_nodes():
node_links = [nodes_temp[node.bind_point().point_name]]
for layer_index in range(node.layer_v + 1, next.layer_v):
node_links.append(DAGOrderHelper(relate=node, towards=next, bind=None))
node_links[-1].layer_number = layer_index
pass
node_links.append(nodes_temp[next.bind_point().point_name])
# 节点链接串已经构建完成,链接各层级节点
for idx in range(1, len(node_links)):
start_point = node_links[idx-1]
end_point = node_links[idx]
end_point.append_upstream_node(start_point)
pass
temp_array.extend(node_links[1:len(node_links)-1])
pass
pass
return temp_array
def __graph_layer_nodes_sort(self, layer_index:int, nodes: List[DAGOrderHelper]):
# 提取当前层次的节点
target_nodes_within_layer = []
for n in nodes:
if n.layer_number == layer_index:
target_nodes_within_layer.append(n)
pass
pass
# 当前层次没有节点,则不做处理
if len(target_nodes_within_layer) == 0:
return
# 初始化节点排序
if layer_index == 0:
for idx in range(0, len(target_nodes_within_layer)):
target_nodes_within_layer[idx].sort_number = idx + 1
pass
pass
elif layer_index > 0:
# 计算排序系数
for target_node in target_nodes_within_layer:
prev_sorts = list(map(lambda n:n.sort_number, target_node.get_upstream_nodes()))
if len(prev_sorts) > 0:
target_node.sort_number = sum(prev_sorts)/len(prev_sorts)
pass
pass
def compare_item(a: DAGOrderHelper):
return a.sort_number
# 整理节点排序
target_nodes_within_layer.sort(key=compare_item)
for idx in range(0, len(target_nodes_within_layer)):
target_item = target_nodes_within_layer[idx]
target_item.sort_number = idx + 1
pass
pass
self.__graph_layer_nodes_sort(layer_index + 1, nodes)
pass
def graph_layout(self):
sort_seqs = []
# 拓扑排序
head = None
refs = []
while True:
peaks_result = self.__spawns_peak(refs)
if peaks_result is None:
break
head, refs = peaks_result
sort_seqs.append(head)
pass
# 数据图恢复
self.__graph_recovery(sort_seqs)
# 数据图节点分层
for item in sort_seqs:
if item.input_count == 0:
self.max_layer_count = max(self.max_layer_count, self.__node_layering(item))
pass
pass
for item in sort_seqs:
if item.input_count ==0:
self.__node_layering_adj(item)
pass
pass
# 整理数据图节点
rich_nodes = self.__tidy_graph_nodes()
self.__graph_layer_nodes_sort(0, rich_nodes)
self.nodes_with_layout = rich_nodes
pass
def visible_nodes(self) -> List[DAGOrderHelper]:
retvs = []
for n in self.nodes_with_layout:
if not n.is_fake_node():
retvs.append(n)
pass
pass
return retvs
if __name__ == "__main__":
graph = DAGGraph()
arrows = [
Arrow(Point('a'), Point('b')),
Arrow(Point('a'), Point('c')),
Arrow(Point('c'), Point('d')),
Arrow(Point('a'), Point('d')),
Arrow(Point('c'), Point('e')),
Arrow(Point('c'), Point('f')),
]
graph.rebuild_from_edges(arrows)
graph.graph_layout()
points = graph.nodes_with_layout
for p in points:
if not p.is_fake_node():
print(f"{p.layer_bind.bind_point().point_name},level{p.layer_number},sort{p.sort_number}")
else:
print(f"'{p.relate_bind.bind_point().point_name}',level{p.layer_number},sort{p.sort_number}")