diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py index f696126c8a1ad7836c0df57da7e4ffc4f0ab10bb..11013c84b08a0ac7e3f406249082df54e79f97bc 100644 --- a/b_asic/core_operations.py +++ b/b_asic/core_operations.py @@ -30,12 +30,12 @@ class Constant(AbstractOperation): @property def value(self) -> Number: - """TODO: docstring""" + """Get the constant value of this operation.""" return self.param("value") @value.setter def value(self, value: Number): - """TODO: docstring""" + """Set the constant value of this operation.""" return self.set_param("value", value) @@ -203,56 +203,15 @@ class ConstantMultiplication(AbstractOperation): def evaluate(self, a): return a * self.param("value") - -class ConstantAddition(AbstractOperation): - """Unary constant addition operation. - TODO: More info. - """ - - def __init__(self, value: Number, src0: Optional[SignalSourceProvider] = None, name: Name = ""): - super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) - self.set_param("value", value) - - @property - def type_name(self) -> TypeName: - return "cadd" - - def evaluate(self, a): - return a + self.param("value") - - -class ConstantSubtraction(AbstractOperation): - """Unary constant subtraction operation. - TODO: More info. - """ - - def __init__(self, value: Number, src0: Optional[SignalSourceProvider] = None, name: Name = ""): - super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) - self.set_param("value", value) - @property - def type_name(self) -> TypeName: - return "csub" - - def evaluate(self, a): - return a - self.param("value") - - -class ConstantDivision(AbstractOperation): - """Unary constant division operation. - TODO: More info. - """ - - def __init__(self, value: Number, src0: Optional[SignalSourceProvider] = None, name: Name = ""): - super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) - self.set_param("value", value) - - @property - def type_name(self) -> TypeName: - return "cdiv" + def value(self) -> Number: + """Get the constant value of this operation.""" + return self.param("value") - def evaluate(self, a): - return a / self.param("value") + @value.setter + def value(self, value: Number): + """Set the constant value of this operation.""" + return self.set_param("value", value) class Butterfly(AbstractOperation): diff --git a/b_asic/operation.py b/b_asic/operation.py index 79798a5221cb9133124256f9650c070e1ed73f03..ed0e7dfd1deb8bb217417eb71ad21cb0287a3bce 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -22,18 +22,30 @@ class Operation(GraphComponent, SignalSourceProvider): """ @abstractmethod - def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Addition, ConstantAddition]": + def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Addition": """Overloads the addition operator to make it return a new Addition operation - object that is connected to the self and other objects. If other is a number then - returns a ConstantAddition operation object instead. + object that is connected to the self and other objects. + """ + raise NotImplementedError + + @abstractmethod + def __radd__(self, src: Union[SignalSourceProvider, Number]) -> "Addition": + """Overloads the addition operator to make it return a new Addition operation + object that is connected to the self and other objects. """ raise NotImplementedError @abstractmethod - def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Subtraction, ConstantSubtraction]": + def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction": """Overloads the subtraction operator to make it return a new Subtraction operation - object that is connected to the self and other objects. If other is a number then - returns a ConstantSubtraction operation object instead. + object that is connected to the self and other objects. + """ + raise NotImplementedError + + @abstractmethod + def __rsub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction": + """Overloads the subtraction operator to make it return a new Subtraction operation + object that is connected to the self and other objects. """ raise NotImplementedError @@ -46,10 +58,24 @@ class Operation(GraphComponent, SignalSourceProvider): raise NotImplementedError @abstractmethod - def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Division, ConstantDivision]": - """Overloads the division operator to make it return a new Division operation + def __rmul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]": + """Overloads the multiplication operator to make it return a new Multiplication operation object that is connected to the self and other objects. If other is a number then - returns a ConstantDivision operation object instead. + returns a ConstantMultiplication operation object instead. + """ + raise NotImplementedError + + @abstractmethod + def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division": + """Overloads the division operator to make it return a new Division operation + object that is connected to the self and other objects. + """ + raise NotImplementedError + + @abstractmethod + def __rtruediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division": + """Overloads the division operator to make it return a new Division operation + object that is connected to the self and other objects. """ raise NotImplementedError @@ -202,37 +228,37 @@ class AbstractOperation(Operation, AbstractGraphComponent): """Evaluate the operation and generate a list of output values given a list of input values.""" raise NotImplementedError - def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Addition, ConstantAddition]": - # Import here to avoid circular imports. - from b_asic.core_operations import Addition, ConstantAddition - - if isinstance(src, Number): - return ConstantAddition(src, self) - return Addition(self, src) - - def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Subtraction, ConstantSubtraction]": - # Import here to avoid circular imports. - from b_asic.core_operations import Subtraction, ConstantSubtraction + def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Addition": + from b_asic.core_operations import Constant, Addition # Import here to avoid circular imports. + return Addition(self, Constant(src) if isinstance(src, Number) else src) + + def __radd__(self, src: Union[SignalSourceProvider, Number]) -> "Addition": + from b_asic.core_operations import Constant, Addition # Import here to avoid circular imports. + return Addition(Constant(src) if isinstance(src, Number) else src, self) - if isinstance(src, Number): - return ConstantSubtraction(src, self) - return Subtraction(self, src) + def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction": + from b_asic.core_operations import Constant, Subtraction # Import here to avoid circular imports. + return Subtraction(self, Constant(src) if isinstance(src, Number) else src) + + def __rsub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction": + from b_asic.core_operations import Constant, Subtraction # Import here to avoid circular imports. + return Subtraction(Constant(src) if isinstance(src, Number) else src, self) def __mul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]": - # Import here to avoid circular imports. - from b_asic.core_operations import Multiplication, ConstantMultiplication - - if isinstance(src, Number): - return ConstantMultiplication(src, self) - return Multiplication(self, src) - - def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Division, ConstantDivision]": - # Import here to avoid circular imports. - from b_asic.core_operations import Division, ConstantDivision + from b_asic.core_operations import Multiplication, ConstantMultiplication # Import here to avoid circular imports. + return ConstantMultiplication(src, self) if isinstance(src, Number) else Multiplication(self, src) + + def __rmul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]": + from b_asic.core_operations import Multiplication, ConstantMultiplication # Import here to avoid circular imports. + return ConstantMultiplication(src, self) if isinstance(src, Number) else Multiplication(src, self) - if isinstance(src, Number): - return ConstantDivision(src, self) - return Division(self, src) + def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division": + from b_asic.core_operations import Constant, Division # Import here to avoid circular imports. + return Division(self, Constant(src) if isinstance(src, Number) else src) + + def __rtruediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division": + from b_asic.core_operations import Constant, Division # Import here to avoid circular imports. + return Division(Constant(src) if isinstance(src, Number) else src, self) @property def inputs(self) -> Sequence[InputPort]: diff --git a/test/fixtures/signal_flow_graph.py b/test/fixtures/signal_flow_graph.py index df7d82b397eeaa936517371361d09eebc33a067f..0a6c554d1340478dad25a11655f0542bf6fba1d1 100644 --- a/test/fixtures/signal_flow_graph.py +++ b/test/fixtures/signal_flow_graph.py @@ -65,8 +65,7 @@ def sfg_accumulator(): """ data_in = Input() reset = Input() - reset_inverted = Constant(1) - reset reg = Register() - reg.input(0).connect((reg + data_in) * reset_inverted) + reg.input(0).connect((reg + data_in) * (1 - reset)) data_out = Output(reg) return SFG(inputs = [data_in, reset], outputs = [data_out]) \ No newline at end of file diff --git a/test/test_abstract_operation.py b/test/test_abstract_operation.py index ab53dabfd50618364e7b15471cc4753341b16ab9..5423ecdf08c420df5dccc6393c3ad6637961172b 100644 --- a/test/test_abstract_operation.py +++ b/test/test_abstract_operation.py @@ -4,8 +4,7 @@ B-ASIC test suite for the AbstractOperation class. import pytest -from b_asic import Addition, ConstantAddition, Subtraction, ConstantSubtraction, \ - Multiplication, ConstantMultiplication, Division, ConstantDivision +from b_asic import Addition, Subtraction, Multiplication, ConstantMultiplication, Division def test_addition_overload(): @@ -14,15 +13,19 @@ def test_addition_overload(): add2 = Addition(None, None, "add2") add3 = add1 + add2 - assert isinstance(add3, Addition) assert add3.input(0).signals == add1.output(0).signals assert add3.input(1).signals == add2.output(0).signals add4 = add3 + 5 - - assert isinstance(add4, ConstantAddition) + assert isinstance(add4, Addition) assert add4.input(0).signals == add3.output(0).signals + assert add4.input(1).signals[0].source.operation.value == 5 + + add5 = 5 + add4 + assert isinstance(add5, Addition) + assert add5.input(0).signals[0].source.operation.value == 5 + assert add5.input(1).signals == add4.output(0).signals def test_subtraction_overload(): @@ -31,15 +34,19 @@ def test_subtraction_overload(): add2 = Addition(None, None, "add2") sub1 = add1 - add2 - assert isinstance(sub1, Subtraction) assert sub1.input(0).signals == add1.output(0).signals assert sub1.input(1).signals == add2.output(0).signals sub2 = sub1 - 5 - - assert isinstance(sub2, ConstantSubtraction) + assert isinstance(sub2, Subtraction) assert sub2.input(0).signals == sub1.output(0).signals + assert sub2.input(1).signals[0].source.operation.value == 5 + + sub3 = 5 - sub2 + assert isinstance(sub3, Subtraction) + assert sub3.input(0).signals[0].source.operation.value == 5 + assert sub3.input(1).signals == sub2.output(0).signals def test_multiplication_overload(): @@ -48,15 +55,19 @@ def test_multiplication_overload(): add2 = Addition(None, None, "add2") mul1 = add1 * add2 - assert isinstance(mul1, Multiplication) assert mul1.input(0).signals == add1.output(0).signals assert mul1.input(1).signals == add2.output(0).signals mul2 = mul1 * 5 - assert isinstance(mul2, ConstantMultiplication) assert mul2.input(0).signals == mul1.output(0).signals + assert mul2.value == 5 + + mul3 = 5 * mul2 + assert isinstance(mul3, ConstantMultiplication) + assert mul3.input(0).signals == mul2.output(0).signals + assert mul3.value == 5 def test_division_overload(): @@ -65,13 +76,17 @@ def test_division_overload(): add2 = Addition(None, None, "add2") div1 = add1 / add2 - assert isinstance(div1, Division) assert div1.input(0).signals == add1.output(0).signals assert div1.input(1).signals == add2.output(0).signals div2 = div1 / 5 - - assert isinstance(div2, ConstantDivision) + assert isinstance(div2, Division) assert div2.input(0).signals == div1.output(0).signals + assert div2.input(1).signals[0].source.operation.value == 5 + + div3 = 5 / div2 + assert isinstance(div3, Division) + assert div3.input(0).signals[0].source.operation.value == 5 + assert div3.input(1).signals == div2.output(0).signals diff --git a/test/test_core_operations.py b/test/test_core_operations.py index 2e7506c7493c925680d2cb15aedbb3b2c882d437..4d0039b558e81c5cd74f151f93f0bc0194a702d5 100644 --- a/test/test_core_operations.py +++ b/test/test_core_operations.py @@ -2,10 +2,9 @@ B-ASIC test suite for the core operations. """ -from b_asic import Constant, Addition, Subtraction, \ - Multiplication, Division, SquareRoot, ComplexConjugate, Max, Min, \ - Absolute, ConstantMultiplication, ConstantAddition, ConstantSubtraction, \ - ConstantDivision, Butterfly +from b_asic import \ + Constant, Addition, Subtraction, Multiplication, ConstantMultiplication, Division, \ + SquareRoot, ComplexConjugate, Max, Min, Absolute, Butterfly class TestConstant: @@ -150,48 +149,6 @@ class TestConstantMultiplication: assert test_operation.evaluate_output(0, [3+4j]) == 1+18j -class TestConstantAddition: - def test_constantaddition_positive(self): - test_operation = ConstantAddition(5) - assert test_operation.evaluate_output(0, [20]) == 25 - - def test_constantaddition_negative(self): - test_operation = ConstantAddition(4) - assert test_operation.evaluate_output(0, [-5]) == -1 - - def test_constantaddition_complex(self): - test_operation = ConstantAddition(3+2j) - assert test_operation.evaluate_output(0, [3+2j]) == 6+4j - - -class TestConstantSubtraction: - def test_constantsubtraction_positive(self): - test_operation = ConstantSubtraction(5) - assert test_operation.evaluate_output(0, [20]) == 15 - - def test_constantsubtraction_negative(self): - test_operation = ConstantSubtraction(4) - assert test_operation.evaluate_output(0, [-5]) == -9 - - def test_constantsubtraction_complex(self): - test_operation = ConstantSubtraction(4+6j) - assert test_operation.evaluate_output(0, [3+4j]) == -1-2j - - -class TestConstantDivision: - def test_constantdivision_positive(self): - test_operation = ConstantDivision(5) - assert test_operation.evaluate_output(0, [20]) == 4 - - def test_constantdivision_negative(self): - test_operation = ConstantDivision(4) - assert test_operation.evaluate_output(0, [-20]) == -5 - - def test_constantdivision_complex(self): - test_operation = ConstantDivision(2+2j) - assert test_operation.evaluate_output(0, [10+10j]) == 5 - - class TestButterfly: def test_butterfly_positive(self): test_operation = Butterfly() diff --git a/test/test_operation.py b/test/test_operation.py index 4b258e00ab127d6489af81bf455bba411509439d..b76ba16d11425c0ce868e4fa0b4c88d9f862e23f 100644 --- a/test/test_operation.py +++ b/test/test_operation.py @@ -1,6 +1,6 @@ import pytest -from b_asic import Constant, Addition, ConstantAddition, Butterfly, Signal, InputPort, OutputPort +from b_asic import Constant, Addition class TestTraverse: def test_traverse_single_tree(self, operation):