From 45c7bb7836384c73b13da46998cf61b6c3568470 Mon Sep 17 00:00:00 2001
From: Frans Skarman <frans.skarman@liu.se>
Date: Tue, 14 Feb 2023 10:43:12 +0000
Subject: [PATCH] Fix type errors

---
 b_asic/core_operations.py    |  24 ++--
 b_asic/graph_component.py    |   7 +-
 b_asic/operation.py          | 212 +++++++++++++++++++----------------
 b_asic/port.py               |   2 +-
 b_asic/special_operations.py |  22 ++--
 b_asic/types.py              |  22 ++++
 test/test_operation.py       |   4 +-
 7 files changed, 169 insertions(+), 124 deletions(-)
 create mode 100644 b_asic/types.py

diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py
index b4c9642b..aa9e33d4 100644
--- a/b_asic/core_operations.py
+++ b/b_asic/core_operations.py
@@ -4,7 +4,6 @@ B-ASIC Core Operations Module.
 Contains some of the most commonly used mathematical operations.
 """
 
-from numbers import Number
 from typing import Dict, Optional
 
 from numpy import abs as np_abs
@@ -13,6 +12,7 @@ from numpy import conjugate, sqrt
 from b_asic.graph_component import Name, TypeName
 from b_asic.operation import AbstractOperation
 from b_asic.port import SignalSourceProvider
+from b_asic.types import Num
 
 
 class Constant(AbstractOperation):
@@ -35,12 +35,12 @@ class Constant(AbstractOperation):
 
     _execution_time = 0
 
-    def __init__(self, value: Number = 0, name: Name = Name("")):
+    def __init__(self, value: Num = 0, name: Name = ""):
         """Construct a Constant operation with the given value."""
         super().__init__(
             input_count=0,
             output_count=1,
-            name=Name(name),
+            name=name,
             latency_offsets={"out0": 0},
         )
         self.set_param("value", value)
@@ -53,12 +53,12 @@ class Constant(AbstractOperation):
         return self.param("value")
 
     @property
-    def value(self) -> Number:
+    def value(self) -> Num:
         """Get the constant value of this operation."""
         return self.param("value")
 
     @value.setter
-    def value(self, value: Number) -> None:
+    def value(self, value: Num) -> None:
         """Set the constant value of this operation."""
         self.set_param("value", value)
 
@@ -257,7 +257,7 @@ class AddSub(AbstractOperation):
         return a + b if self.is_add else a - b
 
     @property
-    def is_add(self) -> Number:
+    def is_add(self) -> Num:
         """Get if operation is add."""
         return self.param("is_add")
 
@@ -582,7 +582,7 @@ class ConstantMultiplication(AbstractOperation):
 
     def __init__(
         self,
-        value: Number = 0,
+        value: Num = 0,
         src0: Optional[SignalSourceProvider] = None,
         name: Name = Name(""),
         latency: Optional[int] = None,
@@ -610,12 +610,12 @@ class ConstantMultiplication(AbstractOperation):
         return a * self.param("value")
 
     @property
-    def value(self) -> Number:
+    def value(self) -> Num:
         """Get the constant value of this operation."""
         return self.param("value")
 
     @value.setter
-    def value(self, value: Number) -> None:
+    def value(self, value: Num) -> None:
         """Set the constant value of this operation."""
         self.set_param("value", value)
 
@@ -714,7 +714,7 @@ class SymmetricTwoportAdaptor(AbstractOperation):
 
     def __init__(
         self,
-        value: Number = 0,
+        value: Num = 0,
         src0: Optional[SignalSourceProvider] = None,
         src1: Optional[SignalSourceProvider] = None,
         name: Name = Name(""),
@@ -743,12 +743,12 @@ class SymmetricTwoportAdaptor(AbstractOperation):
         return b + tmp, a + tmp
 
     @property
-    def value(self) -> Number:
+    def value(self) -> Num:
         """Get the constant value of this operation."""
         return self.param("value")
 
     @value.setter
-    def value(self, value: Number) -> None:
+    def value(self, value: Num) -> None:
         """Set the constant value of this operation."""
         self.set_param("value", value)
 
diff --git a/b_asic/graph_component.py b/b_asic/graph_component.py
index 86018c8d..1f910c30 100644
--- a/b_asic/graph_component.py
+++ b/b_asic/graph_component.py
@@ -7,12 +7,9 @@ Contains the base for all components with an ID in a signal flow graph.
 from abc import ABC, abstractmethod
 from collections import deque
 from copy import copy, deepcopy
-from typing import Any, Dict, Generator, Iterable, Mapping, NewType, cast
+from typing import Any, Dict, Generator, Iterable, Mapping, cast
 
-Name = NewType("Name", str)
-TypeName = NewType("TypeName", str)
-GraphID = NewType("GraphID", str)
-GraphIDNumber = NewType("GraphIDNumber", int)
+from b_asic.types import GraphID, GraphIDNumber, Name, Num, TypeName
 
 
 class GraphComponent(ABC):
diff --git a/b_asic/operation.py b/b_asic/operation.py
index e39e9d69..8b8251dc 100644
--- a/b_asic/operation.py
+++ b/b_asic/operation.py
@@ -5,6 +5,7 @@ Contains the base for operations that are used by B-ASIC.
 """
 
 import collections
