From c2a7ec512f233524e291f680b5a8c4ec55a38911 Mon Sep 17 00:00:00 2001
From: Oscar Gustafsson <oscar.gustafsson@gmail.com>
Date: Mon, 8 Jun 2020 12:03:27 +0200
Subject: [PATCH] Better tests and methods

---
 b_asic/signal_flow_graph.py | 45 +++++++++++++++++++++++--------------
 test/test_sfg.py            | 20 +++++++++++++++--
 2 files changed, 46 insertions(+), 19 deletions(-)

diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py
index 3dd6c279..ef41e1a5 100644
--- a/b_asic/signal_flow_graph.py
+++ b/b_asic/signal_flow_graph.py
@@ -10,6 +10,7 @@ from io import StringIO
 from queue import PriorityQueue
 import itertools as it
 from graphviz import Digraph
+from graphviz.backend import FORMATS as GRAPHVIZ_FORMATS, ENGINES as GRAPHVIZ_ENGINES
 
 from b_asic.port import SignalSourceProvider, OutputPort
 from b_asic.operation import Operation, AbstractOperation, ResultKey, MutableResultMap, MutableDelayMap
@@ -847,44 +848,44 @@ class SFG(AbstractOperation):
         results[key] = value
         return value
 
-    def get_sfg(self,format=None, show_id=False) -> Digraph:
+    def sfg(self, show_id=False, engine=None) -> Digraph:
         """
         Returns a Digraph of the SFG. Can be directly displayed in IPython.
 
         Parameters
         ----------
-        format : string, optional
-            File format of the generated graph. Output formats can be found at https://www.graphviz.org/doc/info/output.html
-            Most common are "pdf", "eps", "png", and "svg". Default is None which leads to PDF.
-
         show_id : Boolean, optional
             If True, the graph_id:s of signals are shown. The default is False.
 
+        engine: string, optional
+            Graphviz layout engine to be used, see https://graphviz.org/documentation/.
+            Most common are "dot" and "neato". Default is None leading to dot.
+
         Returns
         -------
         Digraph
             Digraph of the SFG.
 
         """
-        if format is not None:
-            pg = Digraph(format=format)
-        else:
-            pg = Digraph()
-        pg.attr(rankdir='LR')
+        dg = Digraph()
+        dg.attr(rankdir='LR')
+        if engine:
+            assert engine in GRAPHVIZ_ENGINES, "Unknown layout engine"
+            dg.engine = engine
         for op in self._components_by_id.values():
             if isinstance(op, Signal):
                 if show_id:
-                    pg.edge(op.source.operation.graph_id, op.destination.operation.graph_id, label=op.graph_id)
+                    dg.edge(op.source.operation.graph_id, op.destination.operation.graph_id, label=op.graph_id)
                 else:
-                    pg.edge(op.source.operation.graph_id, op.destination.operation.graph_id)
+                    dg.edge(op.source.operation.graph_id, op.destination.operation.graph_id)
             else:
                 if op.type_name() == Delay.type_name():
-                    pg.node(op.graph_id, shape='square')
+                    dg.node(op.graph_id, shape='square')
                 else:
-                    pg.node(op.graph_id)
-        return pg
+                    dg.node(op.graph_id)
+        return dg
 
-    def show_sfg(self, format=None, show_id=False) -> None:
+    def show_sfg(self, format=None, show_id=False, engine=None) -> None:
         """
         Shows a visual representation of the SFG using the default system viewer.
 
@@ -898,6 +899,16 @@ class SFG(AbstractOperation):
         show_id : Boolean, optional
             If True, the graph_id:s of signals are shown. The default is False.
 
+        engine: string, optional
+            Graphviz layout engine to be used, see https://graphviz.org/documentation/.
+            Most common are "dot" and "neato". Default is None leading to dot.
         """
 
-        self.get_sfg(format=format, show_id=show_id).view()
+        dg = self.sfg(show_id=show_id)
+        if format:
+            assert format in GRAPHVIZ_FORMATS, "Unknown file format"
+            dg.format = format
+        if engine:
+            assert engine in GRAPHVIZ_ENGINES, "Unknown layout engine"
+            dg.engine = engine
+        dg.view()
diff --git a/test/test_sfg.py b/test/test_sfg.py
index 00e99312..0a5fe96e 100644
--- a/test/test_sfg.py
+++ b/test/test_sfg.py
@@ -1039,10 +1039,26 @@ class TestPrecedenceGraph:
 
 
 class TestSFGGraph:
-    def test_get_sfg(self, sfg_simple_filter):
+    def test_sfg(self, sfg_simple_filter):
         res = 'digraph {\n\trankdir=LR\n\tin1\n\tin1 -> ' \
             'add1\n\tout1\n\tt1 -> out1\n\tadd1\n\tcmul1 -> ' \
             'add1\n\tcmul1\n\tadd1 -> t1\n\tt1 [shape=square]\n\tt1 ' \
             '-> cmul1\n}'
 
-        assert sfg_simple_filter.get_sfg().source == res
+        assert sfg_simple_filter.sfg().source == res
+
+    def test_sfg_show_id(self, sfg_simple_filter):
+        res = 'digraph {\n\trankdir=LR\n\tin1\n\tin1 -> add1 ' \
+            '[label=s1]\n\tout1\n\tt1 -> out1 [label=s2]\n\tadd1' \
+            '\n\tcmul1 -> add1 [label=s3]\n\tcmul1\n\tadd1 -> t1 ' \
+            '[label=s4]\n\tt1 [shape=square]\n\tt1 -> cmul1 [label=s5]\n}'
+
+        assert sfg_simple_filter.sfg(show_id=True).source == res
+
+    def test_show_sfg_invalid_format(self, sfg_simple_filter):
+        with pytest.raises(AssertionError):
+            sfg_simple_filter.show_sfg(format="ppddff")
+
+    def test_show_sfg_invalid_engine(self, sfg_simple_filter):
+        with pytest.raises(AssertionError):
+            sfg_simple_filter.show_sfg(engine="ppddff")
-- 
GitLab