From 8f9d3fc19ace612619086d0c5def7d30d74e9bd0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ivar=20H=C3=A4rnqvist?= <ivarhar@outlook.com>
Date: Thu, 9 Apr 2020 14:51:46 +0200
Subject: [PATCH] add code to check for direct feedback loops during evaluation

---
 b_asic/operation.py                | 45 ++++++++++++++++++------------
 b_asic/signal_flow_graph.py        | 18 ++++++------
 b_asic/special_operations.py       |  2 +-
 test/fixtures/operation_tree.py    |  5 ++--
 test/fixtures/signal_flow_graph.py |  4 ++-
 test/test_sfg.py                   | 11 ++++++--
 6 files changed, 49 insertions(+), 36 deletions(-)

diff --git a/b_asic/operation.py b/b_asic/operation.py
index a8dc7a96..45ce01a0 100644
--- a/b_asic/operation.py
+++ b/b_asic/operation.py
@@ -93,7 +93,7 @@ class Operation(GraphComponent, SignalSourceProvider):
         raise NotImplementedError
 
     @abstractmethod
-    def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[str, Number]] = None, registers: Optional[MutableMapping[str, Number]] = None, prefix: str = "") -> Number:
+    def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[str, Optional[Number]]] = None, registers: Optional[MutableMapping[str, Number]] = None, prefix: str = "") -> Number:
         """Evaluate the output at the given index of this operation with the given input values.
         The results parameter will be used to store any intermediate results for caching.
         The registers parameter will be used to get the current value of any intermediate registers that are encountered, and be updated with their new values.
@@ -169,6 +169,17 @@ class AbstractOperation(Operation, AbstractGraphComponent):
         """
         raise NotImplementedError
 
+    def _find_result(self, prefix: str, index: int, results: MutableMapping[str, Optional[Number]]) -> Optional[Number]:
+        key = results_key(self.output_count, prefix, index)
+        if key in results:
+            value = results[key]
+            if value is None:
+                raise RuntimeError(f"Direct feedback loop detected when evaluating operation.")
+            return value
+
+        results[key] = None
+        return None
+
     def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Addition, ConstantAddition]":
         # Import here to avoid circular imports.
         from b_asic.core_operations import Addition, ConstantAddition
@@ -223,36 +234,34 @@ class AbstractOperation(Operation, AbstractGraphComponent):
     def output(self, i: int) -> OutputPort:
         return self._output_ports[i]
 
-    def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[str, Number]] = None, registers: Optional[MutableMapping[str, Number]] = None, prefix: str = "") -> Number:
+    def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[str, Optional[Number]]] = None, registers: Optional[MutableMapping[str, Number]] = None, prefix: str = "") -> Number:
         if index < 0 or index >= self.output_count:
             raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {index})")
         if results is None:
             results = {}
         if registers is None:
             registers = {}
-        
-        key = results_key(self.output_count, prefix, index)
-        if key in results:
-            return results[key]
 
-        result = self.evaluate(*input_values)
-        if isinstance(result, collections.Sequence):
-            if len(result) != self.output_count:
-                raise RuntimeError(f"Operation evaluated to incorrect number of outputs (expected {self.output_count}, got {len(result)})")
-        elif isinstance(result, Number):
+        result = self._find_result(prefix, index, results)
+        if result is not None:
+            return result
+        values = self.evaluate(*input_values)
+        if isinstance(values, collections.Sequence):
+            if len(values) != self.output_count:
+                raise RuntimeError(f"Operation evaluated to incorrect number of outputs (expected {self.output_count}, got {len(values)})")
+        elif isinstance(values, Number):
             if self.output_count != 1:
                 raise RuntimeError(f"Operation evaluated to incorrect number of outputs (expected {self.output_count}, got 1)")
-            result = (result,)
+            values = (values,)
         else:
-            raise RuntimeError(f"Operation evaluated to invalid type (expected Sequence/Number, got {result.__class__.__name__})")
+            raise RuntimeError(f"Operation evaluated to invalid type (expected Sequence/Number, got {values.__class__.__name__})")
 
         if self.output_count == 1:
-            results[key] = result[index]
+            results[results_key(self.output_count, prefix, index)] = values[index]
         else:
-            for i, value in enumerate(result):
-                results[results_key(self.output_count, prefix, i)] = value
-        return result[index]
-
+            for i in range(self.output_count):
+                results[results_key(self.output_count, prefix, i)] = values[i]
+        return values[index]
 
     def evaluate_outputs(self, input_values: Sequence[Number], results: MutableMapping[str, Number], registers: MutableMapping[str, Number], prefix: str = "") -> Sequence[Number]:
         return [self.evaluate_output(i, input_values, results, registers, prefix) for i in range(self.output_count)]
diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py
index b796f160..baa0c004 100644
--- a/b_asic/signal_flow_graph.py
+++ b/b_asic/signal_flow_graph.py
@@ -102,7 +102,7 @@ class SFG(AbstractOperation):
         n = len(result)
         return None if n == 0 else result[0] if n == 1 else result
 
