From 8ea03312ca09a6ebdce801baac3ed6587c0eaa85 Mon Sep 17 00:00:00 2001
From: Oscar Gustafsson <oscar.gustafsson@gmail.com>
Date: Tue, 21 Feb 2023 09:52:03 +0100
Subject: [PATCH] Add is_linear and is_constant properties

---
 b_asic/core_operations.py    | 55 ++++++++++++++++++++++++++++++++++++
 b_asic/operation.py          | 28 ++++++++++++++++++
 b_asic/signal_flow_graph.py  |  8 ++++++
 b_asic/special_operations.py | 16 +++++++++++
 test/test_sfg.py             | 16 +++++++++++
 5 files changed, 123 insertions(+)

diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py
index 161b979c..34b6c7d0 100644
--- a/b_asic/core_operations.py
+++ b/b_asic/core_operations.py
@@ -66,6 +66,20 @@ class Constant(AbstractOperation):
     def latency(self) -> int:
         return self.latency_offsets["out0"]
 
+    def __repr__(self) -> str:
+        return f"Constant({self.value})"
+
+    def __str__(self) -> str:
+        return f"{self.value}"
+
+    @property
+    def is_linear(self) -> bool:
+        return True
+
+    @property
+    def is_constant(self) -> bool:
+        return True
+
 
 class Addition(AbstractOperation):
     """
@@ -129,6 +143,10 @@ class Addition(AbstractOperation):
     def evaluate(self, a, b):
         return a + b
 
+    @property
+    def is_linear(self) -> bool:
+        return True
+
 
 class Subtraction(AbstractOperation):
     """
@@ -189,6 +207,10 @@ class Subtraction(AbstractOperation):
     def evaluate(self, a, b):
         return a - b
 
+    @property
+    def is_linear(self) -> bool:
+        return True
+
 
 class AddSub(AbstractOperation):
     r"""
@@ -270,6 +292,10 @@ class AddSub(AbstractOperation):
         """Set if operation is an addition."""
         self.set_param("is_add", is_add)
 
+    @property
+    def is_linear(self) -> bool:
+        return True
+
 
 class Multiplication(AbstractOperation):
     r"""
@@ -331,6 +357,12 @@ class Multiplication(AbstractOperation):
     def evaluate(self, a, b):
         return a * b
 
+    @property
+    def is_linear(self) -> bool:
+        return any(
+            input.connected_source.operation.is_constant for input in self.inputs
+        )
+
 
 class Division(AbstractOperation):
     r"""
@@ -372,6 +404,10 @@ class Division(AbstractOperation):
     def evaluate(self, a, b):
         return a / b
 
+    @property
+    def is_linear(self) -> bool:
+        return self.input(1).connected_source.operation.is_constant
+
 
 class Min(AbstractOperation):
     r"""
@@ -618,6 +654,10 @@ class ConstantMultiplication(AbstractOperation):
         """Set the constant value of this operation."""
         self.set_param("value", value)
 
+    @property
+    def is_linear(self) -> bool:
+        return True
+
 
 class Butterfly(AbstractOperation):
     r"""
@@ -660,6 +700,10 @@ class Butterfly(AbstractOperation):
     def evaluate(self, a, b):
         return a + b, a - b
 
+    @property
+    def is_linear(self) -> bool:
+        return True
+
 
 class MAD(AbstractOperation):
     r"""
@@ -699,6 +743,13 @@ class MAD(AbstractOperation):
     def evaluate(self, a, b, c):
         return a * b + c
 
+    @property
+    def is_linear(self) -> bool:
+        return (
+            self.input(0).connected_source.operation.is_constant
+            or self.input(1).connected_source.operation.is_constant
+        )
+
 
 class SymmetricTwoportAdaptor(AbstractOperation):
     r"""
@@ -751,6 +802,10 @@ class SymmetricTwoportAdaptor(AbstractOperation):
         """Set the constant value of this operation."""
         self.set_param("value", value)
 
+    @property
+    def is_linear(self) -> bool:
+        return True
+
 
 class Reciprocal(AbstractOperation):
     r"""
