From 44fefc3963afed60f2b09652c03e0f0d981dfc3e Mon Sep 17 00:00:00 2001
From: Oscar Gustafsson <oscar.gustafsson@gmail.com>
Date: Wed, 19 Apr 2023 10:07:28 +0200
Subject: [PATCH] Add insert_operation_after

---
 b_asic/signal_flow_graph.py | 73 +++++++++++++++++++++++++++++++++++++
 pyproject.toml              |  3 ++
 test/test_sfg.py            | 52 ++++++++++++++++++++++++++
 3 files changed, 128 insertions(+)

diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py
index 49f1372c..78ca679b 100644
--- a/b_asic/signal_flow_graph.py
+++ b/b_asic/signal_flow_graph.py
@@ -675,6 +675,79 @@ class SFG(AbstractOperation):
         # Recreate the newly coupled SFG so that all attributes are correct.
         return sfg_copy()
 
+    def insert_operation_after(
+        self,
+        output_comp_id: GraphID,
+        new_operation: Operation,
+    ) -> Optional["SFG"]:
+        """
+        Insert an operation in the SFG after a given source operation.
+
+        Then return a new deepcopy of the sfg with the inserted component.
+
+        The graph_id can be an Operation or a Signal. If the operation has multiple
+        outputs, (copies of) the same operation will be inserted on every port.
+        To specify a port use ``'graph_id.port_number'``, e.g., ``'sym2p4.1'``.
+
+        Currently, the new operation must have one input and one output.
+
+        Parameters
+        ----------
+        output_comp_id : GraphID
+            The source operation GraphID to connect from.
+        new_operation : Operation
+            The new operation, e.g. Multiplication.
+        """
+
+        # Preserve the original SFG by creating a copy.
+        sfg_copy = self()
+        if new_operation.output_count != 1 or new_operation.input_count != 1:
+            raise TypeError(
+                "Only operations with one input and one output can be inserted."
+            )
+        if "." in output_comp_id:
+            output_comp_id, port_id = output_comp_id.split(".")
+            port_id = int(port_id)
+        else:
+            port_id = None
+
+        output_comp = sfg_copy.find_by_id(output_comp_id)
+        if output_comp is None:
+            raise ValueError(f"Unknown component: {output_comp_id!r}")
+        if isinstance(output_comp, Operation):
+            if port_id is None:
+                sfg_copy._insert_operation_after_operation(output_comp, new_operation)
+            else:
+                sfg_copy._insert_operation_after_outputport(
+                    output_comp.output(port_id), new_operation
+                )
+        elif isinstance(output_comp, Signal):
+            sfg_copy._insert_operation_before_signal(output_comp, new_operation)
+
+        # Recreate the newly coupled SFG so that all attributes are correct.
+        return sfg_copy()
+
+    def _insert_operation_after_operation(
+        self, output_operation: Operation, new_operation: Operation
+    ):
+        for output in output_operation.outputs:
+            self._insert_operation_after_outputport(output, new_operation.copy())
+
+    def _insert_operation_after_outputport(
+        self, output_port: OutputPort, new_operation: Operation
+    ):
+        # Make copy as list will be updated
+        signal_list = output_port.signals[:]
+        for signal in signal_list:
+            signal.set_source(new_operation)
+        new_operation.input(0).connect(output_port)
+
+    def _insert_operation_before_signal(self, signal: Signal, new_operation: Operation):
+        output_port = signal.source
+        output_port.remove_signal(signal)
+        Signal(output_port, new_operation)
+        signal.set_source(new_operation)
+
     def remove_operation(self, operation_id: GraphID) -> Union["SFG", None]:
         """
         Returns a version of the SFG where the operation with the specified GraphID
diff --git a/pyproject.toml b/pyproject.toml
index 3ccf58f7..a4cf1aa6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -73,3 +73,6 @@ skip = [
 packages = ["b_asic", "test"]
 no_site_packages = true
 ignore_missing_imports = true
+
+[tool.coverage.report]
+precision = 2
diff --git a/test/test_sfg.py b/test/test_sfg.py
index cdb7d16c..2c740f8d 100644
--- a/test/test_sfg.py
+++ b/test/test_sfg.py
@@ -1534,3 +1534,55 @@ class TestIsConstant:
 
     def test_sfg_nested(self, sfg_nested: SFG):
         assert not sfg_nested.is_constant
+
+
+class TestInsertComponentAfter:
+    def test_insert_component_after_in_sfg(self, large_operation_tree_names):
+        sfg = SFG(outputs=[Output(large_operation_tree_names)])
+        sqrt = SquareRoot()
+
+        _sfg = sfg.insert_operation_after(
+            sfg.find_by_name("constant4")[0].graph_id, sqrt
+        )
+        assert _sfg.evaluate() != sfg.evaluate()
+
+        assert any([isinstance(comp, SquareRoot) for comp in _sfg.operations])
+        assert not any([isinstance(comp, SquareRoot) for comp in sfg.operations])
+
+        assert not isinstance(
+            sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation,
+            SquareRoot,
+        )
+        assert isinstance(
+            _sfg.find_by_name("constant4")[0]
+            .output(0)
+            .signals[0]
+            .destination.operation,
+            SquareRoot,
+        )
+
+        assert sfg.find_by_name("constant4")[0].output(0).signals[
+            0
+        ].destination.operation is sfg.find_by_id("add3")
+        assert _sfg.find_by_name("constant4")[0].output(0).signals[
+            0
+        ].destination.operation is not _sfg.find_by_id("add3")
+        assert _sfg.find_by_id("sqrt1").output(0).signals[
+            0
+        ].destination.operation is _sfg.find_by_id("add3")
+
+    def test_insert_component_after_mimo_operation_error(
+        self, large_operation_tree_names
+    ):
+        sfg = SFG(outputs=[Output(large_operation_tree_names)])
+        with pytest.raises(
+            TypeError, match="Only operations with one input and one output"
+        ):
+            sfg.insert_operation_after('constant4', SymmetricTwoportAdaptor(0.5))
+
+    def test_insert_component_after_unknown_component_error(
+        self, large_operation_tree_names
+    ):
+        sfg = SFG(outputs=[Output(large_operation_tree_names)])
+        with pytest.raises(ValueError, match="Unknown component:"):
+            sfg.insert_operation_after('foo', SquareRoot())
-- 
GitLab