From 0cf7c24e85fa336aa96f9bc8f9ffa3d5847344e0 Mon Sep 17 00:00:00 2001
From: Jacob Wahlman <jacwa448@student.liu.se>
Date: Fri, 17 Apr 2020 12:03:28 +0200
Subject: [PATCH] removed the ability to multi-replace several operations

---
 b_asic/signal_flow_graph.py | 58 +++++++++++++++----------------------
 test/test_sfg.py            | 33 +++++++--------------
 2 files changed, 35 insertions(+), 56 deletions(-)

diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py
index d0ac9fe3..8ff53287 100644
--- a/b_asic/signal_flow_graph.py
+++ b/b_asic/signal_flow_graph.py
@@ -5,7 +5,7 @@ TODO: More info.
 
 from typing import NewType, List, Iterable, Sequence, Dict, Optional, DefaultDict, Set
 from numbers import Number
-from collections import defaultdict, deque, Iterable as CIterable
+from collections import defaultdict, deque
 
 from b_asic.port import SignalSourceProvider, OutputPort
 from b_asic.operation import Operation, AbstractOperation
@@ -350,50 +350,40 @@ class SFG(AbstractOperation):
             input_values.append(self._evaluate_source(input_src))
         return src.operation.evaluate_output(src.index, input_values)
 
-    def replace_component(self, _component: List[Operation], _id: List[GraphID] = None, _type: List[Operation] = None):
+    def replace_component(self, component: Operation, _component: Operation = None, _id: GraphID = None):
         """Find and replace all components matching either on GraphID, Type or both.
         Then return a new deepcopy of the sfg with the replaced component.
 
         Arguments:
-        _component: The list of new component(s), e.g Multiplication
-    
+        component: The new component(s), e.g Multiplication
+
         Keyword arguments:
-        _id: The list of GraphID(s) to match the component to replace.
-        _type: The list of Type(s) to match the component to replace.
+        _component: The specific component to replace.
+        _id: The GraphID to match the component to replace.
         """
 
-        _id = [_id] if not isinstance(_id, CIterable) or isinstance(_id, str) else _id
-        _type = [_type] if not isinstance(_type, CIterable) else _type
-        _component = [_component] if not isinstance(_component, CIterable) else _component
-        components = set()
-
-        for comp_id in _id:
-            if comp_id is None:
-                continue
+        assert _component is not None or _id is not None, \
+            "Define either operation to replace or GraphID of operation"
 
-            components |= {value for key, value in self._components_by_id.items() if key == comp_id}
+        if _id is not None:
+            _component = self.find_by_id(_id)
 
-        for comp_type in _type:
-            if comp_type is None:
-                continue
+        assert _component is not None and isinstance(_component, Operation), \
+            "No operation matching the criteria found"
+        assert _component.output_count == component.output_count, \
+            "The output count may not differ between the operations"
+        assert _component.input_count == component.input_count, \
+            "The input count may not differ between the operations"
 
-            components |= {comp for comp in self.components if isinstance(comp, comp_type)}
+        for index_in, _inp in enumerate(_component.inputs):
+            for _signal in _inp.signals:
+                _signal.remove_destination()
+                _signal.set_destination(component.input(index_in))
         
-        assert sum([comp.output_count for comp in _component]) == sum([comp.output_count for comp in components])
-        assert sum([comp.input_count for comp in _component]) == sum([comp.input_count for comp in components])
-
-        for index, comp in enumerate(components):
-            component = _component if not isinstance(_component, CIterable) else _component[index]
-
-            for index_in, _inp in enumerate(comp.inputs):
-                for _signal in _inp.signals:
-                    _signal.remove_destination()
-                    _signal.set_destination(component.input(index_in))
-            
-            for index_out, _out in enumerate(comp.outputs):
-                for _signal in _out.signals:
-                    _signal.remove_source()
-                    _signal.set_source(component.output(index_out))
+        for index_out, _out in enumerate(_component.outputs):
+            for _signal in _out.signals:
+                _signal.remove_source()
+                _signal.set_source(component.output(index_out))
 
         # The old SFG will be deleted by Python GC
         return self.deep_copy()
diff --git a/test/test_sfg.py b/test/test_sfg.py
index bf656e04..cb94ab27 100644
--- a/test/test_sfg.py
+++ b/test/test_sfg.py
@@ -118,7 +118,7 @@ class TestComponents:
 
 class TestReplaceComponents:
 
-    def test_replace_addition(self, operation_tree):
+    def test_replace_addition_by_id(self, operation_tree):
         sfg = SFG(outputs=[Output(operation_tree)])
         component_id = "add1"
 
@@ -126,6 +126,15 @@ class TestReplaceComponents:
         assert component_id not in sfg._components_by_id.keys()
         assert "Multi" in sfg._components_by_name.keys()
 
+    def test_replace_addition_by_component(self, operation_tree):
+        sfg = SFG(outputs=[Output(operation_tree)])
+        component_id = "add1"
+        component = sfg.find_by_id(component_id)
+
+        sfg = sfg.replace_component(Multiplication(name="Multi"), _component=component)
+        assert component_id not in sfg._components_by_id.keys()
+        assert "Multi" in sfg._components_by_name.keys()
+
     def test_replace_addition_large_tree(self, large_operation_tree):
         sfg = SFG(outputs=[Output(large_operation_tree)])
         component_id = "add2"
@@ -148,26 +157,6 @@ class TestReplaceComponents:
         sfg = sfg.replace_component(Multiplication(name="Multi"), _id=component_id)
         assert "Multi" in sfg._components_by_name.keys()
 
-    def test_replace_several_components(self, large_operation_tree):
-        sfg = SFG(outputs=[Output(large_operation_tree)])
-        component_id = ("add1", "add2", "add3")
-        replace_comp = (Multiplication(name="Multi"), Multiplication(name="Multi"), Multiplication(name="Multi"))
-        
-        sfg = sfg.replace_component(replace_comp, _id=component_id)
-        assert all([_id not in sfg._components_by_id.keys() for _id in component_id])
-        assert "Multi" in sfg._components_by_name.keys()
-        assert len(sfg._components_by_name["Multi"]) == 3
-
-    def test_replace_all_of_type_components(self, large_operation_tree):
-        sfg = SFG(outputs=[Output(large_operation_tree)])
-        component_type = Addition
-        replace_comp = (Multiplication(name="Multi"), Multiplication(name="Multi"), Multiplication(name="Multi"))
-
-        sfg = sfg.replace_component(replace_comp, _type=component_type)
-        assert all([not isinstance(_type, Addition) for _type in sfg.components])
-        assert "Multi" in sfg._components_by_name.keys()
-        assert len(sfg._components_by_name["Multi"]) == 3
-
     def test_no_match_on_replace(self, large_operation_tree):
         sfg = SFG(outputs=[Output(large_operation_tree)])
         component_id = "addd1"
@@ -184,7 +173,7 @@ class TestReplaceComponents:
         component_id = "c1"
 
         try:
-            sfg = sfg.replace_component(Multiplication(name="Multi"), component_id)
+            sfg = sfg.replace_component(Multiplication(name="Multi"), _id=component_id)
         except AssertionError:
             assert True
         else:
-- 
GitLab