diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index d0690ef5d6edfa366484975787a22513535b3297..879af90236311eb0441df4e8eb44212ea06e9fa1 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -407,37 +407,36 @@ class SFG(AbstractOperation): # The old SFG will be deleted by Python GC return _sfg_copy() - def insert_operation(self, component: Operation, _output_comp_id: GraphID): + def insert_operation(self, component: Operation, output_comp_id: GraphID): """Insert an operation in the SFG after a given source operation. The source operation output count must match the input count of the operation as well as the output Then return a new deepcopy of the sfg with the inserted component. Arguments: component: The new component, e.g Multiplication. - _output_comp_id: The source operation GraphID to connect from. + output_comp_id: The source operation GraphID to connect from. """ # Preserve the original SFG by creating a copy. - _sfg_copy = self() - _output_comp = _sfg_copy.find_by_id(_output_comp_id) - if _output_comp is None: + sfg_copy = self() + output_comp = sfg_copy.find_by_id(output_comp_id) + if output_comp is None: return None - assert not isinstance(_output_comp, Output), \ + assert not isinstance(output_comp, Output), \ "Source operation can not be an output operation." - assert len(_output_comp.output_signals) == component.input_count, \ + assert len(output_comp.output_signals) == component.input_count, \ "Source operation output count does not match input count for component." - assert len(_output_comp.output_signals) == component.output_count, \ + assert len(output_comp.output_signals) == component.output_count, \ "Destination operation input count does not match output for component." - for index, _signal_in in enumerate(_output_comp.output_signals): - _destination = _signal_in.destination - _signal_in.set_destination(component.input(index)) - _signal_out = Signal(component.output(index)) - _signal_out.set_destination(_destination) + for index, signal_in in enumerate(output_comp.output_signals): + destination = signal_in.destination + signal_in.set_destination(component.input(index)) + destination.connect(component.output(index)) # Recreate the newly coupled SFG so that all attributes are correct. - return _sfg_copy() + return sfg_copy() def _evaluate_source(self, src: OutputPort, results: MutableResultMap, registers: MutableRegisterMap, prefix: str) -> Number: src_prefix = prefix diff --git a/test/fixtures/operation_tree.py b/test/fixtures/operation_tree.py index fc8008fa4098ca488e23766f5ff7d05711300685..e6584a1087bc84bea97f9f41da8a5bb93d36538b 100644 --- a/test/fixtures/operation_tree.py +++ b/test/fixtures/operation_tree.py @@ -1,6 +1,6 @@ import pytest -from b_asic import Addition, Constant, Signal +from b_asic import Addition, Constant, Signal, Butterfly @pytest.fixture @@ -41,6 +41,41 @@ def large_operation_tree(): """ return Addition(Addition(Constant(2), Constant(3)), Addition(Constant(4), Constant(5))) +@pytest.fixture +def large_operation_tree_names(): + """Valid addition operation connected with a large operation tree with 2 other additions and 4 constants. + With names. + 2---+ + | + v + add---+ + ^ | + | | + 3---+ v + add = (2 + 3) + (4 + 5) = 14 + 4---+ ^ + | | + v | + add---+ + ^ + | + 5---+ + """ + return Addition(Addition(Constant(2, name="constant2"), Constant(3, name="constant3")), Addition(Constant(4, name="constant4"), Constant(5, name="constant5"))) + +@pytest.fixture +def butterfly_operation_tree(): + """Valid butterfly operations connected to eachother with 3 butterfly operations and 2 constants as inputs and 2 outputs. + 2 ---+ +--- (2 + 4) ---+ +--- (6 + (-2)) ---+ +--- (4 + 8) ---> out1 = 12 + | | | | | | + v ^ v ^ v ^ + butterfly butterfly butterfly + ^ v ^ v ^ v + | | | | | | + 4 ---+ +--- (2 - 4) ---+ +--- (6 - (-2)) ---+ +--- (4 - 8) ---> out2 = -4 + """ + return Butterfly(*(Butterfly(*(Butterfly(Constant(2), Constant(4)).outputs)).outputs)) + @pytest.fixture def operation_graph_with_cycle(): """Invalid addition operation connected with an operation graph containing a cycle. diff --git a/test/test_sfg.py b/test/test_sfg.py index bc627cfe9fc39b0cb41e7211c3272b4ae4074304..ea7eb1b878ac48e8e0f5a537b7c3b155de12a9a6 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -1,6 +1,6 @@ import pytest -from b_asic import SFG, Signal, Input, Output, Constant, Addition, Multiplication, SquareRoot +from b_asic import SFG, Signal, Input, Output, Constant, Addition, Multiplication, SquareRoot, Butterfly class TestInit: @@ -249,19 +249,27 @@ class TestReplaceComponents: class TestInsertComponent: - def test_insert_component_in_sfg(self, large_operation_tree): - sfg = SFG(outputs=[Output(large_operation_tree)]) - + def test_insert_component_in_sfg(self, large_operation_tree_names): + sfg = SFG(outputs=[Output(large_operation_tree_names)]) sqrt = SquareRoot() - _sfg = sfg.insert_operation(sqrt, "c1") + _sfg = sfg.insert_operation(sqrt, sfg.find_by_name("constant4")[0].graph_id) assert _sfg.evaluate() != sfg.evaluate() + assert any([isinstance(comp, SquareRoot) for comp in _sfg.operations]) assert not any([isinstance(comp, SquareRoot) for comp in sfg.operations]) + assert not isinstance(sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation, SquareRoot) + assert isinstance(_sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation, SquareRoot) + + assert sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation is sfg.find_by_id("add3") + assert _sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation is not _sfg.find_by_id("add3") + assert _sfg.find_by_id("sqrt1").output(0).signals[0].destination.operation is _sfg.find_by_id("add3") + def test_insert_invalid_component_in_sfg(self, large_operation_tree): sfg = SFG(outputs=[Output(large_operation_tree)]) + # Should raise an exception for not matching input count to output count. add4 = Addition() with pytest.raises(Exception): sfg.insert_operation(add4, "c1") @@ -269,7 +277,24 @@ class TestInsertComponent: def test_insert_at_output(self, large_operation_tree): sfg = SFG(outputs=[Output(large_operation_tree)]) + # Should raise an exception for trying to insert an operation after an output. sqrt = SquareRoot() with pytest.raises(Exception): _sfg = sfg.insert_operation(sqrt, "out1") + def test_insert_multiple_output_ports(self, butterfly_operation_tree): + sfg = SFG(outputs=list(map(Output, butterfly_operation_tree.outputs))) + _sfg = sfg.insert_operation(Butterfly(name="New Bfly"), "bfly3") + + assert sfg.evaluate() != _sfg.evaluate() + + assert len(sfg.find_by_name("New Bfly")) == 0 + assert len(_sfg.find_by_name("New Bfly")) == 1 + + # The old bfly3 becomes bfly4 in the new sfg since it is "moved" back. + assert sfg.find_by_id("bfly4") is None + assert _sfg.find_by_id("bfly4") is not None + + assert sfg.find_by_id("bfly3").output(0).signals[0].destination.operation.name is not "New Bfly" + assert _sfg.find_by_id("bfly4").output(0).signals[0].destination.operation.name is "New Bfly" +