diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 2044a2aa145083b279064ad16dd59915dd8bdfd5..9e0c33f332b679972d4ed1bad30dcd4f1ebbd16c 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -48,6 +48,16 @@ from b_asic.special_operations import Delay, Input, Output DelayQueue = List[Tuple[str, ResultKey, OutputPort]] +_OPERATION_SHAPE: DefaultDict[TypeName, str] = defaultdict(lambda: "ellipse") +_OPERATION_SHAPE.update( + { + Input.type_name(): "cds", + Output.type_name(): "cds", + Delay.type_name(): "square", + } +) + + class GraphIDGenerator: """Generates Graph IDs for objects.""" @@ -813,7 +823,7 @@ class SFG(AbstractOperation): for i in range(len(p_list)): ports = p_list[i] with pg.subgraph(name=f"cluster_{i}") as sub: - sub.attr(label=f"N{i+1}") + sub.attr(label=f"N{i}") for port in ports: portstr = f"{port.operation.graph_id}.{port.index}" if port.operation.output_count > 1: @@ -821,7 +831,10 @@ class SFG(AbstractOperation): else: sub.node( portstr, + shape='rectangle', label=port.operation.graph_id, + height="0.1", + width="0.1", ) # Creates edges for each output port and creates nodes for each operation # and edges for them as well @@ -830,14 +843,20 @@ class SFG(AbstractOperation): for port in ports: for signal in port.signals: destination = cast(InputPort, signal.destination) - if destination.operation.type_name() == Delay.type_name(): + if isinstance(destination.operation, Delay): dest_node = destination.operation.graph_id + "In" else: dest_node = destination.operation.graph_id dest_label = destination.operation.graph_id node_node = f"{port.operation.graph_id}.{port.index}" pg.edge(node_node, dest_node) - pg.node(dest_node, label=dest_label, shape="square") + pg.node( + dest_node, + label=dest_label, + shape=_OPERATION_SHAPE[ + destination.operation.type_name() + ], + ) if port.operation.type_name() == Delay.type_name(): source_node = port.operation.graph_id + "Out" else: @@ -845,7 +864,11 @@ class SFG(AbstractOperation): source_label = port.operation.graph_id node_node = f"{port.operation.graph_id}.{port.index}" pg.edge(source_node, node_node) - pg.node(source_node, label=source_label, shape="square") + pg.node( + source_node, + label=source_label, + shape=_OPERATION_SHAPE[port.operation.type_name()], + ) return pg @@ -1417,12 +1440,7 @@ class SFG(AbstractOperation): destination.operation.graph_id, ) else: - if isinstance(op, Delay): - dg.node(op.graph_id, shape="square") - elif isinstance(op, (Input, Output)): - dg.node(op.graph_id, shape="cds") - else: - dg.node(op.graph_id) + dg.node(op.graph_id, shape=_OPERATION_SHAPE[op.type_name()]) return dg def _repr_mimebundle_(self, include=None, exclude=None): diff --git a/test/test_sfg.py b/test/test_sfg.py index 4dcb1d97c6c8a2ec82d6675e36621e6c49ca3fc4..433c7b175ac1f92a873d874ae9ae1eebfaf15a88 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -1376,21 +1376,23 @@ class TestGetComponentsOfType: class TestPrecedenceGraph: def test_precedence_graph(self, sfg_simple_filter): res = ( - "digraph {\n\trankdir=LR\n\tsubgraph cluster_0" - " {\n\t\tlabel=N1\n\t\t\"in1.0\" [label=in1]\n\t\t\"t1.0\"" - " [label=t1]\n\t}\n\tsubgraph cluster_1" - " {\n\t\tlabel=N2\n\t\t\"cmul1.0\" [label=cmul1]\n\t}\n\tsubgraph" - " cluster_2 {\n\t\tlabel=N3\n\t\t\"add1.0\"" - " [label=add1]\n\t}\n\t\"in1.0\" -> add1\n\tadd1 [label=add1" - " shape=square]\n\tin1 -> \"in1.0\"\n\tin1 [label=in1" - " shape=square]\n\t\"t1.0\" -> cmul1\n\tcmul1 [label=cmul1" - " shape=square]\n\t\"t1.0\" -> out1\n\tout1 [label=out1" - " shape=square]\n\tt1Out -> \"t1.0\"\n\tt1Out [label=t1" - " shape=square]\n\t\"cmul1.0\" -> add1\n\tadd1 [label=add1" - " shape=square]\n\tcmul1 -> \"cmul1.0\"\n\tcmul1 [label=cmul1" - " shape=square]\n\t\"add1.0\" -> t1In\n\tt1In [label=t1" - " shape=square]\n\tadd1 -> \"add1.0\"\n\tadd1 [label=add1" - " shape=square]\n}" + 'digraph {\n\trankdir=LR\n\tsubgraph cluster_0' + ' {\n\t\tlabel=N0\n\t\t"in1.0" [label=in1 height=0.1' + ' shape=rectangle width=0.1]\n\t\t"t1.0" [label=t1 height=0.1' + ' shape=rectangle width=0.1]\n\t}\n\tsubgraph cluster_1' + ' {\n\t\tlabel=N1\n\t\t"cmul1.0" [label=cmul1 height=0.1' + ' shape=rectangle width=0.1]\n\t}\n\tsubgraph cluster_2' + ' {\n\t\tlabel=N2\n\t\t"add1.0" [label=add1 height=0.1' + ' shape=rectangle width=0.1]\n\t}\n\t"in1.0" -> add1\n\tadd1' + ' [label=add1 shape=ellipse]\n\tin1 -> "in1.0"\n\tin1 [label=in1' + ' shape=cds]\n\t"t1.0" -> cmul1\n\tcmul1 [label=cmul1' + ' shape=ellipse]\n\t"t1.0" -> out1\n\tout1 [label=out1' + ' shape=cds]\n\tt1Out -> "t1.0"\n\tt1Out [label=t1' + ' shape=square]\n\t"cmul1.0" -> add1\n\tadd1 [label=add1' + ' shape=ellipse]\n\tcmul1 -> "cmul1.0"\n\tcmul1 [label=cmul1' + ' shape=ellipse]\n\t"add1.0" -> t1In\n\tt1In [label=t1' + ' shape=square]\n\tadd1 -> "add1.0"\n\tadd1 [label=add1' + ' shape=ellipse]\n}' ) assert sfg_simple_filter.precedence_graph().source in (res, res + "\n") @@ -1399,19 +1401,20 @@ class TestPrecedenceGraph: class TestSFGGraph: def test_sfg(self, sfg_simple_filter): res = ( - "digraph {\n\trankdir=LR\n\tin1 [shape=cds]\n\tin1 -> " - "add1\n\tout1 [shape=cds]\n\tt1 -> out1\n\tadd1\n\tcmul1 -> " - "add1\n\tcmul1\n\tadd1 -> t1\n\tt1 [shape=square]\n\tt1 " - "-> cmul1\n}" + 'digraph {\n\trankdir=LR\n\tin1 [shape=cds]\n\tin1 -> add1\n\tout1' + ' [shape=cds]\n\tt1 -> out1\n\tadd1 [shape=ellipse]\n\tcmul1 ->' + ' add1\n\tcmul1 [shape=ellipse]\n\tadd1 -> t1\n\tt1' + ' [shape=square]\n\tt1 -> cmul1\n}' ) assert sfg_simple_filter.sfg_digraph().source in (res, res + "\n") def test_sfg_show_id(self, sfg_simple_filter): res = ( - "digraph {\n\trankdir=LR\n\tin1 [shape=cds]\n\tin1 -> add1 " - "[label=s1]\n\tout1 [shape=cds]\n\tt1 -> out1 [label=s2]\n\tadd1" - "\n\tcmul1 -> add1 [label=s3]\n\tcmul1\n\tadd1 -> t1 " - "[label=s4]\n\tt1 [shape=square]\n\tt1 -> cmul1 [label=s5]\n}" + 'digraph {\n\trankdir=LR\n\tin1 [shape=cds]\n\tin1 -> add1' + ' [label=s1]\n\tout1 [shape=cds]\n\tt1 -> out1 [label=s2]\n\tadd1' + ' [shape=ellipse]\n\tcmul1 -> add1 [label=s3]\n\tcmul1' + ' [shape=ellipse]\n\tadd1 -> t1 [label=s4]\n\tt1' + ' [shape=square]\n\tt1 -> cmul1 [label=s5]\n}' ) assert sfg_simple_filter.sfg_digraph(show_id=True).source in (