+import collections.abc
 import itertools as it
 from abc import abstractmethod
 from numbers import Number
@@ -22,6 +23,7 @@ from typing import (
     Tuple,
     Union,
     cast,
+    overload,
 )
 
 from b_asic.graph_component import (
@@ -32,6 +34,7 @@ from b_asic.graph_component import (
 )
 from b_asic.port import InputPort, OutputPort, SignalSourceProvider
 from b_asic.signal import Signal
+from b_asic.types import Num, NumRuntime
 
 if TYPE_CHECKING:
     # Conditionally imported to avoid circular imports
@@ -47,10 +50,10 @@ if TYPE_CHECKING:
 
 
 ResultKey = NewType("ResultKey", str)
-ResultMap = Mapping[ResultKey, Optional[Number]]
-MutableResultMap = MutableMapping[ResultKey, Optional[Number]]
-DelayMap = Mapping[ResultKey, Number]
-MutableDelayMap = MutableMapping[ResultKey, Number]
+ResultMap = Mapping[ResultKey, Optional[Num]]
+MutableResultMap = MutableMapping[ResultKey, Optional[Num]]
+DelayMap = Mapping[ResultKey, Num]
+MutableDelayMap = MutableMapping[ResultKey, Num]
 
 
 class Operation(GraphComponent, SignalSourceProvider):
@@ -66,7 +69,7 @@ class Operation(GraphComponent, SignalSourceProvider):
     """
 
     @abstractmethod
-    def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Addition":
+    def __add__(self, src: Union[SignalSourceProvider, Num]) -> "Addition":
         """
         Overloads the addition operator to make it return a new Addition operation
         object that is connected to the self and other objects.
@@ -74,7 +77,7 @@ class Operation(GraphComponent, SignalSourceProvider):
         raise NotImplementedError
 
     @abstractmethod
-    def __radd__(self, src: Union[SignalSourceProvider, Number]) -> "Addition":
+    def __radd__(self, src: Union[SignalSourceProvider, Num]) -> "Addition":
         """
         Overloads the addition operator to make it return a new Addition operation
         object that is connected to the self and other objects.
@@ -82,9 +85,7 @@ class Operation(GraphComponent, SignalSourceProvider):
         raise NotImplementedError
 
     @abstractmethod
-    def __sub__(
-        self, src: Union[SignalSourceProvider, Number]
-    ) -> "Subtraction":
+    def __sub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction":
         """
         Overloads the subtraction operator to make it return a new Subtraction
         operation object that is connected to the self and other objects.
@@ -92,9 +93,7 @@ class Operation(GraphComponent, SignalSourceProvider):
         raise NotImplementedError
 
     @abstractmethod
-    def __rsub__(
-        self, src: Union[SignalSourceProvider, Number]
-    ) -> "Subtraction":
+    def __rsub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction":
         """
         Overloads the subtraction operator to make it return a new Subtraction
         operation object that is connected to the self and other objects.
@@ -103,7 +102,7 @@ class Operation(GraphComponent, SignalSourceProvider):
 
     @abstractmethod
     def __mul__(
-        self, src: Union[SignalSourceProvider, Number]
+        self, src: Union[SignalSourceProvider, Num]
     ) -> Union["Multiplication", "ConstantMultiplication"]:
         """
         Overloads the multiplication operator to make it return a new Multiplication
@@ -116,7 +115,7 @@ class Operation(GraphComponent, SignalSourceProvider):
 
     @abstractmethod
     def __rmul__(
-        self, src: Union[SignalSourceProvider, Number]
+        self, src: Union[SignalSourceProvider, Num]
     ) -> Union["Multiplication", "ConstantMultiplication"]:
         """
         Overloads the multiplication operator to make it return a new Multiplication
@@ -128,9 +127,7 @@ class Operation(GraphComponent, SignalSourceProvider):
         raise NotImplementedError
 
     @abstractmethod
-    def __truediv__(
-        self, src: Union[SignalSourceProvider, Number]
-    ) -> "Division":
+    def __truediv__(self, src: Union[SignalSourceProvider, Num]) -> "Division":
         """
         Overloads the division operator to make it return a new Division operation
         object that is connected to the self and other objects.
@@ -139,7 +136,7 @@ class Operation(GraphComponent, SignalSourceProvider):
 
     @abstractmethod
     def __rtruediv__(
-        self, src: Union[SignalSourceProvider, Number]
+        self, src: Union[SignalSourceProvider, Num]
     ) -> Union["Division", "Reciprocal"]:
         """
         Overloads the division operator to make it return a new Division operation
@@ -219,7 +216,7 @@ class Operation(GraphComponent, SignalSourceProvider):
     @abstractmethod
     def current_output(
         self, index: int, delays: Optional[DelayMap] = None, prefix: str = ""
-    ) -> Optional[Number]:
+    ) -> Optional[Num]:
         """
         Get the current output at the given index of this operation, if available.
 
@@ -238,13 +235,13 @@ class Operation(GraphComponent, SignalSourceProvider):
     def evaluate_output(
         self,
         index: int,
-        input_values: Sequence[Number],
+        input_values: Sequence[Num],
         results: Optional[MutableResultMap] = None,
         delays: Optional[MutableDelayMap] = None,
         prefix: str = "",
         bits_override: Optional[int] = None,
         truncate: bool = True,
-    ) -> Number:
+    ) -> Num:
         """
         Evaluate the output at the given index of this operation with the given input
         values.
@@ -281,7 +278,7 @@ class Operation(GraphComponent, SignalSourceProvider):
     @abstractmethod
     def current_outputs(
         self, delays: Optional[DelayMap] = None, prefix: str = ""
-    ) -> Sequence[Optional[Number]]:
+    ) -> Sequence[Optional[Num]]:
         """
         Get all current outputs of this operation, if available.
 
@@ -294,13 +291,13 @@ class Operation(GraphComponent, SignalSourceProvider):
     @abstractmethod
     def evaluate_outputs(
         self,
-        input_values: Sequence[Number],
+        input_values: Sequence[Num],
         results: Optional[MutableResultMap] = None,
         delays: Optional[MutableDelayMap] = None,
         prefix: str = "",
         bits_override: Optional[int] = None,
         truncate: bool = True,
-    ) -> Sequence[Number]:
+    ) -> Sequence[Num]:
         """
         Evaluate all outputs of this operation given the input values.
         See evaluate_output for more information.
@@ -336,7 +333,7 @@ class Operation(GraphComponent, SignalSourceProvider):
         raise NotImplementedError
 
     @abstractmethod
-    def truncate_input(self, index: int, value: Number, bits: int) -> Number:
+    def truncate_input(self, index: int, value: Num, bits: int) -> Num:
         """
         Truncate the value to be used as input at the given index to a certain bit
         length.
@@ -397,7 +394,7 @@ class Operation(GraphComponent, SignalSourceProvider):
 
     @execution_time.setter
     @abstractmethod
-    def execution_time(self, latency: int) -> None:
+    def execution_time(self, latency: Optional[int]) -> None:
         """
         Sets the execution time of the operation to the specified integer
         value. The execution time cannot be a negative integer.
@@ -570,52 +567,63 @@ class AbstractOperation(Operation, AbstractGraphComponent):
 
         self._execution_time = execution_time
 
+    @overload
     @abstractmethod
-    def evaluate(self, *inputs) -> Any:  # pylint: disable=arguments-differ
+    def evaluate(
+        self, *inputs: Operation
+    ) -> List[Operation]:  # pylint: disable=arguments-differ
+        ...
+
+    @overload
+    @abstractmethod
+    def evaluate(
+        self, *inputs: Num
+    ) -> List[Num]:  # pylint: disable=arguments-differ
+        ...
+
+    @abstractmethod
+    def evaluate(self, *inputs):  # pylint: disable=arguments-differ
         """
         Evaluate the operation and generate a list of output values given a
         list of input values.
         """
         raise NotImplementedError
 
-    def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Addition":
+    def __add__(self, src: Union[SignalSourceProvider, Num]) -> "Addition":
         # Import here to avoid circular imports.
         from b_asic.core_operations import Addition, Constant
 
-        return Addition(
-            self, Constant(src) if isinstance(src, Number) else src
-        )
+        if isinstance(src, NumRuntime):
+            return Addition(self, Constant(src))
+        else:
+            return Addition(self, src)
 
-    def __radd__(self, src: Union[SignalSourceProvider, Number]) -> "Addition":
+    def __radd__(self, src: Union[SignalSourceProvider, Num]) -> "Addition":
         # Import here to avoid circular imports.
         from b_asic.core_operations import Addition, Constant
 
         return Addition(
-            Constant(src) if isinstance(src, Number) else src, self
+            Constant(src) if isinstance(src, NumRuntime) else src, self
         )
 
-    def __sub__(
-        self, src: Union[SignalSourceProvider, Number]
-    ) -> "Subtraction":
+    def __sub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction":
         # Import here to avoid circular imports.
         from b_asic.core_operations import Constant, Subtraction
 
         return Subtraction(
-            self, Constant(src) if isinstance(src, Number) else src
+            self, Constant(src) if isinstance(src, NumRuntime) else src
         )
 
-    def __rsub__(
-        self, src: Union[SignalSourceProvider, Number]
-    ) -> "Subtraction":
+    def __rsub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction":
         # Import here to avoid circular imports.
         from b_asic.core_operations import Constant, Subtraction
 
         return Subtraction(
-            Constant(src) if isinstance(src, Number) else src, self
+            Constant(src) if isinstance(src, NumRuntime) else src, self
         )
 
     def __mul__(
-        self, src: Union[SignalSourceProvider, Number]
+        self, src: Union[SignalSourceProvider, Num]
     ) -> Union["Multiplication", "ConstantMultiplication"]:
         # Import here to avoid circular imports.
         from b_asic.core_operations import (
@@ -625,12 +633,12 @@ class AbstractOperation(Operation, AbstractGraphComponent):
 
         return (
             ConstantMultiplication(src, self)
-            if isinstance(src, Number)
+            if isinstance(src, NumRuntime)
             else Multiplication(self, src)
         )
 
     def __rmul__(
-        self, src: Union[SignalSourceProvider, Number]
+        self, src: Union[SignalSourceProvider, Num]
     ) -> Union["Multiplication", "ConstantMultiplication"]:
         # Import here to avoid circular imports.
         from b_asic.core_operations import (
@@ -640,27 +648,25 @@ class AbstractOperation(Operation, AbstractGraphComponent):
 
         return (
             ConstantMultiplication(src, self)
-            if isinstance(src, Number)
+            if isinstance(src, NumRuntime)
             else Multiplication(src, self)
         )
 
-    def __truediv__(
-        self, src: Union[SignalSourceProvider, Number]
-    ) -> "Division":
+    def __truediv__(self, src: Union[SignalSourceProvider, Num]) -> "Division":
         # Import here to avoid circular imports.
         from b_asic.core_operations import Constant, Division
 
         return Division(
-            self, Constant(src) if isinstance(src, Number) else src
+            self, Constant(src) if isinstance(src, NumRuntime) else src
         )
 
     def __rtruediv__(
-        self, src: Union[SignalSourceProvider, Number]
+        self, src: Union[SignalSourceProvider, Num]
     ) -> Union["Division", "Reciprocal"]:
         # Import here to avoid circular imports.
         from b_asic.core_operations import Constant, Division, Reciprocal
 
-        if isinstance(src, Number):
+        if isinstance(src, NumRuntime):
             if src == 1:
                 return Reciprocal(self)
             else:
@@ -771,19 +777,19 @@ class AbstractOperation(Operation, AbstractGraphComponent):
 
     def current_output(
         self, index: int, delays: Optional[DelayMap] = None, prefix: str = ""
-    ) -> Optional[Number]:
+    ) -> Optional[Num]:
         return None
 
     def evaluate_output(
         self,
         index: int,
-        input_values: Sequence[Number],
+        input_values: Sequence[Num],
         results: Optional[MutableResultMap] = None,
         delays: Optional[MutableDelayMap] = None,
         prefix: str = "",
         bits_override: Optional[int] = None,
         truncate: bool = True,
-    ) -> Number:
+    ) -> Num:
         if index < 0 or index >= self.output_count:
             raise IndexError(
                 "Output index out of range (expected"
@@ -828,7 +834,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
 
     def current_outputs(
         self, delays: Optional[DelayMap] = None, prefix: str = ""
-    ) -> Sequence[Optional[Number]]:
+    ) -> Sequence[Optional[Num]]:
         return [
             self.current_output(i, delays, prefix)
             for i in range(self.output_count)
@@ -836,13 +842,13 @@ class AbstractOperation(Operation, AbstractGraphComponent):
 
     def evaluate_outputs(
         self,
-        input_values: Sequence[Number],
+        input_values: Sequence[Num],
         results: Optional[MutableResultMap] = None,
         delays: Optional[MutableDelayMap] = None,
         prefix: str = "",
         bits_override: Optional[int] = None,
         truncate: bool = True,
-    ) -> Sequence[Number]:
+    ) -> Sequence[Num]:
         return [
             self.evaluate_output(
                 i,
@@ -860,18 +866,11 @@ class AbstractOperation(Operation, AbstractGraphComponent):
         # Import here to avoid circular imports.
         from b_asic.special_operations import Input
 
-        try:
-            result = self.evaluate(*([Input()] * self.input_count))
-            if isinstance(result, collections.abc.Sequence) and all(
-                isinstance(e, Operation) for e in result
-            ):
-                return result
-            if isinstance(result, Operation):
-                return [result]
-        except TypeError:
-            pass
-        except ValueError:
-            pass
+        result = self.evaluate(*([Input()] * self.input_count))
+        if isinstance(result, collections.abc.Sequence) and all(
+            isinstance(e, Operation) for e in result
+        ):
+            return cast(List[Operation], result)
         return [self]
 
     def to_sfg(self) -> "SFG":
@@ -966,14 +965,19 @@ class AbstractOperation(Operation, AbstractGraphComponent):
             )
         return self.input(0)
 
-    def truncate_input(self, index: int, value: Number, bits: int) -> Number:
-        return int(value) & ((2**bits) - 1)
+    # TODO: Fix
+    def truncate_input(self, index: int, value: Num, bits: int) -> Num:
+        if isinstance(value, (float, int)):
+            return round(value) & ((2**bits) - 1)
+        else:
+            raise TypeError
 
+    # TODO: Seems wrong??? - Oscar
     def truncate_inputs(
         self,
-        input_values: Sequence[Number],
+        input_values: Sequence[Num],
         bits_override: Optional[int] = None,
-    ) -> Sequence[Number]:
+    ) -> Sequence[Num]:
         """
         Truncate the values to be used as inputs to the bit lengths specified
         by the respective signals connected to each input.
