diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 3dd6c2791898e45d22323ed5be32433fe60cc6ea..ef41e1a5456c6e1aa886773fa892ad105e0dd990 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -10,6 +10,7 @@ from io import StringIO from queue import PriorityQueue import itertools as it from graphviz import Digraph +from graphviz.backend import FORMATS as GRAPHVIZ_FORMATS, ENGINES as GRAPHVIZ_ENGINES from b_asic.port import SignalSourceProvider, OutputPort from b_asic.operation import Operation, AbstractOperation, ResultKey, MutableResultMap, MutableDelayMap @@ -847,44 +848,44 @@ class SFG(AbstractOperation): results[key] = value return value - def get_sfg(self,format=None, show_id=False) -> Digraph: + def sfg(self, show_id=False, engine=None) -> Digraph: """ Returns a Digraph of the SFG. Can be directly displayed in IPython. Parameters ---------- - format : string, optional - File format of the generated graph. Output formats can be found at https://www.graphviz.org/doc/info/output.html - Most common are "pdf", "eps", "png", and "svg". Default is None which leads to PDF. - show_id : Boolean, optional If True, the graph_id:s of signals are shown. The default is False. + engine: string, optional + Graphviz layout engine to be used, see https://graphviz.org/documentation/. + Most common are "dot" and "neato". Default is None leading to dot. + Returns ------- Digraph Digraph of the SFG. """ - if format is not None: - pg = Digraph(format=format) - else: - pg = Digraph() - pg.attr(rankdir='LR') + dg = Digraph() + dg.attr(rankdir='LR') + if engine: + assert engine in GRAPHVIZ_ENGINES, "Unknown layout engine" + dg.engine = engine for op in self._components_by_id.values(): if isinstance(op, Signal): if show_id: - pg.edge(op.source.operation.graph_id, op.destination.operation.graph_id, label=op.graph_id) + dg.edge(op.source.operation.graph_id, op.destination.operation.graph_id, label=op.graph_id) else: - pg.edge(op.source.operation.graph_id, op.destination.operation.graph_id) + dg.edge(op.source.operation.graph_id, op.destination.operation.graph_id) else: if op.type_name() == Delay.type_name(): - pg.node(op.graph_id, shape='square') + dg.node(op.graph_id, shape='square') else: - pg.node(op.graph_id) - return pg + dg.node(op.graph_id) + return dg - def show_sfg(self, format=None, show_id=False) -> None: + def show_sfg(self, format=None, show_id=False, engine=None) -> None: """ Shows a visual representation of the SFG using the default system viewer. @@ -898,6 +899,16 @@ class SFG(AbstractOperation): show_id : Boolean, optional If True, the graph_id:s of signals are shown. The default is False. + engine: string, optional + Graphviz layout engine to be used, see https://graphviz.org/documentation/. + Most common are "dot" and "neato". Default is None leading to dot. """ - self.get_sfg(format=format, show_id=show_id).view() + dg = self.sfg(show_id=show_id) + if format: + assert format in GRAPHVIZ_FORMATS, "Unknown file format" + dg.format = format + if engine: + assert engine in GRAPHVIZ_ENGINES, "Unknown layout engine" + dg.engine = engine + dg.view() diff --git a/test/test_sfg.py b/test/test_sfg.py index 00e9931246d7d050ff5af641a4b89a96e9eb30f2..0a5fe96ed0f24607efdcacefdb1a5275cf229d83 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -1039,10 +1039,26 @@ class TestPrecedenceGraph: class TestSFGGraph: - def test_get_sfg(self, sfg_simple_filter): + def test_sfg(self, sfg_simple_filter): res = 'digraph {\n\trankdir=LR\n\tin1\n\tin1 -> ' \ 'add1\n\tout1\n\tt1 -> out1\n\tadd1\n\tcmul1 -> ' \ 'add1\n\tcmul1\n\tadd1 -> t1\n\tt1 [shape=square]\n\tt1 ' \ '-> cmul1\n}' - assert sfg_simple_filter.get_sfg().source == res + assert sfg_simple_filter.sfg().source == res + + def test_sfg_show_id(self, sfg_simple_filter): + res = 'digraph {\n\trankdir=LR\n\tin1\n\tin1 -> add1 ' \ + '[label=s1]\n\tout1\n\tt1 -> out1 [label=s2]\n\tadd1' \ + '\n\tcmul1 -> add1 [label=s3]\n\tcmul1\n\tadd1 -> t1 ' \ + '[label=s4]\n\tt1 [shape=square]\n\tt1 -> cmul1 [label=s5]\n}' + + assert sfg_simple_filter.sfg(show_id=True).source == res + + def test_show_sfg_invalid_format(self, sfg_simple_filter): + with pytest.raises(AssertionError): + sfg_simple_filter.show_sfg(format="ppddff") + + def test_show_sfg_invalid_engine(self, sfg_simple_filter): + with pytest.raises(AssertionError): + sfg_simple_filter.show_sfg(engine="ppddff")