From 4fee14b18c09d106b1c552a6a718bc663c8cbbb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ivar=20H=C3=A4rnqvist?= <ivaha717@student.liu.se> Date: Tue, 14 Apr 2020 22:15:30 +0200 Subject: [PATCH] add support for reverse operator overloads, remove constant operations other than ConstantMultiplication (as discussed with client) to make the core operations less bloated/confusing --- b_asic/core_operations.py | 59 +++--------------- b_asic/operation.py | 98 +++++++++++++++++++----------- test/fixtures/signal_flow_graph.py | 3 +- test/test_abstract_operation.py | 41 +++++++++---- test/test_core_operations.py | 49 +-------------- test/test_operation.py | 2 +- 6 files changed, 104 insertions(+), 148 deletions(-) diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py index f696126c..11013c84 100644 --- a/b_asic/core_operations.py +++ b/b_asic/core_operations.py @@ -30,12 +30,12 @@ class Constant(AbstractOperation): @property def value(self) -> Number: - """TODO: docstring""" + """Get the constant value of this operation.""" return self.param("value") @value.setter def value(self, value: Number): - """TODO: docstring""" + """Set the constant value of this operation.""" return self.set_param("value", value) @@ -203,56 +203,15 @@ class ConstantMultiplication(AbstractOperation): def evaluate(self, a): return a * self.param("value") - -class ConstantAddition(AbstractOperation): - """Unary constant addition operation. - TODO: More info. - """ - - def __init__(self, value: Number, src0: Optional[SignalSourceProvider] = None, name: Name = ""): - super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) - self.set_param("value", value) - - @property - def type_name(self) -> TypeName: - return "cadd" - - def evaluate(self, a): - return a + self.param("value") - - -class ConstantSubtraction(AbstractOperation): - """Unary constant subtraction operation. - TODO: More info. - """ - - def __init__(self, value: Number, src0: Optional[SignalSourceProvider] = None, name: Name = ""): - super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) - self.set_param("value", value) - @property - def type_name(self) -> TypeName: - return "csub" - - def evaluate(self, a): - return a - self.param("value") - - -class ConstantDivision(AbstractOperation): - """Unary constant division operation. - TODO: More info. - """ - - def __init__(self, value: Number, src0: Optional[SignalSourceProvider] = None, name: Name = ""): - super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) - self.set_param("value", value) - - @property - def type_name(self) -> TypeName: - return "cdiv" + def value(self) -> Number: + """Get the constant value of this operation.""" + return self.param("value") - def evaluate(self, a): - return a / self.param("value") + @value.setter + def value(self, value: Number): + """Set the constant value of this operation.""" + return self.set_param("value", value) class Butterfly(AbstractOperation): diff --git a/b_asic/operation.py b/b_asic/operation.py index 79798a52..ed0e7dfd 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -22,18 +22,30 @@ class Operation(GraphComponent, SignalSourceProvider): """ @abstractmethod - def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Addition, ConstantAddition]": + def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Addition": """Overloads the addition operator to make it return a new Addition operation - object that is connected to the self and other objects. If other is a number then - returns a ConstantAddition operation object instead. + object that is connected to the self and other objects. + """ + raise NotImplementedError + + @abstractmethod + def __radd__(self, src: Union[SignalSourceProvider, Number]) -> "Addition": + """Overloads the addition operator to make it return a new Addition operation + object that is connected to the self and other objects. """ raise NotImplementedError @abstractmethod - def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Subtraction, ConstantSubtraction]": + def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction": """Overloads the subtraction operator to make it return a new Subtraction operation - object that is connected to the self and other objects. If other is a number then - returns a ConstantSubtraction operation object instead. + object that is connected to the self and other objects. + """ + raise NotImplementedError + + @abstractmethod + def __rsub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction": + """Overloads the subtraction operator to make it return a new Subtraction operation + object that is connected to the self and other objects. """ raise NotImplementedError @@ -46,10 +58,24 @@ class Operation(GraphComponent, SignalSourceProvider): raise NotImplementedError @abstractmethod - def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Division, ConstantDivision]": - """Overloads the division operator to make it return a new Division operation + def __rmul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]": + """Overloads the multiplication operator to make it return a new Multiplication operation object that is connected to the self and other objects. If other is a number then - returns a ConstantDivision operation object instead. + returns a ConstantMultiplication operation object instead. + """ + raise NotImplementedError + + @abstractmethod + def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division": + """Overloads the division operator to make it return a new Division operation + object that is connected to the self and other objects. + """ + raise NotImplementedError + + @abstractmethod + def __rtruediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division": + """Overloads the division operator to make it return a new Division operation + object that is connected to the self and other objects. """ raise NotImplementedError @@ -202,37 +228,37 @@ class AbstractOperation(Operation, AbstractGraphComponent): """Evaluate the operation and generate a list of output values given a list of input values.""" raise NotImplementedError - def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Addition, ConstantAddition]": - # Import here to avoid circular imports. - from b_asic.core_operations import Addition, ConstantAddition - - if isinstance(src, Number): - return ConstantAddition(src, self) - return Addition(self, src) - - def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Subtraction, ConstantSubtraction]": - # Import here to avoid circular imports. - from b_asic.core_operations import Subtraction, ConstantSubtraction + def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Addition": + from b_asic.core_operations import Constant, Addition # Import here to avoid circular imports. + return Addition(self, Constant(src) if isinstance(src, Number) else src) + + def __radd__(self, src: Union[SignalSourceProvider, Number]) -> "Addition": + from b_asic.core_operations import Constant, Addition # Import here to avoid circular imports. + return Addition(Constant(src) if isinstance(src, Number) else src, self) - if isinstance(src, Number): - return ConstantSubtraction(src, self) - return Subtraction(self, src) + def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction": + from b_asic.core_operations import Constant, Subtraction # Import here to avoid circular imports. + return Subtraction(self, Constant(src) if isinstance(src, Number) else src) + + def __rsub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction": + from b_asic.core_operations import Constant, Subtraction # Import here to avoid circular imports. + return Subtraction(Constant(src) if isinstance(src, Number) else src, self) def __mul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]": - # Import here to avoid circular imports. - from b_asic.core_operations import Multiplication, ConstantMultiplication - - if isinstance(src, Number): - return ConstantMultiplication(src, self) - return Multiplication(self, src) - - def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Division, ConstantDivision]": - # Import here to avoid circular imports. - from b_asic.core_operations import Division, ConstantDivision + from b_asic.core_operations import Multiplication, ConstantMultiplication # Import here to avoid circular imports. + return ConstantMultiplication(src, self) if isinstance(src, Number) else Multiplication(self, src) + + def __rmul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]": + from b_asic.core_operations import Multiplication, ConstantMultiplication # Import here to avoid circular imports. + return ConstantMultiplication(src, self) if isinstance(src, Number) else Multiplication(src, self) - if isinstance(src, Number): - return ConstantDivision(src, self) - return Division(self, src) + def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division": + from b_asic.core_operations import Constant, Division # Import here to avoid circular imports. + return Division(self, Constant(src) if isinstance(src, Number) else src) + + def __rtruediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division": + from b_asic.core_operations import Constant, Division # Import here to avoid circular imports. + return Division(Constant(src) if isinstance(src, Number) else src, self) @property def inputs(self) -> Sequence[InputPort]: diff --git a/test/fixtures/signal_flow_graph.py b/test/fixtures/signal_flow_graph.py index df7d82b3..0a6c554d 100644 --- a/test/fixtures/signal_flow_graph.py +++ b/test/fixtures/signal_flow_graph.py @@ -65,8 +65,7 @@ def sfg_accumulator(): """ data_in = Input() reset = Input() - reset_inverted = Constant(1) - reset reg = Register() - reg.input(0).connect((reg + data_in) * reset_inverted) + reg.input(0).connect((reg + data_in) * (1 - reset)) data_out = Output(reg) return SFG(inputs = [data_in, reset], outputs = [data_out]) \ No newline at end of file diff --git a/test/test_abstract_operation.py b/test/test_abstract_operation.py index ab53dabf..5423ecdf 100644 --- a/test/test_abstract_operation.py +++ b/test/test_abstract_operation.py @@ -4,8 +4,7 @@ B-ASIC test suite for the AbstractOperation class. import pytest -from b_asic import Addition, ConstantAddition, Subtraction, ConstantSubtraction, \ - Multiplication, ConstantMultiplication, Division, ConstantDivision +from b_asic import Addition, Subtraction, Multiplication, ConstantMultiplication, Division def test_addition_overload(): @@ -14,15 +13,19 @@ def test_addition_overload(): add2 = Addition(None, None, "add2") add3 = add1 + add2 - assert isinstance(add3, Addition) assert add3.input(0).signals == add1.output(0).signals assert add3.input(1).signals == add2.output(0).signals add4 = add3 + 5 - - assert isinstance(add4, ConstantAddition) + assert isinstance(add4, Addition) assert add4.input(0).signals == add3.output(0).signals + assert add4.input(1).signals[0].source.operation.value == 5 + + add5 = 5 + add4 + assert isinstance(add5, Addition) + assert add5.input(0).signals[0].source.operation.value == 5 + assert add5.input(1).signals == add4.output(0).signals def test_subtraction_overload(): @@ -31,15 +34,19 @@ def test_subtraction_overload(): add2 = Addition(None, None, "add2") sub1 = add1 - add2 - assert isinstance(sub1, Subtraction) assert sub1.input(0).signals == add1.output(0).signals assert sub1.input(1).signals == add2.output(0).signals sub2 = sub1 - 5 - - assert isinstance(sub2, ConstantSubtraction) + assert isinstance(sub2, Subtraction) assert sub2.input(0).signals == sub1.output(0).signals + assert sub2.input(1).signals[0].source.operation.value == 5 + + sub3 = 5 - sub2 + assert isinstance(sub3, Subtraction) + assert sub3.input(0).signals[0].source.operation.value == 5 + assert sub3.input(1).signals == sub2.output(0).signals def test_multiplication_overload(): @@ -48,15 +55,19 @@ def test_multiplication_overload(): add2 = Addition(None, None, "add2") mul1 = add1 * add2 - assert isinstance(mul1, Multiplication) assert mul1.input(0).signals == add1.output(0).signals assert mul1.input(1).signals == add2.output(0).signals mul2 = mul1 * 5 - assert isinstance(mul2, ConstantMultiplication) assert mul2.input(0).signals == mul1.output(0).signals + assert mul2.value == 5 + + mul3 = 5 * mul2 + assert isinstance(mul3, ConstantMultiplication) + assert mul3.input(0).signals == mul2.output(0).signals + assert mul3.value == 5 def test_division_overload(): @@ -65,13 +76,17 @@ def test_division_overload(): add2 = Addition(None, None, "add2") div1 = add1 / add2 - assert isinstance(div1, Division) assert div1.input(0).signals == add1.output(0).signals assert div1.input(1).signals == add2.output(0).signals div2 = div1 / 5 - - assert isinstance(div2, ConstantDivision) + assert isinstance(div2, Division) assert div2.input(0).signals == div1.output(0).signals + assert div2.input(1).signals[0].source.operation.value == 5 + + div3 = 5 / div2 + assert isinstance(div3, Division) + assert div3.input(0).signals[0].source.operation.value == 5 + assert div3.input(1).signals == div2.output(0).signals diff --git a/test/test_core_operations.py b/test/test_core_operations.py index 2e7506c7..4d0039b5 100644 --- a/test/test_core_operations.py +++ b/test/test_core_operations.py @@ -2,10 +2,9 @@ B-ASIC test suite for the core operations. """ -from b_asic import Constant, Addition, Subtraction, \ - Multiplication, Division, SquareRoot, ComplexConjugate, Max, Min, \ - Absolute, ConstantMultiplication, ConstantAddition, ConstantSubtraction, \ - ConstantDivision, Butterfly +from b_asic import \ + Constant, Addition, Subtraction, Multiplication, ConstantMultiplication, Division, \ + SquareRoot, ComplexConjugate, Max, Min, Absolute, Butterfly class TestConstant: @@ -150,48 +149,6 @@ class TestConstantMultiplication: assert test_operation.evaluate_output(0, [3+4j]) == 1+18j -class TestConstantAddition: - def test_constantaddition_positive(self): - test_operation = ConstantAddition(5) - assert test_operation.evaluate_output(0, [20]) == 25 - - def test_constantaddition_negative(self): - test_operation = ConstantAddition(4) - assert test_operation.evaluate_output(0, [-5]) == -1 - - def test_constantaddition_complex(self): - test_operation = ConstantAddition(3+2j) - assert test_operation.evaluate_output(0, [3+2j]) == 6+4j - - -class TestConstantSubtraction: - def test_constantsubtraction_positive(self): - test_operation = ConstantSubtraction(5) - assert test_operation.evaluate_output(0, [20]) == 15 - - def test_constantsubtraction_negative(self): - test_operation = ConstantSubtraction(4) - assert test_operation.evaluate_output(0, [-5]) == -9 - - def test_constantsubtraction_complex(self): - test_operation = ConstantSubtraction(4+6j) - assert test_operation.evaluate_output(0, [3+4j]) == -1-2j - - -class TestConstantDivision: - def test_constantdivision_positive(self): - test_operation = ConstantDivision(5) - assert test_operation.evaluate_output(0, [20]) == 4 - - def test_constantdivision_negative(self): - test_operation = ConstantDivision(4) - assert test_operation.evaluate_output(0, [-20]) == -5 - - def test_constantdivision_complex(self): - test_operation = ConstantDivision(2+2j) - assert test_operation.evaluate_output(0, [10+10j]) == 5 - - class TestButterfly: def test_butterfly_positive(self): test_operation = Butterfly() diff --git a/test/test_operation.py b/test/test_operation.py index 4b258e00..b76ba16d 100644 --- a/test/test_operation.py +++ b/test/test_operation.py @@ -1,6 +1,6 @@ import pytest -from b_asic import Constant, Addition, ConstantAddition, Butterfly, Signal, InputPort, OutputPort +from b_asic import Constant, Addition class TestTraverse: def test_traverse_single_tree(self, operation): -- GitLab