-    def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[str, Number]] = None, registers: Optional[MutableMapping[str, Number]] = None, prefix: str = "") -> Number:
+    def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[str, Optional[Number]]] = None, registers: Optional[MutableMapping[str, Number]] = None, prefix: str = "") -> Number:
         if index < 0 or index >= self.output_count:
             raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {index})")
         if len(input_values) != self.input_count:
@@ -111,16 +111,17 @@ class SFG(AbstractOperation):
             results = {}
         if registers is None:
             registers = {}
+
+        result = self._find_result(prefix, index, results)
+        if result is not None:
+            return result
         
         # Set the values of our input operations to the given input values.
         for op, arg in zip(self._input_operations, input_values):
             op.value = arg
-
-        key = results_key(self.output_count, prefix, index)
-        if key in results:
-            return results[key]
+        
         value = self._evaluate_source(self._output_operations[index].input(0).signals[0].source, results, registers, prefix)
-        results[key] = value
+        results[results_key(self.output_count, prefix, index)] = value
         return value
 
     def split(self) -> Iterable[Operation]:
@@ -206,10 +207,7 @@ class SFG(AbstractOperation):
         if op_prefix:
             op_prefix += "."
         op_prefix += src.operation.graph_id
-        key = results_key(src.operation.output_count, op_prefix, src.index)
-        if key in results:
-            return results[key]
         input_values = [self._evaluate_source(input_port.signals[0].source, results, registers, prefix) for input_port in src.operation.inputs]
         value = src.operation.evaluate_output(src.index, input_values, results, registers, op_prefix)
-        results[key] = value
+        results[results_key(src.operation.output_count, op_prefix, src.index)] = value
         return value
\ No newline at end of file
diff --git a/b_asic/special_operations.py b/b_asic/special_operations.py
index 36ebf2d8..5a43ded1 100644
--- a/b_asic/special_operations.py
+++ b/b_asic/special_operations.py
@@ -70,7 +70,7 @@ class Register(AbstractOperation):
     def evaluate(self, a):
         return self.param("initial_value")
 
-    def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[str, Number]] = None, registers: Optional[MutableMapping[str, Number]] = None, prefix: str = ""):
+    def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[str, Optional[Number]]] = None, registers: Optional[MutableMapping[str, Number]] = None, prefix: str = ""):
         if index != 0:
             raise IndexError(f"Output index out of range (expected 0-0, got {index})")
         if len(input_values) != 1:
diff --git a/test/fixtures/operation_tree.py b/test/fixtures/operation_tree.py
index 32758675..3ac35110 100644
--- a/test/fixtures/operation_tree.py
+++ b/test/fixtures/operation_tree.py
@@ -46,7 +46,6 @@ def operation_graph_with_cycle():
                     |
                     6
     """
-    c1 = Constant(7)
-    add1 = Addition(None, c1)
+    add1 = Addition(None, Constant(7))
     add1.input(0).connect(add1)
-    return Addition(add1, c1)
+    return Addition(add1, Constant(6))
diff --git a/test/fixtures/signal_flow_graph.py b/test/fixtures/signal_flow_graph.py
index c8e1dc9a..06cfed56 100644
--- a/test/fixtures/signal_flow_graph.py
+++ b/test/fixtures/signal_flow_graph.py
@@ -25,4 +25,6 @@ def sfg_two_inputs_two_outputs():
     add2=Addition(add1, in2)
     out1=Output(add1)
     out2=Output(add2)
-    return SFG(inputs = [in1, in2], outputs = [out1, out2])
\ No newline at end of file
+    return SFG(inputs = [in1, in2], outputs = [out1, out2])
+
+# TODO: Testa nestad sfg
\ No newline at end of file
diff --git a/test/test_sfg.py b/test/test_sfg.py
index 76c5178e..216e4fd9 100644
--- a/test/test_sfg.py
+++ b/test/test_sfg.py
@@ -25,9 +25,14 @@ class TestConstructor:
 
 class TestEvaluation:
     def test_evaluate_output(self, operation_tree):
-        sfg = SFG(outputs=[Output(operation_tree)])
+        sfg = SFG(outputs = [Output(operation_tree)])
         assert sfg.evaluate_output(0, []) == 5
 
     def test_evaluate_output_large(self, large_operation_tree):
-        sfg = SFG(outputs=[Output(large_operation_tree)])
-        assert sfg.evaluate_output(0, []) == 14
\ No newline at end of file
+        sfg = SFG(outputs = [Output(large_operation_tree)])
+        assert sfg.evaluate_output(0, []) == 14
+
+    def test_evaluate_output_cycle(self, operation_graph_with_cycle):
+        with pytest.raises(Exception):
+            sfg = SFG(outputs = [Output(operation_graph_with_cycle)])
+            sfg.evaluate_output(0, [])
\ No newline at end of file
-- 
GitLab