Skip to content
Snippets Groups Projects
Commit 729c11be authored by Kevin's avatar Kevin
Browse files

implemented operation replacement and changed some test since we return a deepcopy of the sfg

parent dca8af0e
No related branches found
No related tags found
1 merge request!44Resolve "Operation Replacement in a SFG"
Pipeline #14841 passed
This commit is part of merge request !44. Comments created here will be created in the context of that merge request.
......@@ -406,7 +406,7 @@ class SFG(AbstractOperation):
return self()
def replace_operations(self, operation_ids: Sequence[GraphID], operation: Operation):
"""Replace multiple operations in the sfg with a operation of equivalent functionallity with the same number of inputs and outputs.
"""Replace multiple operations in the sfg with a operation with the same amount of inputs and outputs.
Then return a new deepcopy of the sfg with the replaced operations.
Arguments:
......@@ -414,28 +414,36 @@ class SFG(AbstractOperation):
operation: The operation used for replacement.
"""
operations = [self.find_by_id(_id) for _id in operation_ids]
inputs = []
outputs = []
assert sum(o.input_count + o.output_count for o in operations) == operation.input_count + operation.output_count, \
"The input and output count must match"
# Create a copy of the sfg for manipulating
_sfg = self()
for _operation in self.operations:
if _operation.graph_id not in operation_ids:
continue
# Retrive input operations
for _signal in _operation.input_signals:
if _signal.source.operation.graph_id not in operation_ids:
inputs.append(_signal.source.operation)
for operation in operations:
operation
# Retrive output operations
for _signal in _operation.output_signals:
if _signal.destination.operation.graph_id not in operation_ids:
outputs.append(_signal.destination.operation)
assert len(inputs) + len(outputs) == \
operation.input_count + operation.output_count, "The input and output count must match"
for _index, _inp in enumerate(inputs):
for _signal in _inp.output_signals:
for index_in, _input in enumerate(inputs):
for _signal in _input.output_signals:
_signal.remove_destination()
_signal.set_destination(operation.input(_index))
_signal.set_destination(operation.input(index_in))
for _index, _out in enumerate(outputs):
for _signal in _out.input_signals:
_signal.remove_destination()
_signal.set_source(operation.output(_index))
for index_out, _output in enumerate(outputs):
for _signal in _output.input_signals:
_signal.remove_source()
_signal.set_source(operation.output(index_out))
return self()
def _evaluate_source(self, src: OutputPort, results: MutableResultMap, registers: MutableRegisterMap, prefix: str) -> Number:
......
......@@ -268,22 +268,22 @@ class TestReplaceOperations:
mad1 = MAD()
_sfg = sfg.replace_operations(['add1', 'mul1'], mad1)
assert mad1 in _sfg.operations
assert _sfg.find_by_id('mad1') is not None
assert {add1, mul1} not in _sfg.operations
def test_replace_neg_add_with_sub(self):
in1 = Input()
in2 = Input()
neg1 = ConstantMultiplication(-1, in1)
neg1 = ConstantMultiplication(-1, in1, 'neg1')
add1 = neg1 + in2
out1 = Output(add1)
sfg = SFG(inputs=[in1, in2], outputs=[out1])
sub1 = Subtraction()
sfg.replace_operations(['add1, neg1'], sub1)
_sfg = sfg.replace_operations(['add1', 'neg1'], sub1)
assert sub1 in sfg.operations
assert {add1, neg1} not in sfg.operations
assert _sfg.find_by_id('sub1') is not None
assert {add1, neg1} not in _sfg.operations
def test_different_input_output_count(self, operation_tree):
sfg = SFG(outputs=[Output(operation_tree)])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment