From 729c11beced60deaa90e31163bbfcfc0e7c13769 Mon Sep 17 00:00:00 2001
From: Kevin <Kevin>
Date: Sun, 3 May 2020 14:18:52 +0200
Subject: [PATCH] implemented operation replacement and changed some test since
 we return a deepcopy of the sfg

---
 b_asic/signal_flow_graph.py | 42 ++++++++++++++++++++++---------------
 test/test_sfg.py            | 10 ++++-----
 2 files changed, 30 insertions(+), 22 deletions(-)

diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py
index 19675e33..28e0f66b 100644
--- a/b_asic/signal_flow_graph.py
+++ b/b_asic/signal_flow_graph.py
@@ -406,7 +406,7 @@ class SFG(AbstractOperation):
         return self()
 
     def replace_operations(self, operation_ids: Sequence[GraphID], operation: Operation):
-        """Replace multiple operations in the sfg with a operation of equivalent functionallity with the same number of inputs and outputs.
+        """Replace multiple operations in the sfg with a operation with the same amount of inputs and outputs.
         Then return a new deepcopy of the sfg with the replaced operations.
 
         Arguments:
@@ -414,28 +414,36 @@ class SFG(AbstractOperation):
         operation: The operation used for replacement.
         """
 
-        operations = [self.find_by_id(_id) for _id in operation_ids]
+        inputs = []
+        outputs = []
 
-        assert sum(o.input_count + o.output_count for o in operations) == operation.input_count + operation.output_count, \
-            "The input and output count must match"
-
-        # Create a copy of the sfg for manipulating
-        _sfg = self()
+        for _operation in self.operations:
+            if _operation.graph_id not in operation_ids:
+                continue
+            
+            # Retrive input operations
+            for _signal in _operation.input_signals:
+                if _signal.source.operation.graph_id not in operation_ids:
+                    inputs.append(_signal.source.operation)
 
-        for operation in operations:
-            operation
+            # Retrive output operations
+            for _signal in _operation.output_signals:
+                if _signal.destination.operation.graph_id not in operation_ids:
+                    outputs.append(_signal.destination.operation)
 
+        assert len(inputs) + len(outputs) == \
+            operation.input_count + operation.output_count, "The input and output count must match"
 
-        for _index, _inp in enumerate(inputs):
-            for _signal in _inp.output_signals:
+        for index_in, _input in enumerate(inputs):
+            for _signal in _input.output_signals:
                 _signal.remove_destination()
-                _signal.set_destination(operation.input(_index))
+                _signal.set_destination(operation.input(index_in))
 
-        for _index, _out in enumerate(outputs):
-            for _signal in _out.input_signals:
-                _signal.remove_destination()
-                _signal.set_source(operation.output(_index))
-            
+        for index_out, _output in enumerate(outputs):
+            for _signal in _output.input_signals:
+                _signal.remove_source()
+                _signal.set_source(operation.output(index_out))
+        
         return self()
 
     def _evaluate_source(self, src: OutputPort, results: MutableResultMap, registers: MutableRegisterMap, prefix: str) -> Number:
diff --git a/test/test_sfg.py b/test/test_sfg.py
index 8c00cd73..2023ffcf 100644
--- a/test/test_sfg.py
+++ b/test/test_sfg.py
@@ -268,22 +268,22 @@ class TestReplaceOperations:
         mad1 = MAD()
         _sfg = sfg.replace_operations(['add1', 'mul1'], mad1)
 
-        assert mad1 in _sfg.operations
+        assert _sfg.find_by_id('mad1') is not None
         assert {add1, mul1} not in _sfg.operations
 
     def test_replace_neg_add_with_sub(self):
         in1 = Input()
         in2 = Input()
-        neg1 = ConstantMultiplication(-1, in1)
+        neg1 = ConstantMultiplication(-1, in1, 'neg1')
         add1 = neg1 + in2
         out1 = Output(add1)
         sfg = SFG(inputs=[in1, in2], outputs=[out1])
 
         sub1 = Subtraction()
-        sfg.replace_operations(['add1, neg1'], sub1)
+        _sfg = sfg.replace_operations(['add1', 'neg1'], sub1)
 
-        assert sub1 in sfg.operations
-        assert {add1, neg1} not in sfg.operations
+        assert _sfg.find_by_id('sub1') is not None
+        assert {add1, neg1} not in _sfg.operations
 
     def test_different_input_output_count(self, operation_tree):
         sfg = SFG(outputs=[Output(operation_tree)])
-- 
GitLab