diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py index 161b979cacc97a6de719da692b529c71c23b0da6..34b6c7d0d384392c7d2140d3a4fa50eaff1c71cb 100644 --- a/b_asic/core_operations.py +++ b/b_asic/core_operations.py @@ -66,6 +66,20 @@ class Constant(AbstractOperation): def latency(self) -> int: return self.latency_offsets["out0"] + def __repr__(self) -> str: + return f"Constant({self.value})" + + def __str__(self) -> str: + return f"{self.value}" + + @property + def is_linear(self) -> bool: + return True + + @property + def is_constant(self) -> bool: + return True + class Addition(AbstractOperation): """ @@ -129,6 +143,10 @@ class Addition(AbstractOperation): def evaluate(self, a, b): return a + b + @property + def is_linear(self) -> bool: + return True + class Subtraction(AbstractOperation): """ @@ -189,6 +207,10 @@ class Subtraction(AbstractOperation): def evaluate(self, a, b): return a - b + @property + def is_linear(self) -> bool: + return True + class AddSub(AbstractOperation): r""" @@ -270,6 +292,10 @@ class AddSub(AbstractOperation): """Set if operation is an addition.""" self.set_param("is_add", is_add) + @property + def is_linear(self) -> bool: + return True + class Multiplication(AbstractOperation): r""" @@ -331,6 +357,12 @@ class Multiplication(AbstractOperation): def evaluate(self, a, b): return a * b + @property + def is_linear(self) -> bool: + return any( + input.connected_source.operation.is_constant for input in self.inputs + ) + class Division(AbstractOperation): r""" @@ -372,6 +404,10 @@ class Division(AbstractOperation): def evaluate(self, a, b): return a / b + @property + def is_linear(self) -> bool: + return self.input(1).connected_source.operation.is_constant + class Min(AbstractOperation): r""" @@ -618,6 +654,10 @@ class ConstantMultiplication(AbstractOperation): """Set the constant value of this operation.""" self.set_param("value", value) + @property + def is_linear(self) -> bool: + return True + class Butterfly(AbstractOperation): r""" @@ -660,6 +700,10 @@ class Butterfly(AbstractOperation): def evaluate(self, a, b): return a + b, a - b + @property + def is_linear(self) -> bool: + return True + class MAD(AbstractOperation): r""" @@ -699,6 +743,13 @@ class MAD(AbstractOperation): def evaluate(self, a, b, c): return a * b + c + @property + def is_linear(self) -> bool: + return ( + self.input(0).connected_source.operation.is_constant + or self.input(1).connected_source.operation.is_constant + ) + class SymmetricTwoportAdaptor(AbstractOperation): r""" @@ -751,6 +802,10 @@ class SymmetricTwoportAdaptor(AbstractOperation): """Set the constant value of this operation.""" self.set_param("value", value) + @property + def is_linear(self) -> bool: + return True + class Reciprocal(AbstractOperation): r""" diff --git a/b_asic/operation.py b/b_asic/operation.py index 77087d23107cb85e1c355716caf94ce35974c9a9..bba050c6d425fe23f7f94245b03bd1fe0c0ff4f1 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -467,6 +467,22 @@ class Operation(GraphComponent, SignalSourceProvider): def _check_all_latencies_set(self) -> None: raise NotImplementedError + @property + @abstractmethod + def is_linear(self) -> bool: + """ + Return True if the operation is linear. + """ + raise NotImplementedError + + @property + @abstractmethod + def is_constant(self) -> bool: + """ + Return True if the output of the operation is constant. + """ + raise NotImplementedError + class AbstractOperation(Operation, AbstractGraphComponent): """ @@ -1135,3 +1151,15 @@ class AbstractOperation(Operation, AbstractGraphComponent): ) for k in range(num_out) ) + + @property + def is_linear(self) -> bool: + if self.is_constant: + return True + return False + + @property + def is_constant(self) -> bool: + return all( + input.connected_source.operation.is_constant for input in self.inputs + ) diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 04a742268f87b2bd33c6354a22a3f78d8764313e..782baa69b11ed6bcfa6aa5b180cb0012a6743b6f 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -1540,3 +1540,11 @@ class SFG(AbstractOperation): assert len(ids) == len(set(ids)) return SFG(inputs=all_inputs, outputs=all_outputs) + + @property + def is_linear(self) -> bool: + return all(op.is_linear for op in self.split()) + + @property + def is_constant(self) -> bool: + return all(output.is_constant for output in self._output_operations) diff --git a/b_asic/special_operations.py b/b_asic/special_operations.py index 387e7b3f2f9739e2b250a0404416815c3ee5c1b4..6f438299b2fe69ea5f1bfaa4cc2dfb1b320ea687 100644 --- a/b_asic/special_operations.py +++ b/b_asic/special_operations.py @@ -89,6 +89,14 @@ class Input(AbstractOperation): # doc-string inherited return ((0, 0.5),) + @property + def is_constant(self) -> bool: + return False + + @property + def is_linear(self) -> bool: + return True + class Output(AbstractOperation): """ @@ -143,6 +151,10 @@ class Output(AbstractOperation): def latency(self) -> int: return self.latency_offsets["in0"] + @property + def is_linear(self) -> bool: + return True + class Delay(AbstractOperation): """ @@ -221,3 +233,7 @@ class Delay(AbstractOperation): def initial_value(self, value: Num) -> None: """Set the initial value of this delay.""" self.set_param("initial_value", value) + + @property + def is_linear(self) -> bool: + return True diff --git a/test/test_sfg.py b/test/test_sfg.py index a40a111f47d8e495d6676c37cbd898735157e603..a84a280d0037cccb92386418b6f524f820647bff 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -1604,3 +1604,19 @@ class TestUnfold: sfg = sfg_two_inputs_two_outputs with pytest.raises(ValueError, match="Unfolding 0 times removes the SFG"): sfg.unfold(0) + + +class TestIsLinear: + def test_single_accumulator(self, sfg_simple_accumulator: SFG): + assert sfg_simple_accumulator.is_linear + + def test_sfg_nested(self, sfg_nested: SFG): + assert not sfg_nested.is_linear + + +class TestIsConstant: + def test_single_accumulator(self, sfg_simple_accumulator: SFG): + assert not sfg_simple_accumulator.is_constant + + def test_sfg_nested(self, sfg_nested: SFG): + assert not sfg_nested.is_constant