diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 82aeb9c82ca3b3e691d7bb9e37328d7980334cb9..cc197d8e3200f0bdeb3b95421d1fcee6f3d0355d 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -1376,17 +1376,32 @@ class SFG(AbstractOperation): results[key] = value return value - def sfg_digraph(self, show_id=False, engine=None) -> Digraph: + def sfg_digraph( + self, + show_id: bool = False, + engine: str = None, + branch_node: bool = False, + port_numbering: bool = True, + splines: str = "spline", + ) -> Digraph: """ - Returns a Digraph of the SFG. Can be directly displayed in IPython. + Returns a Digraph of the SFG. + + Can be directly displayed in IPython. Parameters ---------- - show_id : Boolean, optional - If True, the graph_id:s of signals are shown. The default is False. + show_id : bool, default: False + If True, the graph_id:s of signals are shown. engine : string, optional Graphviz layout engine to be used, see https://graphviz.org/documentation/. Most common are "dot" and "neato". Default is None leading to dot. + branch_node : bool, default: False + Add a branch node in case the fan-out of a signal is two or more. + port_numbering : bool, default: True + Show the port number in case the number of ports (input or output) is two or more. + splines : {"spline", "line", "ortho", "polyline", "curved"}, default: "spline" + Spline style, see https://graphviz.org/docs/attrs/splines/ for more info. Returns ------- @@ -1395,23 +1410,56 @@ class SFG(AbstractOperation): """ dg = Digraph() - dg.attr(rankdir="LR") + dg.attr(rankdir="LR", splines=splines) + branch_nodes = set() if engine is not None: dg.engine = engine for op in self._components_by_id.values(): if isinstance(op, Signal): source = cast(OutputPort, op.source) destination = cast(InputPort, op.destination) - if show_id: - dg.edge( - source.operation.graph_id, - destination.operation.graph_id, - label=op.graph_id, + source_name = ( + source.name + if branch_node and source.signal_count > 1 + else source.operation.graph_id + ) + label = op.graph_id if show_id else None + taillabel = ( + str(source.index) + if source.operation.output_count > 1 + and (not branch_node or source.signal_count == 1) + and port_numbering + else None + ) + headlabel = ( + str(destination.index) + if destination.operation.input_count > 1 and port_numbering + else None + ) + dg.edge( + source_name, + destination.operation.graph_id, + label=label, + taillabel=taillabel, + headlabel=headlabel, + ) + if ( + branch_node + and source.signal_count > 1 + and source_name not in branch_nodes + ): + branch_nodes.add(source_name) + dg.node(source_name, shape='point') + taillabel = ( + str(source.index) + if source.operation.output_count > 1 and port_numbering + else None ) - else: dg.edge( source.operation.graph_id, - destination.operation.graph_id, + source_name, + arrowhead='none', + taillabel=taillabel, ) else: dg.node(op.graph_id, shape=_OPERATION_SHAPE[op.type_name()])