@@ -981,7 +985,6 @@ class AbstractOperation(Operation, AbstractGraphComponent):
         args = []
         for i, input_port in enumerate(self.inputs):
             value = input_values[i]
-            bits = bits_override
             if bits_override is None and input_port.signal_count >= 1:
                 bits = input_port.signals[0].bits
             if bits_override is not None:
@@ -990,7 +993,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
                         "Complex value cannot be truncated to {bits} bits as"
                         " requested by the signal connected to input #{i}"
                     )
-                value = self.truncate_input(i, value, bits)
+                value = self.truncate_input(i, value, bits_override)
             args.append(value)
         return args
 
@@ -1026,6 +1029,34 @@ class AbstractOperation(Operation, AbstractGraphComponent):
 
         return latency_offsets
 
+    def _check_all_latencies_set(self) -> None:
+        """Raises an exception of an input or output does not have its latency offset set
+        """
+        self.input_latency_offsets()
+        self.output_latency_offsets()
+
+    def input_latency_offsets(self) -> List[int]:
+        latency_offsets = [i.latency_offset for i in self.inputs]
+
+        if any(val is None for val in latency_offsets):
+            raise ValueError(
+                "Missing latencies for inputs"
+                f" {[i for (i, latency) in enumerate(latency_offsets) if latency is None]}"
+            )
+
+        return cast(List[int], latency_offsets)
+
+    def output_latency_offsets(self) -> List[int]:
+        latency_offsets = [i.latency_offset for i in self.outputs]
+
+        if any(val is None for val in latency_offsets):
+            raise ValueError(
+                "Missing latencies for outputs"
+                f" {[i for i in latency_offsets if i is not None]}"
+            )
+
+        return cast(List[int], latency_offsets)
+
     def set_latency(self, latency: int) -> None:
         if latency < 0:
             raise ValueError("Latency cannot be negative")
