Skip to content
Snippets Groups Projects
Commit 4fee14b1 authored by Ivar Härnqvist's avatar Ivar Härnqvist
Browse files

add support for reverse operator overloads, remove constant operations other...

add support for reverse operator overloads, remove constant operations other than ConstantMultiplication (as discussed with client) to make the core operations less bloated/confusing
parent 40aa3a4d
No related branches found
No related tags found
4 merge requests!31Resolve "Specify internal input/output dependencies of an Operation",!25Resolve "System tests iteration 1",!24Resolve "System tests iteration 1",!23Resolve "Simulate SFG"
Pipeline #12902 passed
...@@ -30,12 +30,12 @@ class Constant(AbstractOperation): ...@@ -30,12 +30,12 @@ class Constant(AbstractOperation):
@property @property
def value(self) -> Number: def value(self) -> Number:
"""TODO: docstring""" """Get the constant value of this operation."""
return self.param("value") return self.param("value")
@value.setter @value.setter
def value(self, value: Number): def value(self, value: Number):
"""TODO: docstring""" """Set the constant value of this operation."""
return self.set_param("value", value) return self.set_param("value", value)
...@@ -203,56 +203,15 @@ class ConstantMultiplication(AbstractOperation): ...@@ -203,56 +203,15 @@ class ConstantMultiplication(AbstractOperation):
def evaluate(self, a): def evaluate(self, a):
return a * self.param("value") return a * self.param("value")
class ConstantAddition(AbstractOperation):
"""Unary constant addition operation.
TODO: More info.
"""
def __init__(self, value: Number, src0: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0])
self.set_param("value", value)
@property
def type_name(self) -> TypeName:
return "cadd"
def evaluate(self, a):
return a + self.param("value")
class ConstantSubtraction(AbstractOperation):
"""Unary constant subtraction operation.
TODO: More info.
"""
def __init__(self, value: Number, src0: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0])
self.set_param("value", value)
@property @property
def type_name(self) -> TypeName: def value(self) -> Number:
return "csub" """Get the constant value of this operation."""
return self.param("value")
def evaluate(self, a):
return a - self.param("value")
class ConstantDivision(AbstractOperation):
"""Unary constant division operation.
TODO: More info.
"""
def __init__(self, value: Number, src0: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0])
self.set_param("value", value)
@property
def type_name(self) -> TypeName:
return "cdiv"
def evaluate(self, a): @value.setter
return a / self.param("value") def value(self, value: Number):
"""Set the constant value of this operation."""
return self.set_param("value", value)
class Butterfly(AbstractOperation): class Butterfly(AbstractOperation):
......
...@@ -22,18 +22,30 @@ class Operation(GraphComponent, SignalSourceProvider): ...@@ -22,18 +22,30 @@ class Operation(GraphComponent, SignalSourceProvider):
""" """
@abstractmethod @abstractmethod
def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Addition, ConstantAddition]": def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Addition":
"""Overloads the addition operator to make it return a new Addition operation """Overloads the addition operator to make it return a new Addition operation
object that is connected to the self and other objects. If other is a number then object that is connected to the self and other objects.
returns a ConstantAddition operation object instead. """
raise NotImplementedError
@abstractmethod
def __radd__(self, src: Union[SignalSourceProvider, Number]) -> "Addition":
"""Overloads the addition operator to make it return a new Addition operation
object that is connected to the self and other objects.
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Subtraction, ConstantSubtraction]": def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction":
"""Overloads the subtraction operator to make it return a new Subtraction operation """Overloads the subtraction operator to make it return a new Subtraction operation
object that is connected to the self and other objects. If other is a number then object that is connected to the self and other objects.
returns a ConstantSubtraction operation object instead. """
raise NotImplementedError
@abstractmethod
def __rsub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction":
"""Overloads the subtraction operator to make it return a new Subtraction operation
object that is connected to the self and other objects.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -46,10 +58,24 @@ class Operation(GraphComponent, SignalSourceProvider): ...@@ -46,10 +58,24 @@ class Operation(GraphComponent, SignalSourceProvider):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Division, ConstantDivision]": def __rmul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]":
"""Overloads the division operator to make it return a new Division operation """Overloads the multiplication operator to make it return a new Multiplication operation
object that is connected to the self and other objects. If other is a number then object that is connected to the self and other objects. If other is a number then
returns a ConstantDivision operation object instead. returns a ConstantMultiplication operation object instead.
"""
raise NotImplementedError
@abstractmethod
def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division":
"""Overloads the division operator to make it return a new Division operation
object that is connected to the self and other objects.
"""
raise NotImplementedError
@abstractmethod
def __rtruediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division":
"""Overloads the division operator to make it return a new Division operation
object that is connected to the self and other objects.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -202,37 +228,37 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -202,37 +228,37 @@ class AbstractOperation(Operation, AbstractGraphComponent):
"""Evaluate the operation and generate a list of output values given a list of input values.""" """Evaluate the operation and generate a list of output values given a list of input values."""
raise NotImplementedError raise NotImplementedError
def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Addition, ConstantAddition]": def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Addition":
# Import here to avoid circular imports. from b_asic.core_operations import Constant, Addition # Import here to avoid circular imports.
from b_asic.core_operations import Addition, ConstantAddition return Addition(self, Constant(src) if isinstance(src, Number) else src)
if isinstance(src, Number): def __radd__(self, src: Union[SignalSourceProvider, Number]) -> "Addition":
return ConstantAddition(src, self) from b_asic.core_operations import Constant, Addition # Import here to avoid circular imports.
return Addition(self, src) return Addition(Constant(src) if isinstance(src, Number) else src, self)
def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Subtraction, ConstantSubtraction]":
# Import here to avoid circular imports.
from b_asic.core_operations import Subtraction, ConstantSubtraction
if isinstance(src, Number): def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction":
return ConstantSubtraction(src, self) from b_asic.core_operations import Constant, Subtraction # Import here to avoid circular imports.
return Subtraction(self, src) return Subtraction(self, Constant(src) if isinstance(src, Number) else src)
def __rsub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction":
from b_asic.core_operations import Constant, Subtraction # Import here to avoid circular imports.
return Subtraction(Constant(src) if isinstance(src, Number) else src, self)
def __mul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]": def __mul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]":
# Import here to avoid circular imports. from b_asic.core_operations import Multiplication, ConstantMultiplication # Import here to avoid circular imports.
from b_asic.core_operations import Multiplication, ConstantMultiplication return ConstantMultiplication(src, self) if isinstance(src, Number) else Multiplication(self, src)
if isinstance(src, Number): def __rmul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]":
return ConstantMultiplication(src, self) from b_asic.core_operations import Multiplication, ConstantMultiplication # Import here to avoid circular imports.
return Multiplication(self, src) return ConstantMultiplication(src, self) if isinstance(src, Number) else Multiplication(src, self)
def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Division, ConstantDivision]":
# Import here to avoid circular imports.
from b_asic.core_operations import Division, ConstantDivision
if isinstance(src, Number): def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division":
return ConstantDivision(src, self) from b_asic.core_operations import Constant, Division # Import here to avoid circular imports.
return Division(self, src) return Division(self, Constant(src) if isinstance(src, Number) else src)
def __rtruediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division":
from b_asic.core_operations import Constant, Division # Import here to avoid circular imports.
return Division(Constant(src) if isinstance(src, Number) else src, self)
@property @property
def inputs(self) -> Sequence[InputPort]: def inputs(self) -> Sequence[InputPort]:
......
...@@ -65,8 +65,7 @@ def sfg_accumulator(): ...@@ -65,8 +65,7 @@ def sfg_accumulator():
""" """
data_in = Input() data_in = Input()
reset = Input() reset = Input()
reset_inverted = Constant(1) - reset
reg = Register() reg = Register()
reg.input(0).connect((reg + data_in) * reset_inverted) reg.input(0).connect((reg + data_in) * (1 - reset))
data_out = Output(reg) data_out = Output(reg)
return SFG(inputs = [data_in, reset], outputs = [data_out]) return SFG(inputs = [data_in, reset], outputs = [data_out])
\ No newline at end of file
...@@ -4,8 +4,7 @@ B-ASIC test suite for the AbstractOperation class. ...@@ -4,8 +4,7 @@ B-ASIC test suite for the AbstractOperation class.
import pytest import pytest
from b_asic import Addition, ConstantAddition, Subtraction, ConstantSubtraction, \ from b_asic import Addition, Subtraction, Multiplication, ConstantMultiplication, Division
Multiplication, ConstantMultiplication, Division, ConstantDivision
def test_addition_overload(): def test_addition_overload():
...@@ -14,15 +13,19 @@ def test_addition_overload(): ...@@ -14,15 +13,19 @@ def test_addition_overload():
add2 = Addition(None, None, "add2") add2 = Addition(None, None, "add2")
add3 = add1 + add2 add3 = add1 + add2
assert isinstance(add3, Addition) assert isinstance(add3, Addition)
assert add3.input(0).signals == add1.output(0).signals assert add3.input(0).signals == add1.output(0).signals
assert add3.input(1).signals == add2.output(0).signals assert add3.input(1).signals == add2.output(0).signals
add4 = add3 + 5 add4 = add3 + 5
assert isinstance(add4, Addition)
assert isinstance(add4, ConstantAddition)
assert add4.input(0).signals == add3.output(0).signals assert add4.input(0).signals == add3.output(0).signals
assert add4.input(1).signals[0].source.operation.value == 5
add5 = 5 + add4
assert isinstance(add5, Addition)
assert add5.input(0).signals[0].source.operation.value == 5
assert add5.input(1).signals == add4.output(0).signals
def test_subtraction_overload(): def test_subtraction_overload():
...@@ -31,15 +34,19 @@ def test_subtraction_overload(): ...@@ -31,15 +34,19 @@ def test_subtraction_overload():
add2 = Addition(None, None, "add2") add2 = Addition(None, None, "add2")
sub1 = add1 - add2 sub1 = add1 - add2
assert isinstance(sub1, Subtraction) assert isinstance(sub1, Subtraction)
assert sub1.input(0).signals == add1.output(0).signals assert sub1.input(0).signals == add1.output(0).signals
assert sub1.input(1).signals == add2.output(0).signals assert sub1.input(1).signals == add2.output(0).signals
sub2 = sub1 - 5 sub2 = sub1 - 5
assert isinstance(sub2, Subtraction)
assert isinstance(sub2, ConstantSubtraction)
assert sub2.input(0).signals == sub1.output(0).signals assert sub2.input(0).signals == sub1.output(0).signals
assert sub2.input(1).signals[0].source.operation.value == 5
sub3 = 5 - sub2
assert isinstance(sub3, Subtraction)
assert sub3.input(0).signals[0].source.operation.value == 5
assert sub3.input(1).signals == sub2.output(0).signals
def test_multiplication_overload(): def test_multiplication_overload():
...@@ -48,15 +55,19 @@ def test_multiplication_overload(): ...@@ -48,15 +55,19 @@ def test_multiplication_overload():
add2 = Addition(None, None, "add2") add2 = Addition(None, None, "add2")
mul1 = add1 * add2 mul1 = add1 * add2
assert isinstance(mul1, Multiplication) assert isinstance(mul1, Multiplication)
assert mul1.input(0).signals == add1.output(0).signals assert mul1.input(0).signals == add1.output(0).signals
assert mul1.input(1).signals == add2.output(0).signals assert mul1.input(1).signals == add2.output(0).signals
mul2 = mul1 * 5 mul2 = mul1 * 5
assert isinstance(mul2, ConstantMultiplication) assert isinstance(mul2, ConstantMultiplication)
assert mul2.input(0).signals == mul1.output(0).signals assert mul2.input(0).signals == mul1.output(0).signals
assert mul2.value == 5
mul3 = 5 * mul2
assert isinstance(mul3, ConstantMultiplication)
assert mul3.input(0).signals == mul2.output(0).signals
assert mul3.value == 5
def test_division_overload(): def test_division_overload():
...@@ -65,13 +76,17 @@ def test_division_overload(): ...@@ -65,13 +76,17 @@ def test_division_overload():
add2 = Addition(None, None, "add2") add2 = Addition(None, None, "add2")
div1 = add1 / add2 div1 = add1 / add2
assert isinstance(div1, Division) assert isinstance(div1, Division)
assert div1.input(0).signals == add1.output(0).signals assert div1.input(0).signals == add1.output(0).signals
assert div1.input(1).signals == add2.output(0).signals assert div1.input(1).signals == add2.output(0).signals
div2 = div1 / 5 div2 = div1 / 5
assert isinstance(div2, Division)
assert isinstance(div2, ConstantDivision)
assert div2.input(0).signals == div1.output(0).signals assert div2.input(0).signals == div1.output(0).signals
assert div2.input(1).signals[0].source.operation.value == 5
div3 = 5 / div2
assert isinstance(div3, Division)
assert div3.input(0).signals[0].source.operation.value == 5
assert div3.input(1).signals == div2.output(0).signals
...@@ -2,10 +2,9 @@ ...@@ -2,10 +2,9 @@
B-ASIC test suite for the core operations. B-ASIC test suite for the core operations.
""" """
from b_asic import Constant, Addition, Subtraction, \ from b_asic import \
Multiplication, Division, SquareRoot, ComplexConjugate, Max, Min, \ Constant, Addition, Subtraction, Multiplication, ConstantMultiplication, Division, \
Absolute, ConstantMultiplication, ConstantAddition, ConstantSubtraction, \ SquareRoot, ComplexConjugate, Max, Min, Absolute, Butterfly
ConstantDivision, Butterfly
class TestConstant: class TestConstant:
...@@ -150,48 +149,6 @@ class TestConstantMultiplication: ...@@ -150,48 +149,6 @@ class TestConstantMultiplication:
assert test_operation.evaluate_output(0, [3+4j]) == 1+18j assert test_operation.evaluate_output(0, [3+4j]) == 1+18j
class TestConstantAddition:
def test_constantaddition_positive(self):
test_operation = ConstantAddition(5)
assert test_operation.evaluate_output(0, [20]) == 25
def test_constantaddition_negative(self):
test_operation = ConstantAddition(4)
assert test_operation.evaluate_output(0, [-5]) == -1
def test_constantaddition_complex(self):
test_operation = ConstantAddition(3+2j)
assert test_operation.evaluate_output(0, [3+2j]) == 6+4j
class TestConstantSubtraction:
def test_constantsubtraction_positive(self):
test_operation = ConstantSubtraction(5)
assert test_operation.evaluate_output(0, [20]) == 15
def test_constantsubtraction_negative(self):
test_operation = ConstantSubtraction(4)
assert test_operation.evaluate_output(0, [-5]) == -9
def test_constantsubtraction_complex(self):
test_operation = ConstantSubtraction(4+6j)
assert test_operation.evaluate_output(0, [3+4j]) == -1-2j
class TestConstantDivision:
def test_constantdivision_positive(self):
test_operation = ConstantDivision(5)
assert test_operation.evaluate_output(0, [20]) == 4
def test_constantdivision_negative(self):
test_operation = ConstantDivision(4)
assert test_operation.evaluate_output(0, [-20]) == -5
def test_constantdivision_complex(self):
test_operation = ConstantDivision(2+2j)
assert test_operation.evaluate_output(0, [10+10j]) == 5
class TestButterfly: class TestButterfly:
def test_butterfly_positive(self): def test_butterfly_positive(self):
test_operation = Butterfly() test_operation = Butterfly()
......
import pytest import pytest
from b_asic import Constant, Addition, ConstantAddition, Butterfly, Signal, InputPort, OutputPort from b_asic import Constant, Addition
class TestTraverse: class TestTraverse:
def test_traverse_single_tree(self, operation): def test_traverse_single_tree(self, operation):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment