From 0aacbfa39451e240dba318b42894f1d22718cace Mon Sep 17 00:00:00 2001 From: Jacob Wahlman <jacwa448@student.liu.se> Date: Wed, 15 Apr 2020 14:19:39 +0200 Subject: [PATCH] Implemented the replace method and fixed some of the tests --- b_asic/signal_flow_graph.py | 47 ++++++++++++++++++++++++++++++++----- test/test_sfg.py | 31 +++++++++++++----------- 2 files changed, 58 insertions(+), 20 deletions(-) diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 324ffb7d..d0ac9fe3 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -5,7 +5,7 @@ TODO: More info. from typing import NewType, List, Iterable, Sequence, Dict, Optional, DefaultDict, Set from numbers import Number -from collections import defaultdict, deque +from collections import defaultdict, deque, Iterable as CIterable from b_asic.port import SignalSourceProvider, OutputPort from b_asic.operation import Operation, AbstractOperation @@ -350,18 +350,53 @@ class SFG(AbstractOperation): input_values.append(self._evaluate_source(input_src)) return src.operation.evaluate_output(src.index, input_values) - def replace_component(self, component_type, _id=None, _type=None): + def replace_component(self, _component: List[Operation], _id: List[GraphID] = None, _type: List[Operation] = None): """Find and replace all components matching either on GraphID, Type or both. Then return a new deepcopy of the sfg with the replaced component. Arguments: - component_type: The type of the new component, e.g Multiplication + _component: The list of new component(s), e.g Multiplication Keyword arguments: - _id: The GraphID to match the component to replace. - _type: The Type to match the component to replace. + _id: The list of GraphID(s) to match the component to replace. + _type: The list of Type(s) to match the component to replace. """ - pass + + _id = [_id] if not isinstance(_id, CIterable) or isinstance(_id, str) else _id + _type = [_type] if not isinstance(_type, CIterable) else _type + _component = [_component] if not isinstance(_component, CIterable) else _component + components = set() + + for comp_id in _id: + if comp_id is None: + continue + + components |= {value for key, value in self._components_by_id.items() if key == comp_id} + + for comp_type in _type: + if comp_type is None: + continue + + components |= {comp for comp in self.components if isinstance(comp, comp_type)} + + assert sum([comp.output_count for comp in _component]) == sum([comp.output_count for comp in components]) + assert sum([comp.input_count for comp in _component]) == sum([comp.input_count for comp in components]) + + for index, comp in enumerate(components): + component = _component if not isinstance(_component, CIterable) else _component[index] + + for index_in, _inp in enumerate(comp.inputs): + for _signal in _inp.signals: + _signal.remove_destination() + _signal.set_destination(component.input(index_in)) + + for index_out, _out in enumerate(comp.outputs): + for _signal in _out.signals: + _signal.remove_source() + _signal.set_source(component.output(index_out)) + + # The old SFG will be deleted by Python GC + return self.deep_copy() def __str__(self): """Prints operations, inputs and outputs in a SFG diff --git a/test/test_sfg.py b/test/test_sfg.py index a133c087..bf656e04 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -120,7 +120,7 @@ class TestReplaceComponents: def test_replace_addition(self, operation_tree): sfg = SFG(outputs=[Output(operation_tree)]) - component_id = "add3" + component_id = "add1" sfg = sfg.replace_component(Multiplication(name="Multi"), _id=component_id) assert component_id not in sfg._components_by_id.keys() @@ -131,29 +131,29 @@ class TestReplaceComponents: component_id = "add2" sfg = sfg.replace_component(Multiplication(name="Multi"), _id=component_id) - assert component_id not in sfg._components_by_id.keys() assert "Multi" in sfg._components_by_name.keys() def test_replace_no_input_component(self, operation_tree): sfg = SFG(outputs=[Output(operation_tree)]) component_id = "c1" - + _value = sfg.find_by_id(component_id).value + sfg = sfg.replace_component(Constant(10), _id=component_id) - assert component_id not in sfg._components_by_id.keys() + assert _value != sfg.find_by_id(component_id).value def test_replace_no_destination_component(self, operation_tree): sfg = SFG(outputs=[Output(operation_tree)]) component_id = "add1" - sfg.replace_component(Multiplication(name="Multi"), _id=component_id) - assert component_id not in sfg._components_by_id.keys() + sfg = sfg.replace_component(Multiplication(name="Multi"), _id=component_id) assert "Multi" in sfg._components_by_name.keys() def test_replace_several_components(self, large_operation_tree): sfg = SFG(outputs=[Output(large_operation_tree)]) component_id = ("add1", "add2", "add3") + replace_comp = (Multiplication(name="Multi"), Multiplication(name="Multi"), Multiplication(name="Multi")) - sfg = sfg.replace_component(Multiplication(name="Multi"), _id=component_id) + sfg = sfg.replace_component(replace_comp, _id=component_id) assert all([_id not in sfg._components_by_id.keys() for _id in component_id]) assert "Multi" in sfg._components_by_name.keys() assert len(sfg._components_by_name["Multi"]) == 3 @@ -161,8 +161,9 @@ class TestReplaceComponents: def test_replace_all_of_type_components(self, large_operation_tree): sfg = SFG(outputs=[Output(large_operation_tree)]) component_type = Addition + replace_comp = (Multiplication(name="Multi"), Multiplication(name="Multi"), Multiplication(name="Multi")) - sfg.replace_component(Multiplication(name="Multi"), _type=component_type) + sfg = sfg.replace_component(replace_comp, _type=component_type) assert all([not isinstance(_type, Addition) for _type in sfg.components]) assert "Multi" in sfg._components_by_name.keys() assert len(sfg._components_by_name["Multi"]) == 3 @@ -171,18 +172,20 @@ class TestReplaceComponents: sfg = SFG(outputs=[Output(large_operation_tree)]) component_id = "addd1" - _sfg = sfg.replace_component(Multiplication(name="Multi"), _id=component_id) - assert all([comp in sfg.components for comp in _sfg]) - assert "Multi" not in sfg._components_by_name.keys() + try: + sfg = sfg.replace_component(Multiplication(name="Multi"), _id=component_id) + except AssertionError: + assert True + else: + assert False def test_not_equal_input(self, large_operation_tree): sfg = SFG(outputs=[Output(large_operation_tree)]) component_id = "c1" - # Couldn't import pytest.raises try: sfg = sfg.replace_component(Multiplication(name="Multi"), component_id) except AssertionError: assert True - - assert False + else: + assert False -- GitLab