@@ -1110,33 +1141,28 @@ class AbstractOperation(Operation, AbstractGraphComponent):
             (0, 0),
         )
 
-    def _check_all_latencies_set(self):
-        if any(val is None for val in self.latency_offsets.values()):
-            raise ValueError(
-                f"All latencies must be set: {self.latency_offsets}"
-            )
-
     def _get_plot_coordinates_for_latency(
         self,
     ) -> Tuple[Tuple[float, float], ...]:
-        self._check_all_latencies_set()
         # Points for latency polygon
         latency = []
+        input_latencies = self.input_latency_offsets()
+        output_latencies = self.output_latency_offsets()
         # Remember starting point
-        start_point = (self.inputs[0].latency_offset, 0.0)
+        start_point = (input_latencies[0], 0.0)
         num_in = self.input_count
         latency.append(start_point)
         for k in range(1, num_in):
-            latency.append((self.inputs[k - 1].latency_offset, k / num_in))
-            latency.append((self.inputs[k].latency_offset, k / num_in))
-        latency.append((self.inputs[num_in - 1].latency_offset, 1))
+            latency.append((input_latencies[k - 1], k / num_in))
+            latency.append((input_latencies[k], k / num_in))
+        latency.append((input_latencies[num_in - 1], 1))
 
         num_out = self.output_count
