diff --git a/b_asic/operation.py b/b_asic/operation.py index a8dc7a96d07d139fe388fdb3b723a9b7c136a925..45ce01a00f976f2ff06739fbd8287990f8ba5bb4 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -93,7 +93,7 @@ class Operation(GraphComponent, SignalSourceProvider): raise NotImplementedError @abstractmethod - def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[str, Number]] = None, registers: Optional[MutableMapping[str, Number]] = None, prefix: str = "") -> Number: + def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[str, Optional[Number]]] = None, registers: Optional[MutableMapping[str, Number]] = None, prefix: str = "") -> Number: """Evaluate the output at the given index of this operation with the given input values. The results parameter will be used to store any intermediate results for caching. The registers parameter will be used to get the current value of any intermediate registers that are encountered, and be updated with their new values. @@ -169,6 +169,17 @@ class AbstractOperation(Operation, AbstractGraphComponent): """ raise NotImplementedError + def _find_result(self, prefix: str, index: int, results: MutableMapping[str, Optional[Number]]) -> Optional[Number]: + key = results_key(self.output_count, prefix, index) + if key in results: + value = results[key] + if value is None: + raise RuntimeError(f"Direct feedback loop detected when evaluating operation.") + return value + + results[key] = None + return None + def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Addition, ConstantAddition]": # Import here to avoid circular imports. from b_asic.core_operations import Addition, ConstantAddition @@ -223,36 +234,34 @@ class AbstractOperation(Operation, AbstractGraphComponent): def output(self, i: int) -> OutputPort: return self._output_ports[i] - def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[str, Number]] = None, registers: Optional[MutableMapping[str, Number]] = None, prefix: str = "") -> Number: + def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[str, Optional[Number]]] = None, registers: Optional[MutableMapping[str, Number]] = None, prefix: str = "") -> Number: if index < 0 or index >= self.output_count: raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {index})") if results is None: results = {} if registers is None: registers = {} - - key = results_key(self.output_count, prefix, index) - if key in results: - return results[key] - result = self.evaluate(*input_values) - if isinstance(result, collections.Sequence): - if len(result) != self.output_count: - raise RuntimeError(f"Operation evaluated to incorrect number of outputs (expected {self.output_count}, got {len(result)})") - elif isinstance(result, Number): + result = self._find_result(prefix, index, results) + if result is not None: + return result + values = self.evaluate(*input_values) + if isinstance(values, collections.Sequence): + if len(values) != self.output_count: + raise RuntimeError(f"Operation evaluated to incorrect number of outputs (expected {self.output_count}, got {len(values)})") + elif isinstance(values, Number): if self.output_count != 1: raise RuntimeError(f"Operation evaluated to incorrect number of outputs (expected {self.output_count}, got 1)") - result = (result,) + values = (values,) else: - raise RuntimeError(f"Operation evaluated to invalid type (expected Sequence/Number, got {result.__class__.__name__})") + raise RuntimeError(f"Operation evaluated to invalid type (expected Sequence/Number, got {values.__class__.__name__})") if self.output_count == 1: - results[key] = result[index] + results[results_key(self.output_count, prefix, index)] = values[index] else: - for i, value in enumerate(result): - results[results_key(self.output_count, prefix, i)] = value - return result[index] - + for i in range(self.output_count): + results[results_key(self.output_count, prefix, i)] = values[i] + return values[index] def evaluate_outputs(self, input_values: Sequence[Number], results: MutableMapping[str, Number], registers: MutableMapping[str, Number], prefix: str = "") -> Sequence[Number]: return [self.evaluate_output(i, input_values, results, registers, prefix) for i in range(self.output_count)] diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index b796f1609bdfacc6008c7979a8b47f3e47b52996..baa0c0040c7490c39bc1937b413590d37be25f98 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -102,7 +102,7 @@ class SFG(AbstractOperation): n = len(result) return None if n == 0 else result[0] if n == 1 else result - def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[str, Number]] = None, registers: Optional[MutableMapping[str, Number]] = None, prefix: str = "") -> Number: + def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[str, Optional[Number]]] = None, registers: Optional[MutableMapping[str, Number]] = None, prefix: str = "") -> Number: if index < 0 or index >= self.output_count: raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {index})") if len(input_values) != self.input_count: @@ -111,16 +111,17 @@ class SFG(AbstractOperation): results = {} if registers is None: registers = {} + + result = self._find_result(prefix, index, results) + if result is not None: + return result # Set the values of our input operations to the given input values. for op, arg in zip(self._input_operations, input_values): op.value = arg - - key = results_key(self.output_count, prefix, index) - if key in results: - return results[key] + value = self._evaluate_source(self._output_operations[index].input(0).signals[0].source, results, registers, prefix) - results[key] = value + results[results_key(self.output_count, prefix, index)] = value return value def split(self) -> Iterable[Operation]: @@ -206,10 +207,7 @@ class SFG(AbstractOperation): if op_prefix: op_prefix += "." op_prefix += src.operation.graph_id - key = results_key(src.operation.output_count, op_prefix, src.index) - if key in results: - return results[key] input_values = [self._evaluate_source(input_port.signals[0].source, results, registers, prefix) for input_port in src.operation.inputs] value = src.operation.evaluate_output(src.index, input_values, results, registers, op_prefix) - results[key] = value + results[results_key(src.operation.output_count, op_prefix, src.index)] = value return value \ No newline at end of file diff --git a/b_asic/special_operations.py b/b_asic/special_operations.py index 36ebf2d82cb8f6dd9646accd7866b39cd5f57a91..5a43ded109d59f8cf87eb08142776febde3868f3 100644 --- a/b_asic/special_operations.py +++ b/b_asic/special_operations.py @@ -70,7 +70,7 @@ class Register(AbstractOperation): def evaluate(self, a): return self.param("initial_value") - def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[str, Number]] = None, registers: Optional[MutableMapping[str, Number]] = None, prefix: str = ""): + def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[str, Optional[Number]]] = None, registers: Optional[MutableMapping[str, Number]] = None, prefix: str = ""): if index != 0: raise IndexError(f"Output index out of range (expected 0-0, got {index})") if len(input_values) != 1: diff --git a/test/fixtures/operation_tree.py b/test/fixtures/operation_tree.py index 327586757ed4ae2127d9fb5f8beebfd38fe93dc7..3ac35110c79b1fe3fc7af86e6e0ad975be1db396 100644 --- a/test/fixtures/operation_tree.py +++ b/test/fixtures/operation_tree.py @@ -46,7 +46,6 @@ def operation_graph_with_cycle(): | 6 """ - c1 = Constant(7) - add1 = Addition(None, c1) + add1 = Addition(None, Constant(7)) add1.input(0).connect(add1) - return Addition(add1, c1) + return Addition(add1, Constant(6)) diff --git a/test/fixtures/signal_flow_graph.py b/test/fixtures/signal_flow_graph.py index c8e1dc9aae987349f0fc830b9ee58952781a44ed..06cfed56f8dc95333b53a10ff012fc4baa3cf7f9 100644 --- a/test/fixtures/signal_flow_graph.py +++ b/test/fixtures/signal_flow_graph.py @@ -25,4 +25,6 @@ def sfg_two_inputs_two_outputs(): add2=Addition(add1, in2) out1=Output(add1) out2=Output(add2) - return SFG(inputs = [in1, in2], outputs = [out1, out2]) \ No newline at end of file + return SFG(inputs = [in1, in2], outputs = [out1, out2]) + +# TODO: Testa nestad sfg \ No newline at end of file diff --git a/test/test_sfg.py b/test/test_sfg.py index 76c5178ec1c1482d63e2d295567f966cdc1a85eb..216e4fd90c75d8e181d7a69828c1e395d7868c2f 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -25,9 +25,14 @@ class TestConstructor: class TestEvaluation: def test_evaluate_output(self, operation_tree): - sfg = SFG(outputs=[Output(operation_tree)]) + sfg = SFG(outputs = [Output(operation_tree)]) assert sfg.evaluate_output(0, []) == 5 def test_evaluate_output_large(self, large_operation_tree): - sfg = SFG(outputs=[Output(large_operation_tree)]) - assert sfg.evaluate_output(0, []) == 14 \ No newline at end of file + sfg = SFG(outputs = [Output(large_operation_tree)]) + assert sfg.evaluate_output(0, []) == 14 + + def test_evaluate_output_cycle(self, operation_graph_with_cycle): + with pytest.raises(Exception): + sfg = SFG(outputs = [Output(operation_graph_with_cycle)]) + sfg.evaluate_output(0, []) \ No newline at end of file