diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py index c43a12c738abe05daee5fc8de7bfbd19fbb1b118..3741d90265b5f7a04bcd02ace789a7e1c6697f6d 100644 --- a/b_asic/core_operations.py +++ b/b_asic/core_operations.py @@ -235,8 +235,8 @@ class MAD(AbstractOperation): TODO: More info. """ - def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = ""): - super().__init__(input_count = 3, output_count = 1, name = name, input_sources = [src0, src1]) + def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, src2: Optional[SignalSourceProvider] = None, name: Name = ""): + super().__init__(input_count = 3, output_count = 1, name = name, input_sources = [src0, src1, src2]) @property def type_name(self) -> TypeName: diff --git a/b_asic/operation.py b/b_asic/operation.py index a0d0f48a1f7429ce0d393ad4e93ef24c84914f7b..e5f91fe7fd710523644d8c6859489ddf6f0e0bb2 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -329,7 +329,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): # Import here to avoid circular imports. from b_asic.special_operations import Input try: - result = self.evaluate([Input()] * self.input_count) + result = self.evaluate(*[Input()] * self.input_count) if isinstance(result, collections.Sequence) and all(isinstance(e, Operation) for e in result): return result if isinstance(result, Operation): @@ -338,7 +338,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): pass except ValueError: pass - return [self] + return [self] @property def neighbors(self) -> Iterable[GraphComponent]: diff --git a/test/test_core_operations.py b/test/test_core_operations.py index 4d0039b558e81c5cd74f151f93f0bc0194a702d5..1482938ef989da53203c8d211c1643946af95383 100644 --- a/test/test_core_operations.py +++ b/test/test_core_operations.py @@ -4,8 +4,7 @@ B-ASIC test suite for the core operations. from b_asic import \ Constant, Addition, Subtraction, Multiplication, ConstantMultiplication, Division, \ - SquareRoot, ComplexConjugate, Max, Min, Absolute, Butterfly - + SquareRoot, ComplexConjugate, Max, Min, Absolute, Butterfly, MAD class TestConstant: def test_constant_positive(self): @@ -164,3 +163,14 @@ class TestButterfly: test_operation = Butterfly() assert test_operation.evaluate_output(0, [2+1j, 3-2j]) == 5-1j assert test_operation.evaluate_output(1, [2+1j, 3-2j]) == -1+3j + + def test_split(self): + but1 = Butterfly() + split = but1.split() + assert len(split) == 2 + +class TestMad: + def test_split_mad(self): + mad1 = MAD() + res = mad1.split() + assert len(res) == 2