diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 2f2a024053d41b2f15a4eb18ceb239c5832ff95d..324ffb7d12ada379ad7352fb3f9113a0a703b882 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -350,6 +350,18 @@ class SFG(AbstractOperation): input_values.append(self._evaluate_source(input_src)) return src.operation.evaluate_output(src.index, input_values) + def replace_component(self, component_type, _id=None, _type=None): + """Find and replace all components matching either on GraphID, Type or both. + Then return a new deepcopy of the sfg with the replaced component. + + Arguments: + component_type: The type of the new component, e.g Multiplication + + Keyword arguments: + _id: The GraphID to match the component to replace. + _type: The Type to match the component to replace. + """ + pass def __str__(self): """Prints operations, inputs and outputs in a SFG diff --git a/test/test_sfg.py b/test/test_sfg.py index af9dfe179751fd620d5880494215c3b1cfb8571b..a133c0877d76fb4f2969476b2bd7c341b5bcfd7b 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -114,3 +114,75 @@ class TestComponents: assert set([comp.name for comp in mac_sfg.components]) == { "INP1", "INP2", "INP3", "ADD1", "ADD2", "MUL1", "OUT1", "S1", "S2", "S3", "S4", "S5", "S6", "S7"} + + +class TestReplaceComponents: + + def test_replace_addition(self, operation_tree): + sfg = SFG(outputs=[Output(operation_tree)]) + component_id = "add3" + + sfg = sfg.replace_component(Multiplication(name="Multi"), _id=component_id) + assert component_id not in sfg._components_by_id.keys() + assert "Multi" in sfg._components_by_name.keys() + + def test_replace_addition_large_tree(self, large_operation_tree): + sfg = SFG(outputs=[Output(large_operation_tree)]) + component_id = "add2" + + sfg = sfg.replace_component(Multiplication(name="Multi"), _id=component_id) + assert component_id not in sfg._components_by_id.keys() + assert "Multi" in sfg._components_by_name.keys() + + def test_replace_no_input_component(self, operation_tree): + sfg = SFG(outputs=[Output(operation_tree)]) + component_id = "c1" + + sfg = sfg.replace_component(Constant(10), _id=component_id) + assert component_id not in sfg._components_by_id.keys() + + def test_replace_no_destination_component(self, operation_tree): + sfg = SFG(outputs=[Output(operation_tree)]) + component_id = "add1" + + sfg.replace_component(Multiplication(name="Multi"), _id=component_id) + assert component_id not in sfg._components_by_id.keys() + assert "Multi" in sfg._components_by_name.keys() + + def test_replace_several_components(self, large_operation_tree): + sfg = SFG(outputs=[Output(large_operation_tree)]) + component_id = ("add1", "add2", "add3") + + sfg = sfg.replace_component(Multiplication(name="Multi"), _id=component_id) + assert all([_id not in sfg._components_by_id.keys() for _id in component_id]) + assert "Multi" in sfg._components_by_name.keys() + assert len(sfg._components_by_name["Multi"]) == 3 + + def test_replace_all_of_type_components(self, large_operation_tree): + sfg = SFG(outputs=[Output(large_operation_tree)]) + component_type = Addition + + sfg.replace_component(Multiplication(name="Multi"), _type=component_type) + assert all([not isinstance(_type, Addition) for _type in sfg.components]) + assert "Multi" in sfg._components_by_name.keys() + assert len(sfg._components_by_name["Multi"]) == 3 + + def test_no_match_on_replace(self, large_operation_tree): + sfg = SFG(outputs=[Output(large_operation_tree)]) + component_id = "addd1" + + _sfg = sfg.replace_component(Multiplication(name="Multi"), _id=component_id) + assert all([comp in sfg.components for comp in _sfg]) + assert "Multi" not in sfg._components_by_name.keys() + + def test_not_equal_input(self, large_operation_tree): + sfg = SFG(outputs=[Output(large_operation_tree)]) + component_id = "c1" + + # Couldn't import pytest.raises + try: + sfg = sfg.replace_component(Multiplication(name="Multi"), component_id) + except AssertionError: + assert True + + assert False