From 45c7bb7836384c73b13da46998cf61b6c3568470 Mon Sep 17 00:00:00 2001 From: Frans Skarman <frans.skarman@liu.se> Date: Tue, 14 Feb 2023 10:43:12 +0000 Subject: [PATCH] Fix type errors --- b_asic/core_operations.py | 24 ++-- b_asic/graph_component.py | 7 +- b_asic/operation.py | 212 +++++++++++++++++++---------------- b_asic/port.py | 2 +- b_asic/special_operations.py | 22 ++-- b_asic/types.py | 22 ++++ test/test_operation.py | 4 +- 7 files changed, 169 insertions(+), 124 deletions(-) create mode 100644 b_asic/types.py diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py index b4c9642b..aa9e33d4 100644 --- a/b_asic/core_operations.py +++ b/b_asic/core_operations.py @@ -4,7 +4,6 @@ B-ASIC Core Operations Module. Contains some of the most commonly used mathematical operations. """ -from numbers import Number from typing import Dict, Optional from numpy import abs as np_abs @@ -13,6 +12,7 @@ from numpy import conjugate, sqrt from b_asic.graph_component import Name, TypeName from b_asic.operation import AbstractOperation from b_asic.port import SignalSourceProvider +from b_asic.types import Num class Constant(AbstractOperation): @@ -35,12 +35,12 @@ class Constant(AbstractOperation): _execution_time = 0 - def __init__(self, value: Number = 0, name: Name = Name("")): + def __init__(self, value: Num = 0, name: Name = ""): """Construct a Constant operation with the given value.""" super().__init__( input_count=0, output_count=1, - name=Name(name), + name=name, latency_offsets={"out0": 0}, ) self.set_param("value", value) @@ -53,12 +53,12 @@ class Constant(AbstractOperation): return self.param("value") @property - def value(self) -> Number: + def value(self) -> Num: """Get the constant value of this operation.""" return self.param("value") @value.setter - def value(self, value: Number) -> None: + def value(self, value: Num) -> None: """Set the constant value of this operation.""" self.set_param("value", value) @@ -257,7 +257,7 @@ class AddSub(AbstractOperation): return a + b if self.is_add else a - b @property - def is_add(self) -> Number: + def is_add(self) -> Num: """Get if operation is add.""" return self.param("is_add") @@ -582,7 +582,7 @@ class ConstantMultiplication(AbstractOperation): def __init__( self, - value: Number = 0, + value: Num = 0, src0: Optional[SignalSourceProvider] = None, name: Name = Name(""), latency: Optional[int] = None, @@ -610,12 +610,12 @@ class ConstantMultiplication(AbstractOperation): return a * self.param("value") @property - def value(self) -> Number: + def value(self) -> Num: """Get the constant value of this operation.""" return self.param("value") @value.setter - def value(self, value: Number) -> None: + def value(self, value: Num) -> None: """Set the constant value of this operation.""" self.set_param("value", value) @@ -714,7 +714,7 @@ class SymmetricTwoportAdaptor(AbstractOperation): def __init__( self, - value: Number = 0, + value: Num = 0, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = Name(""), @@ -743,12 +743,12 @@ class SymmetricTwoportAdaptor(AbstractOperation): return b + tmp, a + tmp @property - def value(self) -> Number: + def value(self) -> Num: """Get the constant value of this operation.""" return self.param("value") @value.setter - def value(self, value: Number) -> None: + def value(self, value: Num) -> None: """Set the constant value of this operation.""" self.set_param("value", value) diff --git a/b_asic/graph_component.py b/b_asic/graph_component.py index 86018c8d..1f910c30 100644 --- a/b_asic/graph_component.py +++ b/b_asic/graph_component.py @@ -7,12 +7,9 @@ Contains the base for all components with an ID in a signal flow graph. from abc import ABC, abstractmethod from collections import deque from copy import copy, deepcopy -from typing import Any, Dict, Generator, Iterable, Mapping, NewType, cast +from typing import Any, Dict, Generator, Iterable, Mapping, cast -Name = NewType("Name", str) -TypeName = NewType("TypeName", str) -GraphID = NewType("GraphID", str) -GraphIDNumber = NewType("GraphIDNumber", int) +from b_asic.types import GraphID, GraphIDNumber, Name, Num, TypeName class GraphComponent(ABC): diff --git a/b_asic/operation.py b/b_asic/operation.py index e39e9d69..8b8251dc 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -5,6 +5,7 @@ Contains the base for operations that are used by B-ASIC. """ import collections +import collections.abc import itertools as it from abc import abstractmethod from numbers import Number @@ -22,6 +23,7 @@ from typing import ( Tuple, Union, cast, + overload, ) from b_asic.graph_component import ( @@ -32,6 +34,7 @@ from b_asic.graph_component import ( ) from b_asic.port import InputPort, OutputPort, SignalSourceProvider from b_asic.signal import Signal +from b_asic.types import Num, NumRuntime if TYPE_CHECKING: # Conditionally imported to avoid circular imports @@ -47,10 +50,10 @@ if TYPE_CHECKING: ResultKey = NewType("ResultKey", str) -ResultMap = Mapping[ResultKey, Optional[Number]] -MutableResultMap = MutableMapping[ResultKey, Optional[Number]] -DelayMap = Mapping[ResultKey, Number] -MutableDelayMap = MutableMapping[ResultKey, Number] +ResultMap = Mapping[ResultKey, Optional[Num]] +MutableResultMap = MutableMapping[ResultKey, Optional[Num]] +DelayMap = Mapping[ResultKey, Num] +MutableDelayMap = MutableMapping[ResultKey, Num] class Operation(GraphComponent, SignalSourceProvider): @@ -66,7 +69,7 @@ class Operation(GraphComponent, SignalSourceProvider): """ @abstractmethod - def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Addition": + def __add__(self, src: Union[SignalSourceProvider, Num]) -> "Addition": """ Overloads the addition operator to make it return a new Addition operation object that is connected to the self and other objects. @@ -74,7 +77,7 @@ class Operation(GraphComponent, SignalSourceProvider): raise NotImplementedError @abstractmethod - def __radd__(self, src: Union[SignalSourceProvider, Number]) -> "Addition": + def __radd__(self, src: Union[SignalSourceProvider, Num]) -> "Addition": """ Overloads the addition operator to make it return a new Addition operation object that is connected to the self and other objects. @@ -82,9 +85,7 @@ class Operation(GraphComponent, SignalSourceProvider): raise NotImplementedError @abstractmethod - def __sub__( - self, src: Union[SignalSourceProvider, Number] - ) -> "Subtraction": + def __sub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction": """ Overloads the subtraction operator to make it return a new Subtraction operation object that is connected to the self and other objects. @@ -92,9 +93,7 @@ class Operation(GraphComponent, SignalSourceProvider): raise NotImplementedError @abstractmethod - def __rsub__( - self, src: Union[SignalSourceProvider, Number] - ) -> "Subtraction": + def __rsub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction": """ Overloads the subtraction operator to make it return a new Subtraction operation object that is connected to the self and other objects. @@ -103,7 +102,7 @@ class Operation(GraphComponent, SignalSourceProvider): @abstractmethod def __mul__( - self, src: Union[SignalSourceProvider, Number] + self, src: Union[SignalSourceProvider, Num] ) -> Union["Multiplication", "ConstantMultiplication"]: """ Overloads the multiplication operator to make it return a new Multiplication @@ -116,7 +115,7 @@ class Operation(GraphComponent, SignalSourceProvider): @abstractmethod def __rmul__( - self, src: Union[SignalSourceProvider, Number] + self, src: Union[SignalSourceProvider, Num] ) -> Union["Multiplication", "ConstantMultiplication"]: """ Overloads the multiplication operator to make it return a new Multiplication @@ -128,9 +127,7 @@ class Operation(GraphComponent, SignalSourceProvider): raise NotImplementedError @abstractmethod - def __truediv__( - self, src: Union[SignalSourceProvider, Number] - ) -> "Division": + def __truediv__(self, src: Union[SignalSourceProvider, Num]) -> "Division": """ Overloads the division operator to make it return a new Division operation object that is connected to the self and other objects. @@ -139,7 +136,7 @@ class Operation(GraphComponent, SignalSourceProvider): @abstractmethod def __rtruediv__( - self, src: Union[SignalSourceProvider, Number] + self, src: Union[SignalSourceProvider, Num] ) -> Union["Division", "Reciprocal"]: """ Overloads the division operator to make it return a new Division operation @@ -219,7 +216,7 @@ class Operation(GraphComponent, SignalSourceProvider): @abstractmethod def current_output( self, index: int, delays: Optional[DelayMap] = None, prefix: str = "" - ) -> Optional[Number]: + ) -> Optional[Num]: """ Get the current output at the given index of this operation, if available. @@ -238,13 +235,13 @@ class Operation(GraphComponent, SignalSourceProvider): def evaluate_output( self, index: int, - input_values: Sequence[Number], + input_values: Sequence[Num], results: Optional[MutableResultMap] = None, delays: Optional[MutableDelayMap] = None, prefix: str = "", bits_override: Optional[int] = None, truncate: bool = True, - ) -> Number: + ) -> Num: """ Evaluate the output at the given index of this operation with the given input values. @@ -281,7 +278,7 @@ class Operation(GraphComponent, SignalSourceProvider): @abstractmethod def current_outputs( self, delays: Optional[DelayMap] = None, prefix: str = "" - ) -> Sequence[Optional[Number]]: + ) -> Sequence[Optional[Num]]: """ Get all current outputs of this operation, if available. @@ -294,13 +291,13 @@ class Operation(GraphComponent, SignalSourceProvider): @abstractmethod def evaluate_outputs( self, - input_values: Sequence[Number], + input_values: Sequence[Num], results: Optional[MutableResultMap] = None, delays: Optional[MutableDelayMap] = None, prefix: str = "", bits_override: Optional[int] = None, truncate: bool = True, - ) -> Sequence[Number]: + ) -> Sequence[Num]: """ Evaluate all outputs of this operation given the input values. See evaluate_output for more information. @@ -336,7 +333,7 @@ class Operation(GraphComponent, SignalSourceProvider): raise NotImplementedError @abstractmethod - def truncate_input(self, index: int, value: Number, bits: int) -> Number: + def truncate_input(self, index: int, value: Num, bits: int) -> Num: """ Truncate the value to be used as input at the given index to a certain bit length. @@ -397,7 +394,7 @@ class Operation(GraphComponent, SignalSourceProvider): @execution_time.setter @abstractmethod - def execution_time(self, latency: int) -> None: + def execution_time(self, latency: Optional[int]) -> None: """ Sets the execution time of the operation to the specified integer value. The execution time cannot be a negative integer. @@ -570,52 +567,63 @@ class AbstractOperation(Operation, AbstractGraphComponent): self._execution_time = execution_time + @overload @abstractmethod - def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ + def evaluate( + self, *inputs: Operation + ) -> List[Operation]: # pylint: disable=arguments-differ + ... + + @overload + @abstractmethod + def evaluate( + self, *inputs: Num + ) -> List[Num]: # pylint: disable=arguments-differ + ... + + @abstractmethod + def evaluate(self, *inputs): # pylint: disable=arguments-differ """ Evaluate the operation and generate a list of output values given a list of input values. """ raise NotImplementedError - def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Addition": + def __add__(self, src: Union[SignalSourceProvider, Num]) -> "Addition": # Import here to avoid circular imports. from b_asic.core_operations import Addition, Constant - return Addition( - self, Constant(src) if isinstance(src, Number) else src - ) + if isinstance(src, NumRuntime): + return Addition(self, Constant(src)) + else: + return Addition(self, src) - def __radd__(self, src: Union[SignalSourceProvider, Number]) -> "Addition": + def __radd__(self, src: Union[SignalSourceProvider, Num]) -> "Addition": # Import here to avoid circular imports. from b_asic.core_operations import Addition, Constant return Addition( - Constant(src) if isinstance(src, Number) else src, self + Constant(src) if isinstance(src, NumRuntime) else src, self ) - def __sub__( - self, src: Union[SignalSourceProvider, Number] - ) -> "Subtraction": + def __sub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction": # Import here to avoid circular imports. from b_asic.core_operations import Constant, Subtraction return Subtraction( - self, Constant(src) if isinstance(src, Number) else src + self, Constant(src) if isinstance(src, NumRuntime) else src ) - def __rsub__( - self, src: Union[SignalSourceProvider, Number] - ) -> "Subtraction": + def __rsub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction": # Import here to avoid circular imports. from b_asic.core_operations import Constant, Subtraction return Subtraction( - Constant(src) if isinstance(src, Number) else src, self + Constant(src) if isinstance(src, NumRuntime) else src, self ) def __mul__( - self, src: Union[SignalSourceProvider, Number] + self, src: Union[SignalSourceProvider, Num] ) -> Union["Multiplication", "ConstantMultiplication"]: # Import here to avoid circular imports. from b_asic.core_operations import ( @@ -625,12 +633,12 @@ class AbstractOperation(Operation, AbstractGraphComponent): return ( ConstantMultiplication(src, self) - if isinstance(src, Number) + if isinstance(src, NumRuntime) else Multiplication(self, src) ) def __rmul__( - self, src: Union[SignalSourceProvider, Number] + self, src: Union[SignalSourceProvider, Num] ) -> Union["Multiplication", "ConstantMultiplication"]: # Import here to avoid circular imports. from b_asic.core_operations import ( @@ -640,27 +648,25 @@ class AbstractOperation(Operation, AbstractGraphComponent): return ( ConstantMultiplication(src, self) - if isinstance(src, Number) + if isinstance(src, NumRuntime) else Multiplication(src, self) ) - def __truediv__( - self, src: Union[SignalSourceProvider, Number] - ) -> "Division": + def __truediv__(self, src: Union[SignalSourceProvider, Num]) -> "Division": # Import here to avoid circular imports. from b_asic.core_operations import Constant, Division return Division( - self, Constant(src) if isinstance(src, Number) else src + self, Constant(src) if isinstance(src, NumRuntime) else src ) def __rtruediv__( - self, src: Union[SignalSourceProvider, Number] + self, src: Union[SignalSourceProvider, Num] ) -> Union["Division", "Reciprocal"]: # Import here to avoid circular imports. from b_asic.core_operations import Constant, Division, Reciprocal - if isinstance(src, Number): + if isinstance(src, NumRuntime): if src == 1: return Reciprocal(self) else: @@ -771,19 +777,19 @@ class AbstractOperation(Operation, AbstractGraphComponent): def current_output( self, index: int, delays: Optional[DelayMap] = None, prefix: str = "" - ) -> Optional[Number]: + ) -> Optional[Num]: return None def evaluate_output( self, index: int, - input_values: Sequence[Number], + input_values: Sequence[Num], results: Optional[MutableResultMap] = None, delays: Optional[MutableDelayMap] = None, prefix: str = "", bits_override: Optional[int] = None, truncate: bool = True, - ) -> Number: + ) -> Num: if index < 0 or index >= self.output_count: raise IndexError( "Output index out of range (expected" @@ -828,7 +834,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): def current_outputs( self, delays: Optional[DelayMap] = None, prefix: str = "" - ) -> Sequence[Optional[Number]]: + ) -> Sequence[Optional[Num]]: return [ self.current_output(i, delays, prefix) for i in range(self.output_count) @@ -836,13 +842,13 @@ class AbstractOperation(Operation, AbstractGraphComponent): def evaluate_outputs( self, - input_values: Sequence[Number], + input_values: Sequence[Num], results: Optional[MutableResultMap] = None, delays: Optional[MutableDelayMap] = None, prefix: str = "", bits_override: Optional[int] = None, truncate: bool = True, - ) -> Sequence[Number]: + ) -> Sequence[Num]: return [ self.evaluate_output( i, @@ -860,18 +866,11 @@ class AbstractOperation(Operation, AbstractGraphComponent): # Import here to avoid circular imports. from b_asic.special_operations import Input - try: - result = self.evaluate(*([Input()] * self.input_count)) - if isinstance(result, collections.abc.Sequence) and all( - isinstance(e, Operation) for e in result - ): - return result - if isinstance(result, Operation): - return [result] - except TypeError: - pass - except ValueError: - pass + result = self.evaluate(*([Input()] * self.input_count)) + if isinstance(result, collections.abc.Sequence) and all( + isinstance(e, Operation) for e in result + ): + return cast(List[Operation], result) return [self] def to_sfg(self) -> "SFG": @@ -966,14 +965,19 @@ class AbstractOperation(Operation, AbstractGraphComponent): ) return self.input(0) - def truncate_input(self, index: int, value: Number, bits: int) -> Number: - return int(value) & ((2**bits) - 1) + # TODO: Fix + def truncate_input(self, index: int, value: Num, bits: int) -> Num: + if isinstance(value, (float, int)): + return round(value) & ((2**bits) - 1) + else: + raise TypeError + # TODO: Seems wrong??? - Oscar def truncate_inputs( self, - input_values: Sequence[Number], + input_values: Sequence[Num], bits_override: Optional[int] = None, - ) -> Sequence[Number]: + ) -> Sequence[Num]: """ Truncate the values to be used as inputs to the bit lengths specified by the respective signals connected to each input. @@ -981,7 +985,6 @@ class AbstractOperation(Operation, AbstractGraphComponent): args = [] for i, input_port in enumerate(self.inputs): value = input_values[i] - bits = bits_override if bits_override is None and input_port.signal_count >= 1: bits = input_port.signals[0].bits if bits_override is not None: @@ -990,7 +993,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): "Complex value cannot be truncated to {bits} bits as" " requested by the signal connected to input #{i}" ) - value = self.truncate_input(i, value, bits) + value = self.truncate_input(i, value, bits_override) args.append(value) return args @@ -1026,6 +1029,34 @@ class AbstractOperation(Operation, AbstractGraphComponent): return latency_offsets + def _check_all_latencies_set(self) -> None: + """Raises an exception of an input or output does not have its latency offset set + """ + self.input_latency_offsets() + self.output_latency_offsets() + + def input_latency_offsets(self) -> List[int]: + latency_offsets = [i.latency_offset for i in self.inputs] + + if any(val is None for val in latency_offsets): + raise ValueError( + "Missing latencies for inputs" + f" {[i for (i, latency) in enumerate(latency_offsets) if latency is None]}" + ) + + return cast(List[int], latency_offsets) + + def output_latency_offsets(self) -> List[int]: + latency_offsets = [i.latency_offset for i in self.outputs] + + if any(val is None for val in latency_offsets): + raise ValueError( + "Missing latencies for outputs" + f" {[i for i in latency_offsets if i is not None]}" + ) + + return cast(List[int], latency_offsets) + def set_latency(self, latency: int) -> None: if latency < 0: raise ValueError("Latency cannot be negative") @@ -1110,33 +1141,28 @@ class AbstractOperation(Operation, AbstractGraphComponent): (0, 0), ) - def _check_all_latencies_set(self): - if any(val is None for val in self.latency_offsets.values()): - raise ValueError( - f"All latencies must be set: {self.latency_offsets}" - ) - def _get_plot_coordinates_for_latency( self, ) -> Tuple[Tuple[float, float], ...]: - self._check_all_latencies_set() # Points for latency polygon latency = [] + input_latencies = self.input_latency_offsets() + output_latencies = self.output_latency_offsets() # Remember starting point - start_point = (self.inputs[0].latency_offset, 0.0) + start_point = (input_latencies[0], 0.0) num_in = self.input_count latency.append(start_point) for k in range(1, num_in): - latency.append((self.inputs[k - 1].latency_offset, k / num_in)) - latency.append((self.inputs[k].latency_offset, k / num_in)) - latency.append((self.inputs[num_in - 1].latency_offset, 1)) + latency.append((input_latencies[k - 1], k / num_in)) + latency.append((input_latencies[k], k / num_in)) + latency.append((input_latencies[num_in - 1], 1)) num_out = self.output_count - latency.append((self.outputs[num_out - 1].latency_offset, 1)) + latency.append((output_latencies[num_out - 1], 1)) for k in reversed(range(1, num_out)): - latency.append((self.outputs[k].latency_offset, k / num_out)) - latency.append((self.outputs[k - 1].latency_offset, k / num_out)) - latency.append((self.outputs[0].latency_offset, 0.0)) + latency.append((output_latencies[k], k / num_out)) + latency.append((output_latencies[k - 1], k / num_out)) + latency.append((output_latencies[0], 0.0)) # Close the polygon latency.append(start_point) @@ -1144,10 +1170,9 @@ class AbstractOperation(Operation, AbstractGraphComponent): def get_input_coordinates(self) -> Tuple[Tuple[float, float], ...]: # doc-string inherited - self._check_all_latencies_set() return tuple( ( - self.inputs[k].latency_offset, + self.input_latency_offsets()[k], (1 + 2 * k) / (2 * len(self.inputs)), ) for k in range(len(self.inputs)) @@ -1155,10 +1180,9 @@ class AbstractOperation(Operation, AbstractGraphComponent): def get_output_coordinates(self) -> Tuple[Tuple[float, float], ...]: # doc-string inherited - self._check_all_latencies_set() return tuple( ( - self.outputs[k].latency_offset, + self.output_latency_offsets()[k], (1 + 2 * k) / (2 * len(self.outputs)), ) for k in range(len(self.outputs)) diff --git a/b_asic/port.py b/b_asic/port.py index d823a1bb..e1d35e3e 100644 --- a/b_asic/port.py +++ b/b_asic/port.py @@ -131,7 +131,7 @@ class AbstractPort(Port): return self._latency_offset @latency_offset.setter - def latency_offset(self, latency_offset: int): + def latency_offset(self, latency_offset: Optional[int]): self._latency_offset = latency_offset diff --git a/b_asic/special_operations.py b/b_asic/special_operations.py index c0609dd5..16edaaea 100644 --- a/b_asic/special_operations.py +++ b/b_asic/special_operations.py @@ -5,7 +5,6 @@ Contains operations with special purposes that may be treated differently from normal operations in an SFG. """ -from numbers import Number from typing import List, Optional, Sequence, Tuple from b_asic.graph_component import Name, TypeName @@ -16,6 +15,7 @@ from b_asic.operation import ( MutableResultMap, ) from b_asic.port import SignalSourceProvider +from b_asic.types import Name, Num, TypeName class Input(AbstractOperation): @@ -28,12 +28,12 @@ class Input(AbstractOperation): _execution_time = 0 - def __init__(self, name: Name = Name("")): + def __init__(self, name: Name = ""): """Construct an Input operation.""" super().__init__( input_count=0, output_count=1, - name=Name(name), + name=name, latency_offsets={"out0": 0}, ) self.set_param("value", 0) @@ -46,12 +46,12 @@ class Input(AbstractOperation): return self.param("value") @property - def value(self) -> Number: + def value(self) -> Num: """Get the current value of this input.""" return self.param("value") @value.setter - def value(self, value: Number) -> None: + def value(self, value: Num) -> None: """Set the current value of this input.""" self.set_param("value", value) @@ -152,7 +152,7 @@ class Delay(AbstractOperation): def __init__( self, src0: Optional[SignalSourceProvider] = None, - initial_value: Number = 0, + initial_value: Num = 0, name: Name = Name(""), ): """Construct a Delay operation.""" @@ -173,7 +173,7 @@ class Delay(AbstractOperation): def current_output( self, index: int, delays: Optional[DelayMap] = None, prefix: str = "" - ) -> Optional[Number]: + ) -> Optional[Num]: if delays is not None: return delays.get( self.key(index, prefix), self.param("initial_value") @@ -183,13 +183,13 @@ class Delay(AbstractOperation): def evaluate_output( self, index: int, - input_values: Sequence[Number], + input_values: Sequence[Num], results: Optional[MutableResultMap] = None, delays: Optional[MutableDelayMap] = None, prefix: str = "", bits_override: Optional[int] = None, truncate: bool = True, - ) -> Number: + ) -> Num: if index != 0: raise IndexError( f"Output index out of range (expected 0-0, got {index})" @@ -214,11 +214,11 @@ class Delay(AbstractOperation): return value @property - def initial_value(self) -> Number: + def initial_value(self) -> Num: """Get the initial value of this delay.""" return self.param("initial_value") @initial_value.setter - def initial_value(self, value: Number) -> None: + def initial_value(self, value: Num) -> None: """Set the initial value of this delay.""" self.set_param("initial_value", value) diff --git a/b_asic/types.py b/b_asic/types.py new file mode 100644 index 00000000..1f37cebb --- /dev/null +++ b/b_asic/types.py @@ -0,0 +1,22 @@ +from typing import NewType, Union + +# https://stackoverflow.com/questions/69334475/how-to-hint-at-number-types-i-e-subclasses-of-number-not-numbers-themselv +Num = Union[int, float, complex] + +NumRuntime = (complex, float, int) + + +Name = str +# # We want to be able to initialize Name with String literals, but still have the +# # benefit of static type checking that we don't pass non-names to name locations. +# # However, until python 3.11 a string literal type was not available. In those versions, +# # we'll fall back on just aliasing `str` => Name. +# if sys.version_info >= (3, 11): +# from typing import LiteralString +# Name: TypeAlias = NewType("Name", str) | LiteralString +# else: +# Name = str + +TypeName = NewType("TypeName", str) +GraphID = NewType("GraphID", str) +GraphIDNumber = NewType("GraphIDNumber", int) diff --git a/test/test_operation.py b/test/test_operation.py index ec2fe26e..ae7c2949 100644 --- a/test/test_operation.py +++ b/test/test_operation.py @@ -318,7 +318,9 @@ class TestIOCoordinates: bfly = Butterfly() bfly.set_latency_offsets({"in0": 3, "out1": 5}) - with pytest.raises(ValueError, match="All latencies must be set:"): + with pytest.raises( + ValueError, match="Missing latencies for inputs \\[1\\]" + ): bfly.get_io_coordinates() -- GitLab