diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py index fd75359d32afe8eb9914dae699549d980b08d536..3a1474cc7513955e8bd7536e7bf4350fe61d94cd 100644 --- a/b_asic/core_operations.py +++ b/b_asic/core_operations.py @@ -1102,6 +1102,7 @@ class MAD(AbstractOperation): class MADS(AbstractOperation): __slots__ = ( "_is_add", + "_override_zero_on_src0", "_src0", "_src1", "_src2", @@ -1111,6 +1112,7 @@ class MADS(AbstractOperation): "_execution_time", ) _is_add: Optional[bool] + _override_zero_on_src0: Optional[bool] _src0: Optional[SignalSourceProvider] _src1: Optional[SignalSourceProvider] _src2: Optional[SignalSourceProvider] @@ -1124,6 +1126,7 @@ class MADS(AbstractOperation): def __init__( self, is_add: Optional[bool] = True, + override_zero_on_src0: Optional[bool] = False, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, src2: Optional[SignalSourceProvider] = None, @@ -1143,13 +1146,23 @@ class MADS(AbstractOperation): execution_time=execution_time, ) self.set_param("is_add", is_add) + self.set_param("override_zero_on_src0", override_zero_on_src0) @classmethod def type_name(cls) -> TypeName: return TypeName("mads") def evaluate(self, a, b, c): - return a + b * c if self.is_add else a - b * c + if self.is_add: + if self.override_zero_on_src0: + return b * c + else: + return a + b * c + else: + if self.override_zero_on_src0: + return -b * c + else: + return a - b * c @property def is_add(self) -> bool: @@ -1161,11 +1174,21 @@ class MADS(AbstractOperation): """Set if operation is an addition.""" self.set_param("is_add", is_add) + @property + def override_zero_on_src0(self) -> bool: + """Get if operation is overriding a zero on port src0.""" + return self.param("override_zero_on_src0") + + @override_zero_on_src0.setter + def override_zero_on_src0(self, override_zero_on_src0: bool) -> None: + """Set if operation is overriding a zero on port src0.""" + self.set_param("override_zero_on_src0", override_zero_on_src0) + @property def is_linear(self) -> bool: return ( - self.input(0).connected_source.operation.is_constant - or self.input(1).connected_source.operation.is_constant + self.input(1).connected_source.operation.is_constant + or self.input(2).connected_source.operation.is_constant ) def swap_io(self) -> None: @@ -1598,6 +1621,51 @@ class Shift(AbstractOperation): self.set_param("value", value) +class DontCare(AbstractOperation): + r""" + Dont-care operation + + Used for ignoring the input to another operation and thus avoiding dangling input nodes. + + Parameters + ---------- + name : Name, optional + Operation name. + + """ + + __slots__ = "_name" + _name: Name + + is_linear = True + + def __init__(self, name: Name = ""): + """Construct a DontCare operation.""" + super().__init__( + input_count=0, + output_count=1, + name=name, + latency_offsets={"out0": 0}, + ) + + @classmethod + def type_name(cls) -> TypeName: + return TypeName("dontcare") + + def evaluate(self): + return 0 + + @property + def latency(self) -> int: + return self.latency_offsets["out0"] + + def __repr__(self) -> str: + return "DontCare()" + + def __str__(self) -> str: + return "dontcare" + + class Sink(AbstractOperation): r""" Sink operation. diff --git a/b_asic/sfg_generators.py b/b_asic/sfg_generators.py index 0f7d26c50c5354553713d1655fc3d0dafc5ea70d..3e76f159706a72d349d8006c27c894372f5a1ce3 100644 --- a/b_asic/sfg_generators.py +++ b/b_asic/sfg_generators.py @@ -12,9 +12,8 @@ from b_asic.core_operations import ( MADS, Addition, Butterfly, - ComplexConjugate, - Constant, ConstantMultiplication, + DontCare, Name, Reciprocal, SymmetricTwoportAdaptor, @@ -436,7 +435,19 @@ def radix_2_dif_fft(points: int) -> SFG: return SFG(inputs=inputs, outputs=outputs) -def ldlt_matrix_inverse(N: int, is_complex: bool) -> SFG: +def ldlt_matrix_inverse(N: int) -> SFG: + """Generates an SFG for the LDLT matrix inverse algorithm. + + Parameters + ---------- + N : int + Dimension of the square input matrix. + + Returns + ------- + SFG + Signal Flow Graph + """ inputs = [] A = [[None for _ in range(N)] for _ in range(N)] for i in range(N): @@ -457,7 +468,7 @@ def ldlt_matrix_inverse(N: int, is_complex: bool) -> SFG: # R*di*R^T factorization for i in range(N): for k in range(i): - D[i] = MADS(False, D[i], M[k][i], R[k][i]) + D[i] = MADS(False, False, D[i], M[k][i], R[k][i]) D_inv[i] = Reciprocal(D[i]) @@ -465,14 +476,14 @@ def ldlt_matrix_inverse(N: int, is_complex: bool) -> SFG: R[i][j] = A[i][j] for k in range(i): - R[i][j] = MADS(False, R[i][j], M[k][i], R[k][j]) + R[i][j] = MADS(False, False, R[i][j], M[k][i], R[k][j]) - if is_complex: - M[i][j] = ComplexConjugate(R[i][j]) - else: - M[i][j] = R[i][j] + # if is_complex: + # M[i][j] = ComplexConjugate(R[i][j]) + # else: + M[i][j] = R[i][j] - R[i][j] = MADS(True, Constant(0, name="0"), R[i][j], D_inv[i]) + R[i][j] = MADS(True, True, DontCare(), R[i][j], D_inv[i]) # back substitution A_inv = [[None for _ in range(N)] for _ in range(N)] @@ -481,14 +492,16 @@ def ldlt_matrix_inverse(N: int, is_complex: bool) -> SFG: for j in reversed(range(i + 1)): for k in reversed(range(j + 1, N)): if k == N - 1 and i != j: - A_inv[j][i] = MADS( - False, Constant(0, name="0"), R[j][k], A_inv[i][k] - ) + A_inv[j][i] = MADS(False, True, DontCare(), R[j][k], A_inv[i][k]) else: if A_inv[i][k]: - A_inv[j][i] = MADS(False, A_inv[j][i], R[j][k], A_inv[i][k]) + A_inv[j][i] = MADS( + False, False, A_inv[j][i], R[j][k], A_inv[i][k] + ) else: - A_inv[j][i] = MADS(False, A_inv[j][i], R[j][k], A_inv[k][i]) + A_inv[j][i] = MADS( + False, False, A_inv[j][i], R[j][k], A_inv[k][i] + ) outputs = [] for i in range(N): diff --git a/test/test_core_operations.py b/test/test_core_operations.py index 328acad27a38b16f27c6a424d0e7af45f42be1d9..4b34ac58fa029c5093a6c8674cc484aca5586724 100644 --- a/test/test_core_operations.py +++ b/test/test_core_operations.py @@ -5,6 +5,7 @@ import pytest from b_asic import ( MAD, MADS, + SFG, Absolute, Addition, AddSub, @@ -13,11 +14,13 @@ from b_asic import ( Constant, ConstantMultiplication, Division, + DontCare, Input, LeftShift, Max, Min, Multiplication, + Output, Reciprocal, RightShift, Shift, @@ -343,19 +346,33 @@ class TestMADS: test_operation = MADS(is_add=True) assert test_operation.evaluate_output(0, [3 + 6j, 2 + 6j, 1 + 1j]) == -1 + 14j + def test_mads_zero_override(self): + test_operation = MADS(is_add=True, override_zero_on_src0=True) + assert test_operation.evaluate_output(0, [1, 1, 1]) == 1 + + def test_mads_sub_zero_override(self): + test_operation = MADS(is_add=False, override_zero_on_src0=True) + assert test_operation.evaluate_output(0, [1, 1, 1]) == -1 + def test_mads_is_linear(self): test_operation = MADS( - Constant(3), Addition(Input(), Constant(3)), Addition(Input(), Constant(3)) + src0=Constant(3), + src1=Addition(Input(), Constant(3)), + src2=Addition(Input(), Constant(3)), ) assert not test_operation.is_linear test_operation = MADS( - Addition(Input(), Constant(3)), Constant(3), Addition(Input(), Constant(3)) + src0=Addition(Input(), Constant(3)), + src1=Constant(3), + src2=Addition(Input(), Constant(3)), ) assert test_operation.is_linear test_operation = MADS( - Addition(Input(), Constant(3)), Addition(Input(), Constant(3)), Constant(3) + src0=Addition(Input(), Constant(3)), + src1=Addition(Input(), Constant(3)), + src2=Constant(3), ) assert test_operation.is_linear @@ -381,6 +398,22 @@ class TestMADS: test_operation.is_add = False assert not test_operation.is_add + def test_mads_override_zero_on_src0_getter(self): + test_operation = MADS(override_zero_on_src0=False) + assert not test_operation.override_zero_on_src0 + + test_operation = MADS(override_zero_on_src0=True) + assert test_operation.override_zero_on_src0 + + def test_mads_override_zero_on_src0_setter(self): + test_operation = MADS(override_zero_on_src0=False) + test_operation.override_zero_on_src0 = True + assert test_operation.override_zero_on_src0 + + test_operation = MADS(override_zero_on_src0=True) + test_operation.override_zero_on_src0 = False + assert not test_operation.override_zero_on_src0 + class TestRightShift: """Tests for RightShift class.""" @@ -556,6 +589,33 @@ class TestDepends: assert set(bfly1.inputs_required_for_output(1)) == {0, 1} +class TestDontCare: + def test_create_sfg_with_dontcare(self): + i1 = Input() + dc = DontCare() + a = Addition(i1, dc) + o = Output(a) + sfg = SFG([i1], [o]) + + assert sfg.output_count == 1 + assert sfg.input_count == 1 + + assert sfg.evaluate_output(0, [0]) == 0 + assert sfg.evaluate_output(0, [1]) == 1 + + def test_dontcare_latency_getter(self): + test_operation = DontCare() + assert test_operation.latency == 0 + + def test_dontcare_repr(self): + test_operation = DontCare() + assert repr(test_operation) == "DontCare()" + + def test_dontcare_str(self): + test_operation = DontCare() + assert str(test_operation) == "dontcare" + + class TestSink: def test_create_sfg_with_sink(self): bfly = Butterfly() diff --git a/test/test_sfg_generators.py b/test/test_sfg_generators.py index 3d876d7595c92132fe8f247c50719da9a5b385e7..bbd8916ef3a47df39d49a6d0bfaaa96022dd44d5 100644 --- a/test/test_sfg_generators.py +++ b/test/test_sfg_generators.py @@ -644,7 +644,7 @@ class TestRadix2FFT: class TestLdltMatrixInverse: def test_1x1(self): - sfg = ldlt_matrix_inverse(N=1, is_complex=False) + sfg = ldlt_matrix_inverse(N=1) assert len(sfg.inputs) == 1 assert len(sfg.outputs) == 1 @@ -661,7 +661,7 @@ class TestLdltMatrixInverse: assert np.isclose(res["0"], 0.2) def test_2x2_simple_spd(self): - sfg = ldlt_matrix_inverse(N=2, is_complex=False) + sfg = ldlt_matrix_inverse(N=2) assert len(sfg.inputs) == 3 assert len(sfg.outputs) == 3 @@ -683,7 +683,7 @@ class TestLdltMatrixInverse: assert np.isclose(res["2"], A_inv[1, 1]) def test_3x3_simple_spd(self): - sfg = ldlt_matrix_inverse(N=3, is_complex=False) + sfg = ldlt_matrix_inverse(N=3) assert len(sfg.inputs) == 6 assert len(sfg.outputs) == 6 @@ -717,7 +717,7 @@ class TestLdltMatrixInverse: def test_5x5_random_spd(self): N = 5 - sfg = ldlt_matrix_inverse(N=N, is_complex=False) + sfg = ldlt_matrix_inverse(N=N) assert len(sfg.inputs) == 15 assert len(sfg.outputs) == 15 @@ -746,7 +746,7 @@ class TestLdltMatrixInverse: def test_20x20_random_spd(self): N = 20 - sfg = ldlt_matrix_inverse(N=N, is_complex=False) + sfg = ldlt_matrix_inverse(N=N) A = self._generate_random_spd_matrix(N)