From 44fefc3963afed60f2b09652c03e0f0d981dfc3e Mon Sep 17 00:00:00 2001 From: Oscar Gustafsson <oscar.gustafsson@gmail.com> Date: Wed, 19 Apr 2023 10:07:28 +0200 Subject: [PATCH] Add insert_operation_after --- b_asic/signal_flow_graph.py | 73 +++++++++++++++++++++++++++++++++++++ pyproject.toml | 3 ++ test/test_sfg.py | 52 ++++++++++++++++++++++++++ 3 files changed, 128 insertions(+) diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 49f1372c..78ca679b 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -675,6 +675,79 @@ class SFG(AbstractOperation): # Recreate the newly coupled SFG so that all attributes are correct. return sfg_copy() + def insert_operation_after( + self, + output_comp_id: GraphID, + new_operation: Operation, + ) -> Optional["SFG"]: + """ + Insert an operation in the SFG after a given source operation. + + Then return a new deepcopy of the sfg with the inserted component. + + The graph_id can be an Operation or a Signal. If the operation has multiple + outputs, (copies of) the same operation will be inserted on every port. + To specify a port use ``'graph_id.port_number'``, e.g., ``'sym2p4.1'``. + + Currently, the new operation must have one input and one output. + + Parameters + ---------- + output_comp_id : GraphID + The source operation GraphID to connect from. + new_operation : Operation + The new operation, e.g. Multiplication. + """ + + # Preserve the original SFG by creating a copy. + sfg_copy = self() + if new_operation.output_count != 1 or new_operation.input_count != 1: + raise TypeError( + "Only operations with one input and one output can be inserted." + ) + if "." in output_comp_id: + output_comp_id, port_id = output_comp_id.split(".") + port_id = int(port_id) + else: + port_id = None + + output_comp = sfg_copy.find_by_id(output_comp_id) + if output_comp is None: + raise ValueError(f"Unknown component: {output_comp_id!r}") + if isinstance(output_comp, Operation): + if port_id is None: + sfg_copy._insert_operation_after_operation(output_comp, new_operation) + else: + sfg_copy._insert_operation_after_outputport( + output_comp.output(port_id), new_operation + ) + elif isinstance(output_comp, Signal): + sfg_copy._insert_operation_before_signal(output_comp, new_operation) + + # Recreate the newly coupled SFG so that all attributes are correct. + return sfg_copy() + + def _insert_operation_after_operation( + self, output_operation: Operation, new_operation: Operation + ): + for output in output_operation.outputs: + self._insert_operation_after_outputport(output, new_operation.copy()) + + def _insert_operation_after_outputport( + self, output_port: OutputPort, new_operation: Operation + ): + # Make copy as list will be updated + signal_list = output_port.signals[:] + for signal in signal_list: + signal.set_source(new_operation) + new_operation.input(0).connect(output_port) + + def _insert_operation_before_signal(self, signal: Signal, new_operation: Operation): + output_port = signal.source + output_port.remove_signal(signal) + Signal(output_port, new_operation) + signal.set_source(new_operation) + def remove_operation(self, operation_id: GraphID) -> Union["SFG", None]: """ Returns a version of the SFG where the operation with the specified GraphID diff --git a/pyproject.toml b/pyproject.toml index 3ccf58f7..a4cf1aa6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,3 +73,6 @@ skip = [ packages = ["b_asic", "test"] no_site_packages = true ignore_missing_imports = true + +[tool.coverage.report] +precision = 2 diff --git a/test/test_sfg.py b/test/test_sfg.py index cdb7d16c..2c740f8d 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -1534,3 +1534,55 @@ class TestIsConstant: def test_sfg_nested(self, sfg_nested: SFG): assert not sfg_nested.is_constant + + +class TestInsertComponentAfter: + def test_insert_component_after_in_sfg(self, large_operation_tree_names): + sfg = SFG(outputs=[Output(large_operation_tree_names)]) + sqrt = SquareRoot() + + _sfg = sfg.insert_operation_after( + sfg.find_by_name("constant4")[0].graph_id, sqrt + ) + assert _sfg.evaluate() != sfg.evaluate() + + assert any([isinstance(comp, SquareRoot) for comp in _sfg.operations]) + assert not any([isinstance(comp, SquareRoot) for comp in sfg.operations]) + + assert not isinstance( + sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation, + SquareRoot, + ) + assert isinstance( + _sfg.find_by_name("constant4")[0] + .output(0) + .signals[0] + .destination.operation, + SquareRoot, + ) + + assert sfg.find_by_name("constant4")[0].output(0).signals[ + 0 + ].destination.operation is sfg.find_by_id("add3") + assert _sfg.find_by_name("constant4")[0].output(0).signals[ + 0 + ].destination.operation is not _sfg.find_by_id("add3") + assert _sfg.find_by_id("sqrt1").output(0).signals[ + 0 + ].destination.operation is _sfg.find_by_id("add3") + + def test_insert_component_after_mimo_operation_error( + self, large_operation_tree_names + ): + sfg = SFG(outputs=[Output(large_operation_tree_names)]) + with pytest.raises( + TypeError, match="Only operations with one input and one output" + ): + sfg.insert_operation_after('constant4', SymmetricTwoportAdaptor(0.5)) + + def test_insert_component_after_unknown_component_error( + self, large_operation_tree_names + ): + sfg = SFG(outputs=[Output(large_operation_tree_names)]) + with pytest.raises(ValueError, match="Unknown component:"): + sfg.insert_operation_after('foo', SquareRoot()) -- GitLab