diff --git a/b_asic/operation.py b/b_asic/operation.py index bb66e26b30a4a14b116800300d5a00d0855945f2..92c7b2b04029d039b125970774d84993ae1ae88b 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -334,7 +334,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): # Import here to avoid circular imports. from b_asic.special_operations import Input try: - result = self.evaluate([Input()] * self.input_count) + result = self.evaluate(*([Input()] * self.input_count)) if isinstance(result, collections.Sequence) and all(isinstance(e, Operation) for e in result): return result if isinstance(result, Operation): diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index bcfc9eebd7599128a92c83beafa64209fce6e43a..61449fc3e5cf87167ddfbd2203eb8eb2d3f03978 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -234,6 +234,40 @@ class SFG(AbstractOperation): results[self.key(index, prefix)] = value return value + def connect_external_signals_to_components(self) -> bool: + """ Connects any external signals to this SFG's internal operations. This SFG becomes unconnected to the SFG + it is a component off, causing it to become invalid afterwards. Returns True if succesful, False otherwise. """ + if len(self.inputs) != len(self.input_operations): + raise IndexError(f"Number of inputs does not match the number of input_operations in SFG.") + if len(self.outputs) != len(self.output_operations): + raise IndexError(f"Number of outputs does not match the number of output_operations SFG.") + if len(self.input_signals) == 0: + return False + if len(self.output_signals) == 0: + return False + + # For each input_signal, connect it to the corresponding operation + for port, input_operation in zip(self.inputs, self.input_operations): + dest = input_operation.output(0).signals[0].destination + dest.clear() + port.signals[0].set_destination(dest) + # For each output_signal, connect it to the corresponding operation + for port, output_operation in zip(self.outputs, self.output_operations): + src = output_operation.input(0).signals[0].source + src.clear() + port.signals[0].set_source(src) + return True + + @property + def input_operations(self) -> Sequence[Operation]: + """Get the internal input operations in the same order as their respective input ports.""" + return self._input_operations + + @property + def output_operations(self) -> Sequence[Operation]: + """Get the internal output operations in the same order as their respective output ports.""" + return self._output_operations + def split(self) -> Iterable[Operation]: return self.operations diff --git a/test/test_sfg.py b/test/test_sfg.py index 222bfe237d287891eaf72442e83cec5c4012178c..c188351f916ace6803b2f81eae24011e4aa4fcc7 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -1,6 +1,6 @@ import pytest -from b_asic import SFG, Signal, Input, Output, Constant, Addition, Multiplication +from b_asic import SFG, Signal, Input, Output, Constant, Addition, Multiplication, Subtraction class TestInit: @@ -254,3 +254,153 @@ class TestReplaceComponents: assert True else: assert False + +class TestConnectExternalSignalsToComponentsSoloComp: + + def test_connect_external_signals_to_components_mac(self): + """ Replace a MAC with inner components in an SFG """ + inp1 = Input("INP1") + inp2 = Input("INP2") + inp3 = Input("INP3") + add1 = Addition(None, None, "ADD1") + add2 = Addition(None, None, "ADD2") + mul1 = Multiplication(None, None, "MUL1") + out1 = Output(None, "OUT1") + + add1.input(0).connect(inp1, "S1") + add1.input(1).connect(inp2, "S2") + add2.input(0).connect(add1, "S3") + add2.input(1).connect(inp3, "S4") + mul1.input(0).connect(add1, "S5") + mul1.input(1).connect(add2, "S6") + out1.input(0).connect(mul1, "S7") + + mac_sfg = SFG(inputs = [inp1, inp2], outputs = [out1]) + + inp4 = Input("INP4") + inp5 = Input("INP5") + out2 = Output(None, "OUT2") + + mac_sfg.input(0).connect(inp4, "S8") + mac_sfg.input(1).connect(inp5, "S9") + out2.input(0).connect(mac_sfg.outputs[0], "S10") + + test_sfg = SFG(inputs = [inp4, inp5], outputs = [out2]) + assert test_sfg.evaluate(1,2) == 9 + mac_sfg.connect_external_signals_to_components() + assert test_sfg.evaluate(1,2) == 9 + assert test_sfg.connect_external_signals_to_components() == False + + def test_connect_external_signals_to_components_operation_tree(self, operation_tree): + """ Replaces an SFG with only a operation_tree component with its inner components """ + sfg1 = SFG(outputs = [Output(operation_tree)]) + out1 = Output(None, "OUT1") + out1.input(0).connect(sfg1.outputs[0], "S1") + test_sfg = SFG(outputs = [out1]) + assert test_sfg.evaluate_output(0, []) == 5 + sfg1.connect_external_signals_to_components() + assert test_sfg.evaluate_output(0, []) == 5 + assert test_sfg.connect_external_signals_to_components() == False + + def test_connect_external_signals_to_components_large_operation_tree(self, large_operation_tree): + """ Replaces an SFG with only a large_operation_tree component with its inner components """ + sfg1 = SFG(outputs = [Output(large_operation_tree)]) + out1 = Output(None, "OUT1") + out1.input(0).connect(sfg1.outputs[0], "S1") + test_sfg = SFG(outputs = [out1]) + assert test_sfg.evaluate_output(0, []) == 14 + sfg1.connect_external_signals_to_components() + assert test_sfg.evaluate_output(0, []) == 14 + assert test_sfg.connect_external_signals_to_components() == False + +class TestConnectExternalSignalsToComponentsMultipleComp: + + def test_connect_external_signals_to_components_operation_tree(self, operation_tree): + """ Replaces a operation_tree in an SFG with other components """ + sfg1 = SFG(outputs = [Output(operation_tree)]) + + inp1 = Input("INP1") + inp2 = Input("INP2") + out1 = Output(None, "OUT1") + + add1 = Addition(None, None, "ADD1") + add2 = Addition(None, None, "ADD2") + + add1.input(0).connect(inp1, "S1") + add1.input(1).connect(inp2, "S2") + add2.input(0).connect(add1, "S3") + add2.input(1).connect(sfg1.outputs[0], "S4") + out1.input(0).connect(add2, "S5") + + test_sfg = SFG(inputs = [inp1, inp2], outputs = [out1]) + assert test_sfg.evaluate(1, 2) == 8 + sfg1.connect_external_signals_to_components() + assert test_sfg.evaluate(1, 2) == 8 + assert test_sfg.connect_external_signals_to_components() == False + + def test_connect_external_signals_to_components_large_operation_tree(self, large_operation_tree): + """ Replaces a large_operation_tree in an SFG with other components """ + sfg1 = SFG(outputs = [Output(large_operation_tree)]) + + inp1 = Input("INP1") + inp2 = Input("INP2") + out1 = Output(None, "OUT1") + add1 = Addition(None, None, "ADD1") + add2 = Addition(None, None, "ADD2") + + add1.input(0).connect(inp1, "S1") + add1.input(1).connect(inp2, "S2") + add2.input(0).connect(add1, "S3") + add2.input(1).connect(sfg1.outputs[0], "S4") + out1.input(0).connect(add2, "S5") + + test_sfg = SFG(inputs = [inp1, inp2], outputs = [out1]) + assert test_sfg.evaluate(1, 2) == 17 + sfg1.connect_external_signals_to_components() + assert test_sfg.evaluate(1, 2) == 17 + assert test_sfg.connect_external_signals_to_components() == False + + def create_sfg(self, op_tree): + """ Create a simple SFG with either operation_tree or large_operation_tree """ + sfg1 = SFG(outputs = [Output(op_tree)]) + + inp1 = Input("INP1") + inp2 = Input("INP2") + out1 = Output(None, "OUT1") + add1 = Addition(None, None, "ADD1") + add2 = Addition(None, None, "ADD2") + + add1.input(0).connect(inp1, "S1") + add1.input(1).connect(inp2, "S2") + add2.input(0).connect(add1, "S3") + add2.input(1).connect(sfg1.outputs[0], "S4") + out1.input(0).connect(add2, "S5") + + return SFG(inputs = [inp1, inp2], outputs = [out1]) + + def test_connect_external_signals_to_components_many_op(self, large_operation_tree): + """ Replaces an sfg component in a larger SFG with several component operations """ + inp1 = Input("INP1") + inp2 = Input("INP2") + inp3 = Input("INP3") + inp4 = Input("INP4") + out1 = Output(None, "OUT1") + add1 = Addition(None, None, "ADD1") + sub1 = Subtraction(None, None, "SUB1") + + add1.input(0).connect(inp1, "S1") + add1.input(1).connect(inp2, "S2") + + sfg1 = self.create_sfg(large_operation_tree) + + sfg1.input(0).connect(add1, "S3") + sfg1.input(1).connect(inp3, "S4") + sub1.input(0).connect(sfg1.outputs[0], "S5") + sub1.input(1).connect(inp4, "S6") + out1.input(0).connect(sub1, "S7") + + test_sfg = SFG(inputs = [inp1, inp2, inp3, inp4], outputs = [out1]) + assert test_sfg.evaluate(1, 2, 3, 4) == 16 + sfg1.connect_external_signals_to_components() + assert test_sfg.evaluate(1, 2, 3, 4) == 16 + assert test_sfg.connect_external_signals_to_components() == False \ No newline at end of file