-        latency.append((self.outputs[num_out - 1].latency_offset, 1))
+        latency.append((output_latencies[num_out - 1], 1))
         for k in reversed(range(1, num_out)):
-            latency.append((self.outputs[k].latency_offset, k / num_out))
-            latency.append((self.outputs[k - 1].latency_offset, k / num_out))
-        latency.append((self.outputs[0].latency_offset, 0.0))
+            latency.append((output_latencies[k], k / num_out))
+            latency.append((output_latencies[k - 1], k / num_out))
+        latency.append((output_latencies[0], 0.0))
         # Close the polygon
         latency.append(start_point)
 
@@ -1144,10 +1170,9 @@ class AbstractOperation(Operation, AbstractGraphComponent):
 
     def get_input_coordinates(self) -> Tuple[Tuple[float, float], ...]:
         # doc-string inherited
-        self._check_all_latencies_set()
         return tuple(
             (
-                self.inputs[k].latency_offset,
+                self.input_latency_offsets()[k],
                 (1 + 2 * k) / (2 * len(self.inputs)),
             )
             for k in range(len(self.inputs))
@@ -1155,10 +1180,9 @@ class AbstractOperation(Operation, AbstractGraphComponent):
 
     def get_output_coordinates(self) -> Tuple[Tuple[float, float], ...]:
         # doc-string inherited
-        self._check_all_latencies_set()
         return tuple(
             (
-                self.outputs[k].latency_offset,
+                self.output_latency_offsets()[k],
                 (1 + 2 * k) / (2 * len(self.outputs)),
             )
             for k in range(len(self.outputs))
diff --git a/b_asic/port.py b/b_asic/port.py
index d823a1bb..e1d35e3e 100644
--- a/b_asic/port.py
+++ b/b_asic/port.py
@@ -131,7 +131,7 @@ class AbstractPort(Port):
         return self._latency_offset
 
     @latency_offset.setter
-    def latency_offset(self, latency_offset: int):
+    def latency_offset(self, latency_offset: Optional[int]):
         self._latency_offset = latency_offset
 
 
diff --git a/b_asic/special_operations.py b/b_asic/special_operations.py
index c0609dd5..16edaaea 100644
--- a/b_asic/special_operations.py
+++ b/b_asic/special_operations.py
@@ -5,7 +5,6 @@ Contains operations with special purposes that may be treated differently from
 normal operations in an SFG.
 """
 
-from numbers import Number
 from typing import List, Optional, Sequence, Tuple
 
 from b_asic.graph_component import Name, TypeName
@@ -16,6 +15,7 @@ from b_asic.operation import (
     MutableResultMap,
 )
 from b_asic.port import SignalSourceProvider
+from b_asic.types import Name, Num, TypeName
 
 
 class Input(AbstractOperation):
@@ -28,12 +28,12 @@ class Input(AbstractOperation):
 
     _execution_time = 0
 
-    def __init__(self, name: Name = Name("")):
+    def __init__(self, name: Name = ""):
         """Construct an Input operation."""
         super().__init__(
             input_count=0,
             output_count=1,
-            name=Name(name),
+            name=name,
             latency_offsets={"out0": 0},
         )
         self.set_param("value", 0)
@@ -46,12 +46,12 @@ class Input(AbstractOperation):
         return self.param("value")
 
     @property
-    def value(self) -> Number:
+    def value(self) -> Num:
         """Get the current value of this input."""
         return self.param("value")
 
     @value.setter
-    def value(self, value: Number) -> None:
+    def value(self, value: Num) -> None:
         """Set the current value of this input."""
         self.set_param("value", value)
 
@@ -152,7 +152,7 @@ class Delay(AbstractOperation):
     def __init__(
         self,
         src0: Optional[SignalSourceProvider] = None,
-        initial_value: Number = 0,
+        initial_value: Num = 0,
         name: Name = Name(""),
     ):
         """Construct a Delay operation."""
@@ -173,7 +173,7 @@ class Delay(AbstractOperation):
 
     def current_output(
         self, index: int, delays: Optional[DelayMap] = None, prefix: str = ""
-    ) -> Optional[Number]:
+    ) -> Optional[Num]:
         if delays is not None:
             return delays.get(
                 self.key(index, prefix), self.param("initial_value")
@@ -183,13 +183,13 @@ class Delay(AbstractOperation):
     def evaluate_output(
         self,
         index: int,
-        input_values: Sequence[Number],
+        input_values: Sequence[Num],
         results: Optional[MutableResultMap] = None,
         delays: Optional[MutableDelayMap] = None,
         prefix: str = "",
         bits_override: Optional[int] = None,
         truncate: bool = True,
-    ) -> Number:
+    ) -> Num:
         if index != 0:
             raise IndexError(
                 f"Output index out of range (expected 0-0, got {index})"
@@ -214,11 +214,11 @@ class Delay(AbstractOperation):
         return value
 
     @property
-    def initial_value(self) -> Number:
+    def initial_value(self) -> Num:
         """Get the initial value of this delay."""
         return self.param("initial_value")
 
     @initial_value.setter
-    def initial_value(self, value: Number) -> None:
+    def initial_value(self, value: Num) -> None:
         """Set the initial value of this delay."""
         self.set_param("initial_value", value)
diff --git a/b_asic/types.py b/b_asic/types.py
new file mode 100644
index 00000000..1f37cebb
--- /dev/null
+++ b/b_asic/types.py
@@ -0,0 +1,22 @@
+from typing import NewType, Union
+
+# https://stackoverflow.com/questions/69334475/how-to-hint-at-number-types-i-e-subclasses-of-number-not-numbers-themselv
+Num = Union[int, float, complex]
+
+NumRuntime = (complex, float, int)
+
+
+Name = str
+# # We want to be able to initialize Name with String literals, but still have the
+# # benefit of static type checking that we don't pass non-names to name locations.
+# # However, until python 3.11 a string literal type was not available. In those versions,
+# # we'll fall back on just aliasing `str` => Name.
+# if sys.version_info >= (3, 11):
+#     from typing import LiteralString
+#     Name: TypeAlias = NewType("Name", str) | LiteralString
+# else:
+#     Name = str
+
+TypeName = NewType("TypeName", str)
+GraphID = NewType("GraphID", str)
+GraphIDNumber = NewType("GraphIDNumber", int)
diff --git a/test/test_operation.py b/test/test_operation.py
index ec2fe26e..ae7c2949 100644
--- a/test/test_operation.py
+++ b/test/test_operation.py
@@ -318,7 +318,9 @@ class TestIOCoordinates:
         bfly = Butterfly()
 
         bfly.set_latency_offsets({"in0": 3, "out1": 5})
-        with pytest.raises(ValueError, match="All latencies must be set:"):
+        with pytest.raises(
+            ValueError, match="Missing latencies for inputs \\[1\\]"
+        ):
             bfly.get_io_coordinates()
 
 
-- 
GitLab