Skip to content
Snippets Groups Projects
Commit ae11979e authored by Oscar Gustafsson's avatar Oscar Gustafsson :bicyclist:
Browse files

Added sfg and show_sfg methods

parent 48178a0c
No related branches found
No related tags found
No related merge requests found
......@@ -10,9 +10,10 @@ from io import StringIO
from queue import PriorityQueue
import itertools as it
from graphviz import Digraph
from graphviz.backend import FORMATS as GRAPHVIZ_FORMATS, ENGINES as GRAPHVIZ_ENGINES
from b_asic.port import SignalSourceProvider, OutputPort
from b_asic.operation import Operation, AbstractOperation, ResultKey, DelayMap, MutableResultMap, MutableDelayMap
from b_asic.operation import Operation, AbstractOperation, ResultKey, MutableResultMap, MutableDelayMap
from b_asic.signal import Signal
from b_asic.graph_component import GraphID, GraphIDNumber, GraphComponent, Name, TypeName
from b_asic.special_operations import Input, Output, Delay
......@@ -846,3 +847,71 @@ class SFG(AbstractOperation):
src.index, input_values, results, delays, key_base, bits_override, truncate)
results[key] = value
return value
def sfg(self, show_id=False, engine=None) -> Digraph:
"""
Returns a Digraph of the SFG. Can be directly displayed in IPython.
Parameters
----------
show_id : Boolean, optional
If True, the graph_id:s of signals are shown. The default is False.
engine: string, optional
Graphviz layout engine to be used, see https://graphviz.org/documentation/.
Most common are "dot" and "neato". Default is None leading to dot.
Returns
-------
Digraph
Digraph of the SFG.
"""
dg = Digraph()
dg.attr(rankdir='LR')
if engine:
assert engine in GRAPHVIZ_ENGINES, "Unknown layout engine"
dg.engine = engine
for op in self._components_by_id.values():
if isinstance(op, Signal):
if show_id:
dg.edge(op.source.operation.graph_id, op.destination.operation.graph_id, label=op.graph_id)
else:
dg.edge(op.source.operation.graph_id, op.destination.operation.graph_id)
else:
if op.type_name() == Delay.type_name():
dg.node(op.graph_id, shape='square')
else:
dg.node(op.graph_id)
return dg
def _repr_svg_(self):
return self.sfg()._repr_svg_()
def show_sfg(self, format=None, show_id=False, engine=None) -> None:
"""
Shows a visual representation of the SFG using the default system viewer.
Parameters
----------
format : string, optional
File format of the generated graph. Output formats can be found at https://www.graphviz.org/doc/info/output.html
Most common are "pdf", "eps", "png", and "svg". Default is None which leads to PDF.
show_id : Boolean, optional
If True, the graph_id:s of signals are shown. The default is False.
engine: string, optional
Graphviz layout engine to be used, see https://graphviz.org/documentation/.
Most common are "dot" and "neato". Default is None leading to dot.
"""
dg = self.sfg(show_id=show_id)
if format:
assert format in GRAPHVIZ_FORMATS, "Unknown file format"
dg.format = format
if engine:
assert engine in GRAPHVIZ_ENGINES, "Unknown layout engine"
dg.engine = engine
dg.view()
......@@ -1016,3 +1016,49 @@ class TestGetComponentsOfType:
assert [op.name for op in sfg_two_inputs_two_outputs.find_by_type_name(Output.type_name())] \
== ["OUT1", "OUT2"]
class TestPrecedenceGraph:
def test_precedence_graph(self, sfg_simple_filter):
res = 'digraph {\n\trankdir=LR\n\tsubgraph cluster_0 ' \
'{\n\t\tlabel=N1\n\t\t"in1.0" [label=in1]\n\t\t"t1.0" [label=t1]' \
'\n\t}\n\tsubgraph cluster_1 {\n\t\tlabel=N2\n\t\t"cmul1.0" ' \
'[label=cmul1]\n\t}\n\tsubgraph cluster_2 ' \
'{\n\t\tlabel=N3\n\t\t"add1.0" [label=add1]\n\t}\n\t"in1.0" ' \
'-> add1\n\tadd1 [label=add1 shape=square]\n\tin1 -> "in1.0"' \
'\n\tin1 [label=in1 shape=square]\n\t"t1.0" -> cmul1\n\tcmul1 ' \
'[label=cmul1 shape=square]\n\t"t1.0" -> out1\n\tout1 ' \
'[label=out1 shape=square]\n\tt1Out -> "t1.0"\n\tt1Out ' \
'[label=t1 shape=square]\n\t"cmul1.0" -> add1\n\tadd1 ' \
'[label=add1 shape=square]\n\tcmul1 -> "cmul1.0"\n\tcmul1 ' \
'[label=cmul1 shape=square]\n\t"add1.0" -> t1In\n\tt1In ' \
'[label=t1 shape=square]\n\tadd1 -> "add1.0"\n\tadd1 ' \
'[label=add1 shape=square]\n}'
assert sfg_simple_filter.precedence_graph().source == res
class TestSFGGraph:
def test_sfg(self, sfg_simple_filter):
res = 'digraph {\n\trankdir=LR\n\tin1\n\tin1 -> ' \
'add1\n\tout1\n\tt1 -> out1\n\tadd1\n\tcmul1 -> ' \
'add1\n\tcmul1\n\tadd1 -> t1\n\tt1 [shape=square]\n\tt1 ' \
'-> cmul1\n}'
assert sfg_simple_filter.sfg().source == res
def test_sfg_show_id(self, sfg_simple_filter):
res = 'digraph {\n\trankdir=LR\n\tin1\n\tin1 -> add1 ' \
'[label=s1]\n\tout1\n\tt1 -> out1 [label=s2]\n\tadd1' \
'\n\tcmul1 -> add1 [label=s3]\n\tcmul1\n\tadd1 -> t1 ' \
'[label=s4]\n\tt1 [shape=square]\n\tt1 -> cmul1 [label=s5]\n}'
assert sfg_simple_filter.sfg(show_id=True).source == res
def test_show_sfg_invalid_format(self, sfg_simple_filter):
with pytest.raises(AssertionError):
sfg_simple_filter.show_sfg(format="ppddff")
def test_show_sfg_invalid_engine(self, sfg_simple_filter):
with pytest.raises(AssertionError):
sfg_simple_filter.show_sfg(engine="ppddff")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment