diff --git a/b_asic/architecture.py b/b_asic/architecture.py index c2c6dde38cf2b74760e5e2224193287fddf9e2c2..0802648720355f81086fdda0b2bfab87083ff4d1 100644 --- a/b_asic/architecture.py +++ b/b_asic/architecture.py @@ -742,6 +742,7 @@ of :class:`~b_asic.architecture.ProcessingElement` cluster: bool = True, splines: str = "spline", io_cluster: bool = True, + show_multiplexers: bool = True, ) -> Digraph: dg = Digraph(node_attr={'shape': 'record'}) dg.attr(splines=splines) @@ -782,32 +783,59 @@ of :class:`~b_asic.architecture.ProcessingElement` # Create list of interconnects edges: DefaultDict[str, Set[Tuple[str, str]]] = defaultdict(set) + destination_edges: DefaultDict[str, Set[str]] = defaultdict(set) for pe in self._processing_elements: inputs, outputs = self.get_interconnects_for_pe(pe) for i, inp in enumerate(inputs): for (source, port), cnt in inp.items(): - edges[f"{source.entity_name}:out{port}"].add( + source_str = f"{source.entity_name}:out{port}" + destination_str = f"{pe.entity_name}:in{i}" + edges[source_str].add( ( - f"{pe.entity_name}:in{i}", + destination_str, f"{cnt}", ) ) + destination_edges[destination_str].add(source_str) for o, output in enumerate(outputs): for (destination, port), cnt in output.items(): - edges[f"{pe.entity_name}:out{o}"].add( + source_str = f"{pe.entity_name}:out{o}" + destination_str = f"{destination.entity_name}:in{port}" + edges[source_str].add( ( - f"{destination.entity_name}:in{port}", + destination_str, f"{cnt}", ) ) + destination_edges[destination_str].add(source_str) + + destination_list = {k: list(v) for k, v in destination_edges.items()} + if show_multiplexers: + for destination, source_list in destination_list.items(): + if len(source_list) > 1: + # Create GraphViz struct for multiplexer + inputs = [f"in{i}" for i in range(len(source_list))] + ret = "" + in_strs = [f"<{in_str}> {in_str}" for in_str in inputs] + ret += f"{{{'|'.join(in_strs)}}}|" + name = f"{destination.replace(':', '_')}_mux" + ret += name + ret += "|<out> out" + dg.node(name, "{" + ret + "}") + dg.edge(f"{name}:out", destination) + # Add edges to graph for src_str, destination_counts in edges.items(): + original_src_str = src_str if len(destination_counts) > 1 and branch_node: branch = f"{src_str}_branch".replace(":", "") dg.node(branch, shape='point') dg.edge(src_str, branch, arrowhead='none') src_str = branch for destination_str, cnt_str in destination_counts: + if show_multiplexers and len(destination_list[destination_str]) > 1: + idx = destination_list[destination_str].index(original_src_str) + destination_str = f"{destination_str.replace(':', '_')}_mux:in{idx}" dg.edge(src_str, destination_str, label=cnt_str) return dg