Skip to content
Snippets Groups Projects
Commit eee37b86 authored by Jacob Wahlman's avatar Jacob Wahlman :ok_hand:
Browse files

added tests and fixed smaller linting

parent 963dc7a8
No related branches found
No related tags found
2 merge requests!42Resolve "Operation to SFG Conversion",!32Resolve "Insert Operation in SFG"
Pipeline #14305 passed
......@@ -407,37 +407,36 @@ class SFG(AbstractOperation):
# The old SFG will be deleted by Python GC
return _sfg_copy()
def insert_operation(self, component: Operation, _output_comp_id: GraphID):
def insert_operation(self, component: Operation, output_comp_id: GraphID):
"""Insert an operation in the SFG after a given source operation.
The source operation output count must match the input count of the operation as well as the output
Then return a new deepcopy of the sfg with the inserted component.
Arguments:
component: The new component, e.g Multiplication.
_output_comp_id: The source operation GraphID to connect from.
output_comp_id: The source operation GraphID to connect from.
"""
# Preserve the original SFG by creating a copy.
_sfg_copy = self()
_output_comp = _sfg_copy.find_by_id(_output_comp_id)
if _output_comp is None:
sfg_copy = self()
output_comp = sfg_copy.find_by_id(output_comp_id)
if output_comp is None:
return None
assert not isinstance(_output_comp, Output), \
assert not isinstance(output_comp, Output), \
"Source operation can not be an output operation."
assert len(_output_comp.output_signals) == component.input_count, \
assert len(output_comp.output_signals) == component.input_count, \
"Source operation output count does not match input count for component."
assert len(_output_comp.output_signals) == component.output_count, \
assert len(output_comp.output_signals) == component.output_count, \
"Destination operation input count does not match output for component."
for index, _signal_in in enumerate(_output_comp.output_signals):
_destination = _signal_in.destination
_signal_in.set_destination(component.input(index))
_signal_out = Signal(component.output(index))
_signal_out.set_destination(_destination)
for index, signal_in in enumerate(output_comp.output_signals):
destination = signal_in.destination
signal_in.set_destination(component.input(index))
destination.connect(component.output(index))
# Recreate the newly coupled SFG so that all attributes are correct.
return _sfg_copy()
return sfg_copy()
def _evaluate_source(self, src: OutputPort, results: MutableResultMap, registers: MutableRegisterMap, prefix: str) -> Number:
src_prefix = prefix
......
import pytest
from b_asic import Addition, Constant, Signal
from b_asic import Addition, Constant, Signal, Butterfly
@pytest.fixture
......@@ -41,6 +41,41 @@ def large_operation_tree():
"""
return Addition(Addition(Constant(2), Constant(3)), Addition(Constant(4), Constant(5)))
@pytest.fixture
def large_operation_tree_names():
"""Valid addition operation connected with a large operation tree with 2 other additions and 4 constants.
With names.
2---+
|
v
add---+
^ |
| |
3---+ v
add = (2 + 3) + (4 + 5) = 14
4---+ ^
| |
v |
add---+
^
|
5---+
"""
return Addition(Addition(Constant(2, name="constant2"), Constant(3, name="constant3")), Addition(Constant(4, name="constant4"), Constant(5, name="constant5")))
@pytest.fixture
def butterfly_operation_tree():
"""Valid butterfly operations connected to eachother with 3 butterfly operations and 2 constants as inputs and 2 outputs.
2 ---+ +--- (2 + 4) ---+ +--- (6 + (-2)) ---+ +--- (4 + 8) ---> out1 = 12
| | | | | |
v ^ v ^ v ^
butterfly butterfly butterfly
^ v ^ v ^ v
| | | | | |
4 ---+ +--- (2 - 4) ---+ +--- (6 - (-2)) ---+ +--- (4 - 8) ---> out2 = -4
"""
return Butterfly(*(Butterfly(*(Butterfly(Constant(2), Constant(4)).outputs)).outputs))
@pytest.fixture
def operation_graph_with_cycle():
"""Invalid addition operation connected with an operation graph containing a cycle.
......
import pytest
from b_asic import SFG, Signal, Input, Output, Constant, Addition, Multiplication, SquareRoot
from b_asic import SFG, Signal, Input, Output, Constant, Addition, Multiplication, SquareRoot, Butterfly
class TestInit:
......@@ -249,19 +249,27 @@ class TestReplaceComponents:
class TestInsertComponent:
def test_insert_component_in_sfg(self, large_operation_tree):
sfg = SFG(outputs=[Output(large_operation_tree)])
def test_insert_component_in_sfg(self, large_operation_tree_names):
sfg = SFG(outputs=[Output(large_operation_tree_names)])
sqrt = SquareRoot()
_sfg = sfg.insert_operation(sqrt, "c1")
_sfg = sfg.insert_operation(sqrt, sfg.find_by_name("constant4")[0].graph_id)
assert _sfg.evaluate() != sfg.evaluate()
assert any([isinstance(comp, SquareRoot) for comp in _sfg.operations])
assert not any([isinstance(comp, SquareRoot) for comp in sfg.operations])
assert not isinstance(sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation, SquareRoot)
assert isinstance(_sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation, SquareRoot)
assert sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation is sfg.find_by_id("add3")
assert _sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation is not _sfg.find_by_id("add3")
assert _sfg.find_by_id("sqrt1").output(0).signals[0].destination.operation is _sfg.find_by_id("add3")
def test_insert_invalid_component_in_sfg(self, large_operation_tree):
sfg = SFG(outputs=[Output(large_operation_tree)])
# Should raise an exception for not matching input count to output count.
add4 = Addition()
with pytest.raises(Exception):
sfg.insert_operation(add4, "c1")
......@@ -269,7 +277,24 @@ class TestInsertComponent:
def test_insert_at_output(self, large_operation_tree):
sfg = SFG(outputs=[Output(large_operation_tree)])
# Should raise an exception for trying to insert an operation after an output.
sqrt = SquareRoot()
with pytest.raises(Exception):
_sfg = sfg.insert_operation(sqrt, "out1")
def test_insert_multiple_output_ports(self, butterfly_operation_tree):
sfg = SFG(outputs=list(map(Output, butterfly_operation_tree.outputs)))
_sfg = sfg.insert_operation(Butterfly(name="New Bfly"), "bfly3")
assert sfg.evaluate() != _sfg.evaluate()
assert len(sfg.find_by_name("New Bfly")) == 0
assert len(_sfg.find_by_name("New Bfly")) == 1
# The old bfly3 becomes bfly4 in the new sfg since it is "moved" back.
assert sfg.find_by_id("bfly4") is None
assert _sfg.find_by_id("bfly4") is not None
assert sfg.find_by_id("bfly3").output(0).signals[0].destination.operation.name is not "New Bfly"
assert _sfg.find_by_id("bfly4").output(0).signals[0].destination.operation.name is "New Bfly"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment