From 6733a62700280112d67c7b793893678eb9b02eba Mon Sep 17 00:00:00 2001
From: Oscar Gustafsson <oscar.gustafsson@gmail.com>
Date: Mon, 30 Jan 2023 13:40:54 +0100
Subject: [PATCH] Fix typing and add mypy config

---
 b_asic/graph_component.py |  3 ++-
 b_asic/operation.py       | 38 +++++++++++++++++++++++++++-----------
 pyproject.toml            |  6 ++++++
 3 files changed, 35 insertions(+), 12 deletions(-)

diff --git a/b_asic/graph_component.py b/b_asic/graph_component.py
index 28deab5a..1e76ce02 100644
--- a/b_asic/graph_component.py
+++ b/b_asic/graph_component.py
@@ -7,7 +7,7 @@ Contains the base for all components with an ID in a signal flow graph.
 from abc import ABC, abstractmethod
 from collections import deque
 from copy import copy, deepcopy
-from typing import Any, Dict, Generator, Iterable, Mapping, NewType
+from typing import Any, Dict, Generator, Iterable, Mapping, NewType, cast
 
 Name = NewType("Name", str)
 TypeName = NewType("TypeName", str)
@@ -176,6 +176,7 @@ class AbstractGraphComponent(GraphComponent):
             component = fontier.popleft()
             yield component
             for neighbor in component.neighbors:
+                neighbor = cast(AbstractGraphComponent, neighbor)
                 if neighbor not in visited:
                     visited.add(neighbor)
                     fontier.append(neighbor)
diff --git a/b_asic/operation.py b/b_asic/operation.py
index 3da6e4ab..aa84a0ed 100644
--- a/b_asic/operation.py
+++ b/b_asic/operation.py
@@ -21,6 +21,7 @@ from typing import (
     Sequence,
     Tuple,
     Union,
+    cast,
 )
 
 from b_asic.graph_component import (
@@ -324,7 +325,7 @@ class Operation(GraphComponent, SignalSourceProvider):
 
     @property
     @abstractmethod
-    def latency_offsets(self) -> Dict[str, int]:
+    def latency_offsets(self) -> Dict[str, Optional[int]]:
         """
         Get a dictionary with all the operations ports latency-offsets.
         """
@@ -569,7 +570,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
 
     def __str__(self) -> str:
         """Get a string representation of this operation."""
-        inputs_dict = {}
+        inputs_dict: Dict[int, Union[List[GraphID], str]] = {}
         for i, inport in enumerate(self.inputs):
             if inport.signal_count == 0:
                 inputs_dict[i] = "-"
@@ -588,7 +589,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
                         dict_ele.append(GraphID("no_id"))
             inputs_dict[i] = dict_ele
 
-        outputs_dict = {}
+        outputs_dict: Dict[int, Union[List[GraphID], str]] = {}
         for i, outport in enumerate(self.outputs):
             if outport.signal_count == 0:
                 outputs_dict[i] = "-"
@@ -778,7 +779,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
                 last_operations = [last_operations]
             outputs = [Output(o) for o in last_operations]
         except TypeError:
-            operation_copy: Operation = self.copy_component()
+            operation_copy: Operation = cast(Operation, self.copy_component())
             inputs = []
             for i in range(self.input_count):
                 _input = Input()
@@ -790,7 +791,9 @@ class AbstractOperation(Operation, AbstractGraphComponent):
         return SFG(inputs=inputs, outputs=outputs)
 
     def copy_component(self, *args, **kwargs) -> GraphComponent:
-        new_component: Operation = super().copy_component(*args, **kwargs)
+        new_component: Operation = cast(
+            Operation, super().copy_component(*args, **kwargs)
+        )
         for i, inp in enumerate(self.inputs):
             new_component.input(i).latency_offset = inp.latency_offset
         for i, outp in enumerate(self.outputs):
@@ -885,13 +888,16 @@ class AbstractOperation(Operation, AbstractGraphComponent):
 
         return max(
             (
-                (outp.latency_offset - inp.latency_offset)
+                (
+                    cast(int, outp.latency_offset)
+                    - cast(int, inp.latency_offset)
+                )
                 for outp, inp in it.product(self.outputs, self.inputs)
             )
         )
 
     @property
-    def latency_offsets(self) -> Dict[str, int]:
+    def latency_offsets(self) -> Dict[str, Optional[int]]:
         latency_offsets = {}
 
         for i, inp in enumerate(self.inputs):
@@ -948,13 +954,15 @@ class AbstractOperation(Operation, AbstractGraphComponent):
         if self._execution_time is not None:
             self._execution_time *= factor
         for port in [*self.inputs, *self.outputs]:
-            port.latency_offset *= factor
+            if port.latency_offset is not None:
+                port.latency_offset *= factor
 
     def _decrease_time_resolution(self, factor: int) -> None:
         if self._execution_time is not None:
             self._execution_time = self._execution_time // factor
         for port in [*self.inputs, *self.outputs]:
-            port.latency_offset = port.latency_offset // factor
+            if port.latency_offset is not None:
+                port.latency_offset = port.latency_offset // factor
 
     def get_plot_coordinates(
         self,
@@ -978,11 +986,18 @@ class AbstractOperation(Operation, AbstractGraphComponent):
             [0, 0],
         ]
 
+    def _check_all_latencies_set(self):
+        if any(val is None for _, val in self.latency_offsets.items()):
+            raise ValueError(
+                f"All latencies must be set: {self.latency_offsets}"
+            )
+
     def _get_plot_coordinates_for_latency(self) -> List[List[float]]:
+        self._check_all_latencies_set()
         # Points for latency polygon
         latency = []
         # Remember starting point
-        start_point = [self.inputs[0].latency_offset, 0]
+        start_point = [self.inputs[0].latency_offset, 0.0]
         num_in = self.input_count
         latency.append(start_point)
         for k in range(1, num_in):
@@ -995,7 +1010,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
         for k in reversed(range(1, num_out)):
             latency.append([self.outputs[k].latency_offset, k / num_out])
             latency.append([self.outputs[k - 1].latency_offset, k / num_out])
-        latency.append([self.outputs[0].latency_offset, 0])
+        latency.append([self.outputs[0].latency_offset, 0.0])
         # Close the polygon
         latency.append(start_point)
 
@@ -1004,6 +1019,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
     def get_io_coordinates(
         self,
     ) -> Tuple[List[List[float]], List[List[float]]]:
+        self._check_all_latencies_set()
         # Doc-string inherited
         input_coords = [
             [
diff --git a/pyproject.toml b/pyproject.toml
index 4f5b7930..32b7026b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -66,3 +66,9 @@ src_paths = ["b_asic", "test"]
 skip = [
     "test/test_gui"
 ]
+
+
+[tool.mypy]
+packages = ["b_asic", "test"]
+no_site_packages = true
+ignore_missing_imports = true
-- 
GitLab