From 5aaf4c2802372fdc730cb663ce27f4efdbcb8812 Mon Sep 17 00:00:00 2001
From: Oscar Gustafsson <oscar.gustafsson@gmail.com>
Date: Tue, 16 May 2023 12:47:01 +0200
Subject: [PATCH] Add multiplexers to Architecture Digraph

---
 b_asic/architecture.py | 36 ++++++++++++++++++++++++++++++++----
 1 file changed, 32 insertions(+), 4 deletions(-)

diff --git a/b_asic/architecture.py b/b_asic/architecture.py
index c2c6dde3..08026487 100644
--- a/b_asic/architecture.py
+++ b/b_asic/architecture.py
@@ -742,6 +742,7 @@ of :class:`~b_asic.architecture.ProcessingElement`
         cluster: bool = True,
         splines: str = "spline",
         io_cluster: bool = True,
+        show_multiplexers: bool = True,
     ) -> Digraph:
         dg = Digraph(node_attr={'shape': 'record'})
         dg.attr(splines=splines)
@@ -782,32 +783,59 @@ of :class:`~b_asic.architecture.ProcessingElement`
 
         # Create list of interconnects
         edges: DefaultDict[str, Set[Tuple[str, str]]] = defaultdict(set)
+        destination_edges: DefaultDict[str, Set[str]] = defaultdict(set)
         for pe in self._processing_elements:
             inputs, outputs = self.get_interconnects_for_pe(pe)
             for i, inp in enumerate(inputs):
                 for (source, port), cnt in inp.items():
-                    edges[f"{source.entity_name}:out{port}"].add(
+                    source_str = f"{source.entity_name}:out{port}"
+                    destination_str = f"{pe.entity_name}:in{i}"
+                    edges[source_str].add(
                         (
-                            f"{pe.entity_name}:in{i}",
+                            destination_str,
                             f"{cnt}",
                         )
                     )
+                    destination_edges[destination_str].add(source_str)
             for o, output in enumerate(outputs):
                 for (destination, port), cnt in output.items():
-                    edges[f"{pe.entity_name}:out{o}"].add(
+                    source_str = f"{pe.entity_name}:out{o}"
+                    destination_str = f"{destination.entity_name}:in{port}"
+                    edges[source_str].add(
                         (
-                            f"{destination.entity_name}:in{port}",
+                            destination_str,
                             f"{cnt}",
                         )
                     )
+                    destination_edges[destination_str].add(source_str)
+
+        destination_list = {k: list(v) for k, v in destination_edges.items()}
+        if show_multiplexers:
+            for destination, source_list in destination_list.items():
+                if len(source_list) > 1:
+                    # Create GraphViz struct for multiplexer
+                    inputs = [f"in{i}" for i in range(len(source_list))]
+                    ret = ""
+                    in_strs = [f"<{in_str}> {in_str}" for in_str in inputs]
+                    ret += f"{{{'|'.join(in_strs)}}}|"
+                    name = f"{destination.replace(':', '_')}_mux"
+                    ret += name
+                    ret += "|<out> out"
+                    dg.node(name, "{" + ret + "}")
+                    dg.edge(f"{name}:out", destination)
+
         # Add edges to graph
         for src_str, destination_counts in edges.items():
+            original_src_str = src_str
             if len(destination_counts) > 1 and branch_node:
                 branch = f"{src_str}_branch".replace(":", "")
                 dg.node(branch, shape='point')
                 dg.edge(src_str, branch, arrowhead='none')
                 src_str = branch
             for destination_str, cnt_str in destination_counts:
+                if show_multiplexers and len(destination_list[destination_str]) > 1:
+                    idx = destination_list[destination_str].index(original_src_str)
+                    destination_str = f"{destination_str.replace(':', '_')}_mux:in{idx}"
                 dg.edge(src_str, destination_str, label=cnt_str)
         return dg
 
-- 
GitLab