From aac6c47ab27707db9fc0f237839084f3fbab9461 Mon Sep 17 00:00:00 2001
From: Oscar Gustafsson <oscar.gustafsson@gmail.com>
Date: Mon, 20 Feb 2023 12:14:12 +0100
Subject: [PATCH] Add testing for operation errors

---
 b_asic/core_operations.py          | 15 +++++++--------
 b_asic/operation.py                |  6 +++---
 b_asic/special_operations.py       | 24 ++++++++++++------------
 test/fixtures/signal_flow_graph.py |  8 ++++----
 test/test_operation.py             | 13 ++++++++++++-
 test/test_schedule.py              | 24 +++++++++++-------------
 6 files changed, 49 insertions(+), 41 deletions(-)

diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py
index d67b0242..161b979c 100644
--- a/b_asic/core_operations.py
+++ b/b_asic/core_operations.py
@@ -62,6 +62,10 @@ class Constant(AbstractOperation):
         """Set the constant value of this operation."""
         self.set_param("value", value)
 
+    @property
+    def latency(self) -> int:
+        return self.latency_offsets["out0"]
+
 
 class Addition(AbstractOperation):
     """
@@ -410,9 +414,7 @@ class Min(AbstractOperation):
 
     def evaluate(self, a, b):
         if isinstance(a, complex) or isinstance(b, complex):
-            raise ValueError(
-                "core_operations.Min does not support complex numbers."
-            )
+            raise ValueError("core_operations.Min does not support complex numbers.")
         return a if a < b else b
 
 
@@ -457,9 +459,7 @@ class Max(AbstractOperation):
 
     def evaluate(self, a, b):
         if isinstance(a, complex) or isinstance(b, complex):
-            raise ValueError(
-                "core_operations.Max does not support complex numbers."
-            )
+            raise ValueError("core_operations.Max does not support complex numbers.")
         return a if a > b else b
 
 
@@ -589,8 +589,7 @@ class ConstantMultiplication(AbstractOperation):
         latency_offsets: Optional[Dict[str, int]] = None,
         execution_time: Optional[int] = None,
     ):
-        """Construct a ConstantMultiplication operation with the given value.
-        """
+        """Construct a ConstantMultiplication operation with the given value."""
         super().__init__(
             input_count=1,
             output_count=1,
diff --git a/b_asic/operation.py b/b_asic/operation.py
index 4ad5b7f9..77087d23 100644
--- a/b_asic/operation.py
+++ b/b_asic/operation.py
@@ -988,7 +988,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
 
         if any(val is None for val in latency_offsets):
             raise ValueError(
-                "Missing latencies for inputs"
+                "Missing latencies for input(s)"
                 f" {[i for (i, latency) in enumerate(latency_offsets) if latency is None]}"
             )
 
@@ -999,8 +999,8 @@ class AbstractOperation(Operation, AbstractGraphComponent):
 
         if any(val is None for val in latency_offsets):
             raise ValueError(
-                "Missing latencies for outputs"
-                f" {[i for i in latency_offsets if i is not None]}"
+                "Missing latencies for output(s)"
+                f" {[i for (i, latency) in enumerate(latency_offsets) if latency is None]}"
             )
 
         return cast(List[int], latency_offsets)
diff --git a/b_asic/special_operations.py b/b_asic/special_operations.py
index 35b78efa..387e7b3f 100644
--- a/b_asic/special_operations.py
+++ b/b_asic/special_operations.py
@@ -44,6 +44,10 @@ class Input(AbstractOperation):
     def evaluate(self):
         return self.param("value")
 
+    @property
+    def latency(self) -> int:
+        return self.latency_offsets["out0"]
+
     @property
     def value(self) -> Num:
         """Get the current value of this input."""
@@ -56,9 +60,7 @@ class Input(AbstractOperation):
 
     def get_plot_coordinates(
         self,
-    ) -> Tuple[
-        Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]
-    ]:
+    ) -> Tuple[Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]]:
         # Doc-string inherited
         return (
             (
@@ -122,9 +124,7 @@ class Output(AbstractOperation):
 
     def get_plot_coordinates(
         self,
-    ) -> Tuple[
-        Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]
-    ]:
+    ) -> Tuple[Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]]:
         # Doc-string inherited
         return (
             ((0, 0), (0, 1), (0.25, 1), (0.5, 0.5), (0.25, 0), (0, 0)),
@@ -139,6 +139,10 @@ class Output(AbstractOperation):
         # doc-string inherited
         return tuple()
 
+    @property
+    def latency(self) -> int:
+        return self.latency_offsets["in0"]
+
 
 class Delay(AbstractOperation):
     """
@@ -174,9 +178,7 @@ class Delay(AbstractOperation):
         self, index: int, delays: Optional[DelayMap] = None, prefix: str = ""
     ) -> Optional[Num]:
         if delays is not None:
-            return delays.get(
-                self.key(index, prefix), self.param("initial_value")
-            )
+            return delays.get(self.key(index, prefix), self.param("initial_value"))
         return self.param("initial_value")
 
     def evaluate_output(
@@ -190,9 +192,7 @@ class Delay(AbstractOperation):
         truncate: bool = True,
     ) -> Num:
         if index != 0:
-            raise IndexError(
-                f"Output index out of range (expected 0-0, got {index})"
-            )
+            raise IndexError(f"Output index out of range (expected 0-0, got {index})")
         if len(input_values) != 1:
             raise ValueError(
                 "Wrong number of inputs supplied to SFG for evaluation"
diff --git a/test/fixtures/signal_flow_graph.py b/test/fixtures/signal_flow_graph.py
index 321e7523..c8da0eeb 100644
--- a/test/fixtures/signal_flow_graph.py
+++ b/test/fixtures/signal_flow_graph.py
@@ -94,10 +94,10 @@ def sfg_two_inputs_two_outputs_independent_with_cmul():
     in1 = Input("IN1")
     in2 = Input("IN2")
     c1 = Constant(3, "C1")
-    add1 = Addition(in2, c1, "ADD1", 7)
-    cmul3 = ConstantMultiplication(2, add1, "CMUL3", 3)
-    cmul1 = ConstantMultiplication(5, in1, "CMUL1", 5)
-    cmul2 = ConstantMultiplication(4, cmul1, "CMUL2", 4)
+    add1 = Addition(in2, c1, "ADD1", 7, execution_time=2)
+    cmul3 = ConstantMultiplication(2, add1, "CMUL3", 3, execution_time=1)
+    cmul1 = ConstantMultiplication(5, in1, "CMUL1", 5, execution_time=3)
+    cmul2 = ConstantMultiplication(4, cmul1, "CMUL2", 4, execution_time=1)
     out1 = Output(cmul2, "OUT1")
     out2 = Output(cmul3, "OUT2")
     return SFG(inputs=[in1, in2], outputs=[out1, out2])
diff --git a/test/test_operation.py b/test/test_operation.py
index 0ce438b6..b08e3907 100644
--- a/test/test_operation.py
+++ b/test/test_operation.py
@@ -204,6 +204,10 @@ class TestLatency:
             "out1": 9,
         }
 
+    def test_set_latency_negative(self):
+        with pytest.raises(ValueError, match="Latency cannot be negative"):
+            Butterfly(latency=-1)
+
 
 class TestExecutionTime:
     def test_execution_time_constructor(self):
@@ -292,9 +296,16 @@ class TestIOCoordinates:
         bfly = Butterfly()
 
         bfly.set_latency_offsets({"in0": 3, "out1": 5})
-        with pytest.raises(ValueError, match="Missing latencies for inputs \\[1\\]"):
+        with pytest.raises(
+            ValueError, match="Missing latencies for input\\(s\\) \\[1\\]"
+        ):
             bfly.get_input_coordinates()
 
+        with pytest.raises(
+            ValueError, match="Missing latencies for output\\(s\\) \\[0\\]"
+        ):
+            bfly.get_output_coordinates()
+
 
 class TestSplit:
     def test_simple_case(self):
diff --git a/test/test_schedule.py b/test/test_schedule.py
index dcc04c80..0f2c412c 100644
--- a/test/test_schedule.py
+++ b/test/test_schedule.py
@@ -410,21 +410,19 @@ class TestTimeResolution:
 
         start_times_names = {}
         for op_id, start_time in schedule._start_times.items():
-            op_name = sfg_two_inputs_two_outputs_independent_with_cmul.find_by_id(
-                op_id
-            ).name
-            start_times_names[op_name] = start_time
+            op = sfg_two_inputs_two_outputs_independent_with_cmul.find_by_id(op_id)
+            start_times_names[op.name] = (start_time, op.latency, op.execution_time)
 
         assert start_times_names == {
-            "C1": 0,
-            "IN1": 0,
-            "IN2": 0,
-            "CMUL1": 0,
-            "CMUL2": 30,
-            "ADD1": 0,
-            "CMUL3": 42,
-            "OUT1": 54,
-            "OUT2": 60,
+            "C1": (0, 0, None),
+            "IN1": (0, 0, None),
+            "IN2": (0, 0, None),
+            "CMUL1": (0, 30, 18),
+            "CMUL2": (30, 24, 6),
+            "ADD1": (0, 42, 12),
+            "CMUL3": (42, 18, 6),
+            "OUT1": (54, 0, None),
+            "OUT2": (60, 0, None),
         }
 
         assert 6 * old_schedule_time == schedule.schedule_time
-- 
GitLab