diff --git a/b_asic/operation.py b/b_asic/operation.py
index 77087d23..bba050c6 100644
--- a/b_asic/operation.py
+++ b/b_asic/operation.py
@@ -467,6 +467,22 @@ class Operation(GraphComponent, SignalSourceProvider):
     def _check_all_latencies_set(self) -> None:
         raise NotImplementedError
 
+    @property
+    @abstractmethod
+    def is_linear(self) -> bool:
+        """
+        Return True if the operation is linear.
+        """
+        raise NotImplementedError
+
+    @property
+    @abstractmethod
+    def is_constant(self) -> bool:
+        """
+        Return True if the output of the operation is constant.
+        """
+        raise NotImplementedError
+
 
 class AbstractOperation(Operation, AbstractGraphComponent):
     """
@@ -1135,3 +1151,15 @@ class AbstractOperation(Operation, AbstractGraphComponent):
             )
             for k in range(num_out)
         )
+
+    @property
+    def is_linear(self) -> bool:
+        if self.is_constant:
+            return True
+        return False
+
+    @property
+    def is_constant(self) -> bool:
+        return all(
+            input.connected_source.operation.is_constant for input in self.inputs
+        )
diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py
index 04a74226..782baa69 100644
--- a/b_asic/signal_flow_graph.py
+++ b/b_asic/signal_flow_graph.py
@@ -1540,3 +1540,11 @@ class SFG(AbstractOperation):
         assert len(ids) == len(set(ids))
 
         return SFG(inputs=all_inputs, outputs=all_outputs)
+
+    @property
+    def is_linear(self) -> bool:
+        return all(op.is_linear for op in self.split())
+
+    @property
+    def is_constant(self) -> bool:
+        return all(output.is_constant for output in self._output_operations)
diff --git a/b_asic/special_operations.py b/b_asic/special_operations.py
index 387e7b3f..6f438299 100644
--- a/b_asic/special_operations.py
+++ b/b_asic/special_operations.py
@@ -89,6 +89,14 @@ class Input(AbstractOperation):
         # doc-string inherited
         return ((0, 0.5),)
 
+    @property
+    def is_constant(self) -> bool:
+        return False
+
+    @property
+    def is_linear(self) -> bool:
+        return True
+
 
 class Output(AbstractOperation):
     """
@@ -143,6 +151,10 @@ class Output(AbstractOperation):
     def latency(self) -> int:
         return self.latency_offsets["in0"]
 
+    @property
+    def is_linear(self) -> bool:
+        return True
+
 
 class Delay(AbstractOperation):
     """
@@ -221,3 +233,7 @@ class Delay(AbstractOperation):
     def initial_value(self, value: Num) -> None:
         """Set the initial value of this delay."""
         self.set_param("initial_value", value)
+
+    @property
+    def is_linear(self) -> bool:
+        return True
diff --git a/test/test_sfg.py b/test/test_sfg.py
index a40a111f..a84a280d 100644
--- a/test/test_sfg.py
+++ b/test/test_sfg.py
@@ -1604,3 +1604,19 @@ class TestUnfold:
         sfg = sfg_two_inputs_two_outputs
         with pytest.raises(ValueError, match="Unfolding 0 times removes the SFG"):
             sfg.unfold(0)
+
+
+class TestIsLinear:
+    def test_single_accumulator(self, sfg_simple_accumulator: SFG):
+        assert sfg_simple_accumulator.is_linear
+
+    def test_sfg_nested(self, sfg_nested: SFG):
+        assert not sfg_nested.is_linear
+
+
+class TestIsConstant:
+    def test_single_accumulator(self, sfg_simple_accumulator: SFG):
+        assert not sfg_simple_accumulator.is_constant
+
+    def test_sfg_nested(self, sfg_nested: SFG):
+        assert not sfg_nested.is_constant
-- 
GitLab