Skip to content
Snippets Groups Projects
Commit f9cbafb1 authored by Jacob Wahlman's avatar Jacob Wahlman :ok_hand:
Browse files

Merge branch 'develop' of gitlab.liu.se:PUM_TDDD96/B-ASIC into 87-resize-gui-window

parents 39ff67c3 6dcab2af
No related branches found
No related tags found
1 merge request!46Resolve "Resize GUI Window"
Pipeline #14936 passed
...@@ -240,3 +240,18 @@ class Butterfly(AbstractOperation): ...@@ -240,3 +240,18 @@ class Butterfly(AbstractOperation):
def evaluate(self, a, b): def evaluate(self, a, b):
return a + b, a - b return a + b, a - b
class MAD(AbstractOperation):
"""Multiply-and-add operation.
TODO: More info.
"""
def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, src2: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(input_count = 3, output_count = 1, name = name, input_sources = [src0, src1, src2])
@property
def type_name(self) -> TypeName:
return "mad"
def evaluate(self, a, b, c):
return a * b + c
...@@ -186,6 +186,12 @@ class Operation(GraphComponent, SignalSourceProvider): ...@@ -186,6 +186,12 @@ class Operation(GraphComponent, SignalSourceProvider):
"""Get the input indices of all inputs in this operation whose values are required in order to evalueate the output at the given output index.""" """Get the input indices of all inputs in this operation whose values are required in order to evalueate the output at the given output index."""
raise NotImplementedError raise NotImplementedError
@abstractmethod
def to_sfg(self) -> "SFG":
"""Convert the operation into its corresponding SFG.
If the operation is composed by multiple operations, the operation will be split.
"""
raise NotImplementedError
class AbstractOperation(Operation, AbstractGraphComponent): class AbstractOperation(Operation, AbstractGraphComponent):
"""Generic abstract operation class which most implementations will derive from. """Generic abstract operation class which most implementations will derive from.
...@@ -361,6 +367,30 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -361,6 +367,30 @@ class AbstractOperation(Operation, AbstractGraphComponent):
pass pass
return [self] return [self]
def to_sfg(self) -> "SFG":
# Import here to avoid circular imports.
from b_asic.special_operations import Input, Output
from b_asic.signal_flow_graph import SFG
inputs = [Input() for i in range(self.input_count)]
try:
last_operations = self.evaluate(*inputs)
if isinstance(last_operations, Operation):
last_operations = [last_operations]
outputs = [Output(o) for o in last_operations]
except TypeError:
operation_copy = self.copy_component()
inputs = []
for i in range(self.input_count):
_input = Input()
operation_copy.input(i).connect(_input)
inputs.append(_input)
outputs = [Output(operation_copy)]
return SFG(inputs=inputs, outputs=outputs)
def inputs_required_for_output(self, output_index: int) -> Iterable[int]: def inputs_required_for_output(self, output_index: int) -> Iterable[int]:
if output_index < 0 or output_index >= self.output_count: if output_index < 0 or output_index >= self.output_count:
raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {output_index})") raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {output_index})")
......
...@@ -283,6 +283,9 @@ class SFG(AbstractOperation): ...@@ -283,6 +283,9 @@ class SFG(AbstractOperation):
def split(self) -> Iterable[Operation]: def split(self) -> Iterable[Operation]:
return self.operations return self.operations
def to_sfg(self) -> 'SFG':
return self
def inputs_required_for_output(self, output_index: int) -> Iterable[int]: def inputs_required_for_output(self, output_index: int) -> Iterable[int]:
if output_index < 0 or output_index >= self.output_count: if output_index < 0 or output_index >= self.output_count:
......
...@@ -89,4 +89,3 @@ def test_division_overload(): ...@@ -89,4 +89,3 @@ def test_division_overload():
assert isinstance(div3, Division) assert isinstance(div3, Division)
assert div3.input(0).signals[0].source.operation.value == 5 assert div3.input(0).signals[0].source.operation.value == 5
assert div3.input(1).signals == div2.output(0).signals assert div3.input(1).signals == div2.output(0).signals
...@@ -6,7 +6,6 @@ from b_asic import \ ...@@ -6,7 +6,6 @@ from b_asic import \
Constant, Addition, Subtraction, Multiplication, ConstantMultiplication, Division, \ Constant, Addition, Subtraction, Multiplication, ConstantMultiplication, Division, \
SquareRoot, ComplexConjugate, Max, Min, Absolute, Butterfly SquareRoot, ComplexConjugate, Max, Min, Absolute, Butterfly
class TestConstant: class TestConstant:
def test_constant_positive(self): def test_constant_positive(self):
test_operation = Constant(3) test_operation = Constant(3)
......
import pytest import pytest
from b_asic import Constant, Addition from b_asic import Constant, Addition, MAD, Butterfly, SquareRoot
class TestTraverse: class TestTraverse:
def test_traverse_single_tree(self, operation): def test_traverse_single_tree(self, operation):
...@@ -22,4 +22,32 @@ class TestTraverse: ...@@ -22,4 +22,32 @@ class TestTraverse:
assert len(list(filter(lambda type_: isinstance(type_, Constant), result))) == 4 assert len(list(filter(lambda type_: isinstance(type_, Constant), result))) == 4
def test_traverse_loop(self, operation_graph_with_cycle): def test_traverse_loop(self, operation_graph_with_cycle):
assert len(list(operation_graph_with_cycle.traverse())) == 8 assert len(list(operation_graph_with_cycle.traverse())) == 8
\ No newline at end of file
class TestToSfg:
def test_convert_mad_to_sfg(self):
mad1 = MAD()
mad1_sfg = mad1.to_sfg()
assert mad1.evaluate(1,1,1) == mad1_sfg.evaluate(1,1,1)
assert len(mad1_sfg.operations) == 6
def test_butterfly_to_sfg(self):
but1 = Butterfly()
but1_sfg = but1.to_sfg()
assert but1.evaluate(1,1)[0] == but1_sfg.evaluate(1,1)[0]
assert but1.evaluate(1,1)[1] == but1_sfg.evaluate(1,1)[1]
assert len(but1_sfg.operations) == 8
def test_add_to_sfg(self):
add1 = Addition()
add1_sfg = add1.to_sfg()
assert len(add1_sfg.operations) == 4
def test_sqrt_to_sfg(self):
sqrt1 = SquareRoot()
sqrt1_sfg = sqrt1.to_sfg()
assert len(sqrt1_sfg.operations) == 3
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment