From 5eec2b7f59f159824e2c43cab9a87d64681b345f Mon Sep 17 00:00:00 2001
From: Oscar Gustafsson <oscar.gustafsson@gmail.com>
Date: Thu, 16 Feb 2023 17:03:20 +0100
Subject: [PATCH] Unify operation shapes in SFG and precendence graph

---
 b_asic/signal_flow_graph.py | 38 ++++++++++++++++++++--------
 test/test_sfg.py            | 49 ++++++++++++++++++++-----------------
 2 files changed, 54 insertions(+), 33 deletions(-)

diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py
index 2044a2aa..9e0c33f3 100644
--- a/b_asic/signal_flow_graph.py
+++ b/b_asic/signal_flow_graph.py
@@ -48,6 +48,16 @@ from b_asic.special_operations import Delay, Input, Output
 DelayQueue = List[Tuple[str, ResultKey, OutputPort]]
 
 
+_OPERATION_SHAPE: DefaultDict[TypeName, str] = defaultdict(lambda: "ellipse")
+_OPERATION_SHAPE.update(
+    {
+        Input.type_name(): "cds",
+        Output.type_name(): "cds",
+        Delay.type_name(): "square",
+    }
+)
+
+
 class GraphIDGenerator:
     """Generates Graph IDs for objects."""
 
@@ -813,7 +823,7 @@ class SFG(AbstractOperation):
         for i in range(len(p_list)):
             ports = p_list[i]
             with pg.subgraph(name=f"cluster_{i}") as sub:
-                sub.attr(label=f"N{i+1}")
+                sub.attr(label=f"N{i}")
                 for port in ports:
                     portstr = f"{port.operation.graph_id}.{port.index}"
                     if port.operation.output_count > 1:
@@ -821,7 +831,10 @@ class SFG(AbstractOperation):
                     else:
                         sub.node(
                             portstr,
+                            shape='rectangle',
                             label=port.operation.graph_id,
+                            height="0.1",
+                            width="0.1",
                         )
         # Creates edges for each output port and creates nodes for each operation
         # and edges for them as well
@@ -830,14 +843,20 @@ class SFG(AbstractOperation):
             for port in ports:
                 for signal in port.signals:
                     destination = cast(InputPort, signal.destination)
-                    if destination.operation.type_name() == Delay.type_name():
+                    if isinstance(destination.operation, Delay):
                         dest_node = destination.operation.graph_id + "In"
                     else:
                         dest_node = destination.operation.graph_id
                     dest_label = destination.operation.graph_id
                     node_node = f"{port.operation.graph_id}.{port.index}"
                     pg.edge(node_node, dest_node)
-                    pg.node(dest_node, label=dest_label, shape="square")
+                    pg.node(
+                        dest_node,
+                        label=dest_label,
+                        shape=_OPERATION_SHAPE[
+                            destination.operation.type_name()
+                        ],
+                    )
                 if port.operation.type_name() == Delay.type_name():
                     source_node = port.operation.graph_id + "Out"
                 else:
@@ -845,7 +864,11 @@ class SFG(AbstractOperation):
                 source_label = port.operation.graph_id
                 node_node = f"{port.operation.graph_id}.{port.index}"
                 pg.edge(source_node, node_node)
-                pg.node(source_node, label=source_label, shape="square")
+                pg.node(
+                    source_node,
+                    label=source_label,
+                    shape=_OPERATION_SHAPE[port.operation.type_name()],
+                )
 
         return pg
 
@@ -1417,12 +1440,7 @@ class SFG(AbstractOperation):
                         destination.operation.graph_id,
                     )
             else:
-                if isinstance(op, Delay):
-                    dg.node(op.graph_id, shape="square")
-                elif isinstance(op, (Input, Output)):
-                    dg.node(op.graph_id, shape="cds")
-                else:
-                    dg.node(op.graph_id)
+                dg.node(op.graph_id, shape=_OPERATION_SHAPE[op.type_name()])
         return dg
 
     def _repr_mimebundle_(self, include=None, exclude=None):
diff --git a/test/test_sfg.py b/test/test_sfg.py
index 4dcb1d97..433c7b17 100644
--- a/test/test_sfg.py
+++ b/test/test_sfg.py
@@ -1376,21 +1376,23 @@ class TestGetComponentsOfType:
 class TestPrecedenceGraph:
     def test_precedence_graph(self, sfg_simple_filter):
         res = (
-            "digraph {\n\trankdir=LR\n\tsubgraph cluster_0"
-            " {\n\t\tlabel=N1\n\t\t\"in1.0\" [label=in1]\n\t\t\"t1.0\""
-            " [label=t1]\n\t}\n\tsubgraph cluster_1"
-            " {\n\t\tlabel=N2\n\t\t\"cmul1.0\" [label=cmul1]\n\t}\n\tsubgraph"
-            " cluster_2 {\n\t\tlabel=N3\n\t\t\"add1.0\""
-            " [label=add1]\n\t}\n\t\"in1.0\" -> add1\n\tadd1 [label=add1"
-            " shape=square]\n\tin1 -> \"in1.0\"\n\tin1 [label=in1"
-            " shape=square]\n\t\"t1.0\" -> cmul1\n\tcmul1 [label=cmul1"
-            " shape=square]\n\t\"t1.0\" -> out1\n\tout1 [label=out1"
-            " shape=square]\n\tt1Out -> \"t1.0\"\n\tt1Out [label=t1"
-            " shape=square]\n\t\"cmul1.0\" -> add1\n\tadd1 [label=add1"
-            " shape=square]\n\tcmul1 -> \"cmul1.0\"\n\tcmul1 [label=cmul1"
-            " shape=square]\n\t\"add1.0\" -> t1In\n\tt1In [label=t1"
-            " shape=square]\n\tadd1 -> \"add1.0\"\n\tadd1 [label=add1"
-            " shape=square]\n}"
+            'digraph {\n\trankdir=LR\n\tsubgraph cluster_0'
+            ' {\n\t\tlabel=N0\n\t\t"in1.0" [label=in1 height=0.1'
+            ' shape=rectangle width=0.1]\n\t\t"t1.0" [label=t1 height=0.1'
+            ' shape=rectangle width=0.1]\n\t}\n\tsubgraph cluster_1'
+            ' {\n\t\tlabel=N1\n\t\t"cmul1.0" [label=cmul1 height=0.1'
+            ' shape=rectangle width=0.1]\n\t}\n\tsubgraph cluster_2'
+            ' {\n\t\tlabel=N2\n\t\t"add1.0" [label=add1 height=0.1'
+            ' shape=rectangle width=0.1]\n\t}\n\t"in1.0" -> add1\n\tadd1'
+            ' [label=add1 shape=ellipse]\n\tin1 -> "in1.0"\n\tin1 [label=in1'
+            ' shape=cds]\n\t"t1.0" -> cmul1\n\tcmul1 [label=cmul1'
+            ' shape=ellipse]\n\t"t1.0" -> out1\n\tout1 [label=out1'
+            ' shape=cds]\n\tt1Out -> "t1.0"\n\tt1Out [label=t1'
+            ' shape=square]\n\t"cmul1.0" -> add1\n\tadd1 [label=add1'
+            ' shape=ellipse]\n\tcmul1 -> "cmul1.0"\n\tcmul1 [label=cmul1'
+            ' shape=ellipse]\n\t"add1.0" -> t1In\n\tt1In [label=t1'
+            ' shape=square]\n\tadd1 -> "add1.0"\n\tadd1 [label=add1'
+            ' shape=ellipse]\n}'
         )
 
         assert sfg_simple_filter.precedence_graph().source in (res, res + "\n")
@@ -1399,19 +1401,20 @@ class TestPrecedenceGraph:
 class TestSFGGraph:
     def test_sfg(self, sfg_simple_filter):
         res = (
-            "digraph {\n\trankdir=LR\n\tin1 [shape=cds]\n\tin1 -> "
-            "add1\n\tout1 [shape=cds]\n\tt1 -> out1\n\tadd1\n\tcmul1 -> "
-            "add1\n\tcmul1\n\tadd1 -> t1\n\tt1 [shape=square]\n\tt1 "
-            "-> cmul1\n}"
+            'digraph {\n\trankdir=LR\n\tin1 [shape=cds]\n\tin1 -> add1\n\tout1'
+            ' [shape=cds]\n\tt1 -> out1\n\tadd1 [shape=ellipse]\n\tcmul1 ->'
+            ' add1\n\tcmul1 [shape=ellipse]\n\tadd1 -> t1\n\tt1'
+            ' [shape=square]\n\tt1 -> cmul1\n}'
         )
         assert sfg_simple_filter.sfg_digraph().source in (res, res + "\n")
 
     def test_sfg_show_id(self, sfg_simple_filter):
         res = (
-            "digraph {\n\trankdir=LR\n\tin1 [shape=cds]\n\tin1 -> add1 "
-            "[label=s1]\n\tout1 [shape=cds]\n\tt1 -> out1 [label=s2]\n\tadd1"
-            "\n\tcmul1 -> add1 [label=s3]\n\tcmul1\n\tadd1 -> t1 "
-            "[label=s4]\n\tt1 [shape=square]\n\tt1 -> cmul1 [label=s5]\n}"
+            'digraph {\n\trankdir=LR\n\tin1 [shape=cds]\n\tin1 -> add1'
+            ' [label=s1]\n\tout1 [shape=cds]\n\tt1 -> out1 [label=s2]\n\tadd1'
+            ' [shape=ellipse]\n\tcmul1 -> add1 [label=s3]\n\tcmul1'
+            ' [shape=ellipse]\n\tadd1 -> t1 [label=s4]\n\tt1'
+            ' [shape=square]\n\tt1 -> cmul1 [label=s5]\n}'
         )
 
         assert sfg_simple_filter.sfg_digraph(show_id=True).source in (
-- 
GitLab