From 729c11beced60deaa90e31163bbfcfc0e7c13769 Mon Sep 17 00:00:00 2001 From: Kevin <Kevin> Date: Sun, 3 May 2020 14:18:52 +0200 Subject: [PATCH] implemented operation replacement and changed some test since we return a deepcopy of the sfg --- b_asic/signal_flow_graph.py | 42 ++++++++++++++++++++++--------------- test/test_sfg.py | 10 ++++----- 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 19675e33..28e0f66b 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -406,7 +406,7 @@ class SFG(AbstractOperation): return self() def replace_operations(self, operation_ids: Sequence[GraphID], operation: Operation): - """Replace multiple operations in the sfg with a operation of equivalent functionallity with the same number of inputs and outputs. + """Replace multiple operations in the sfg with a operation with the same amount of inputs and outputs. Then return a new deepcopy of the sfg with the replaced operations. Arguments: @@ -414,28 +414,36 @@ class SFG(AbstractOperation): operation: The operation used for replacement. """ - operations = [self.find_by_id(_id) for _id in operation_ids] + inputs = [] + outputs = [] - assert sum(o.input_count + o.output_count for o in operations) == operation.input_count + operation.output_count, \ - "The input and output count must match" - - # Create a copy of the sfg for manipulating - _sfg = self() + for _operation in self.operations: + if _operation.graph_id not in operation_ids: + continue + + # Retrive input operations + for _signal in _operation.input_signals: + if _signal.source.operation.graph_id not in operation_ids: + inputs.append(_signal.source.operation) - for operation in operations: - operation + # Retrive output operations + for _signal in _operation.output_signals: + if _signal.destination.operation.graph_id not in operation_ids: + outputs.append(_signal.destination.operation) + assert len(inputs) + len(outputs) == \ + operation.input_count + operation.output_count, "The input and output count must match" - for _index, _inp in enumerate(inputs): - for _signal in _inp.output_signals: + for index_in, _input in enumerate(inputs): + for _signal in _input.output_signals: _signal.remove_destination() - _signal.set_destination(operation.input(_index)) + _signal.set_destination(operation.input(index_in)) - for _index, _out in enumerate(outputs): - for _signal in _out.input_signals: - _signal.remove_destination() - _signal.set_source(operation.output(_index)) - + for index_out, _output in enumerate(outputs): + for _signal in _output.input_signals: + _signal.remove_source() + _signal.set_source(operation.output(index_out)) + return self() def _evaluate_source(self, src: OutputPort, results: MutableResultMap, registers: MutableRegisterMap, prefix: str) -> Number: diff --git a/test/test_sfg.py b/test/test_sfg.py index 8c00cd73..2023ffcf 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -268,22 +268,22 @@ class TestReplaceOperations: mad1 = MAD() _sfg = sfg.replace_operations(['add1', 'mul1'], mad1) - assert mad1 in _sfg.operations + assert _sfg.find_by_id('mad1') is not None assert {add1, mul1} not in _sfg.operations def test_replace_neg_add_with_sub(self): in1 = Input() in2 = Input() - neg1 = ConstantMultiplication(-1, in1) + neg1 = ConstantMultiplication(-1, in1, 'neg1') add1 = neg1 + in2 out1 = Output(add1) sfg = SFG(inputs=[in1, in2], outputs=[out1]) sub1 = Subtraction() - sfg.replace_operations(['add1, neg1'], sub1) + _sfg = sfg.replace_operations(['add1', 'neg1'], sub1) - assert sub1 in sfg.operations - assert {add1, neg1} not in sfg.operations + assert _sfg.find_by_id('sub1') is not None + assert {add1, neg1} not in _sfg.operations def test_different_input_output_count(self, operation_tree): sfg = SFG(outputs=[Output(operation_tree)]) -- GitLab