Skip to content
Snippets Groups Projects
Commit 45c7bb78 authored by Frans Skarman's avatar Frans Skarman :tropical_fish: Committed by Oscar Gustafsson
Browse files

Fix type errors

parent 7c45e4c3
No related branches found
No related tags found
1 merge request!176Name literals
Pipeline #89474 passed
......@@ -4,7 +4,6 @@ B-ASIC Core Operations Module.
Contains some of the most commonly used mathematical operations.
"""
from numbers import Number
from typing import Dict, Optional
from numpy import abs as np_abs
......@@ -13,6 +12,7 @@ from numpy import conjugate, sqrt
from b_asic.graph_component import Name, TypeName
from b_asic.operation import AbstractOperation
from b_asic.port import SignalSourceProvider
from b_asic.types import Num
class Constant(AbstractOperation):
......@@ -35,12 +35,12 @@ class Constant(AbstractOperation):
_execution_time = 0
def __init__(self, value: Number = 0, name: Name = Name("")):
def __init__(self, value: Num = 0, name: Name = ""):
"""Construct a Constant operation with the given value."""
super().__init__(
input_count=0,
output_count=1,
name=Name(name),
name=name,
latency_offsets={"out0": 0},
)
self.set_param("value", value)
......@@ -53,12 +53,12 @@ class Constant(AbstractOperation):
return self.param("value")
@property
def value(self) -> Number:
def value(self) -> Num:
"""Get the constant value of this operation."""
return self.param("value")
@value.setter
def value(self, value: Number) -> None:
def value(self, value: Num) -> None:
"""Set the constant value of this operation."""
self.set_param("value", value)
......@@ -257,7 +257,7 @@ class AddSub(AbstractOperation):
return a + b if self.is_add else a - b
@property
def is_add(self) -> Number:
def is_add(self) -> Num:
"""Get if operation is add."""
return self.param("is_add")
......@@ -582,7 +582,7 @@ class ConstantMultiplication(AbstractOperation):
def __init__(
self,
value: Number = 0,
value: Num = 0,
src0: Optional[SignalSourceProvider] = None,
name: Name = Name(""),
latency: Optional[int] = None,
......@@ -610,12 +610,12 @@ class ConstantMultiplication(AbstractOperation):
return a * self.param("value")
@property
def value(self) -> Number:
def value(self) -> Num:
"""Get the constant value of this operation."""
return self.param("value")
@value.setter
def value(self, value: Number) -> None:
def value(self, value: Num) -> None:
"""Set the constant value of this operation."""
self.set_param("value", value)
......@@ -714,7 +714,7 @@ class SymmetricTwoportAdaptor(AbstractOperation):
def __init__(
self,
value: Number = 0,
value: Num = 0,
src0: Optional[SignalSourceProvider] = None,
src1: Optional[SignalSourceProvider] = None,
name: Name = Name(""),
......@@ -743,12 +743,12 @@ class SymmetricTwoportAdaptor(AbstractOperation):
return b + tmp, a + tmp
@property
def value(self) -> Number:
def value(self) -> Num:
"""Get the constant value of this operation."""
return self.param("value")
@value.setter
def value(self, value: Number) -> None:
def value(self, value: Num) -> None:
"""Set the constant value of this operation."""
self.set_param("value", value)
......
......@@ -7,12 +7,9 @@ Contains the base for all components with an ID in a signal flow graph.
from abc import ABC, abstractmethod
from collections import deque
from copy import copy, deepcopy
from typing import Any, Dict, Generator, Iterable, Mapping, NewType, cast
from typing import Any, Dict, Generator, Iterable, Mapping, cast
Name = NewType("Name", str)
TypeName = NewType("TypeName", str)
GraphID = NewType("GraphID", str)
GraphIDNumber = NewType("GraphIDNumber", int)
from b_asic.types import GraphID, GraphIDNumber, Name, Num, TypeName
class GraphComponent(ABC):
......
......@@ -5,6 +5,7 @@ Contains the base for operations that are used by B-ASIC.
"""
import collections
import collections.abc
import itertools as it
from abc import abstractmethod
from numbers import Number
......@@ -22,6 +23,7 @@ from typing import (
Tuple,
Union,
cast,
overload,
)
from b_asic.graph_component import (
......@@ -32,6 +34,7 @@ from b_asic.graph_component import (
)
from b_asic.port import InputPort, OutputPort, SignalSourceProvider
from b_asic.signal import Signal
from b_asic.types import Num, NumRuntime
if TYPE_CHECKING:
# Conditionally imported to avoid circular imports
......@@ -47,10 +50,10 @@ if TYPE_CHECKING:
ResultKey = NewType("ResultKey", str)
ResultMap = Mapping[ResultKey, Optional[Number]]
MutableResultMap = MutableMapping[ResultKey, Optional[Number]]
DelayMap = Mapping[ResultKey, Number]
MutableDelayMap = MutableMapping[ResultKey, Number]
ResultMap = Mapping[ResultKey, Optional[Num]]
MutableResultMap = MutableMapping[ResultKey, Optional[Num]]
DelayMap = Mapping[ResultKey, Num]
MutableDelayMap = MutableMapping[ResultKey, Num]
class Operation(GraphComponent, SignalSourceProvider):
......@@ -66,7 +69,7 @@ class Operation(GraphComponent, SignalSourceProvider):
"""
@abstractmethod
def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Addition":
def __add__(self, src: Union[SignalSourceProvider, Num]) -> "Addition":
"""
Overloads the addition operator to make it return a new Addition operation
object that is connected to the self and other objects.
......@@ -74,7 +77,7 @@ class Operation(GraphComponent, SignalSourceProvider):
raise NotImplementedError
@abstractmethod
def __radd__(self, src: Union[SignalSourceProvider, Number]) -> "Addition":
def __radd__(self, src: Union[SignalSourceProvider, Num]) -> "Addition":
"""
Overloads the addition operator to make it return a new Addition operation
object that is connected to the self and other objects.
......@@ -82,9 +85,7 @@ class Operation(GraphComponent, SignalSourceProvider):
raise NotImplementedError
@abstractmethod
def __sub__(
self, src: Union[SignalSourceProvider, Number]
) -> "Subtraction":
def __sub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction":
"""
Overloads the subtraction operator to make it return a new Subtraction
operation object that is connected to the self and other objects.
......@@ -92,9 +93,7 @@ class Operation(GraphComponent, SignalSourceProvider):
raise NotImplementedError
@abstractmethod
def __rsub__(
self, src: Union[SignalSourceProvider, Number]
) -> "Subtraction":
def __rsub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction":
"""
Overloads the subtraction operator to make it return a new Subtraction
operation object that is connected to the self and other objects.
......@@ -103,7 +102,7 @@ class Operation(GraphComponent, SignalSourceProvider):
@abstractmethod
def __mul__(
self, src: Union[SignalSourceProvider, Number]
self, src: Union[SignalSourceProvider, Num]
) -> Union["Multiplication", "ConstantMultiplication"]:
"""
Overloads the multiplication operator to make it return a new Multiplication
......@@ -116,7 +115,7 @@ class Operation(GraphComponent, SignalSourceProvider):
@abstractmethod
def __rmul__(
self, src: Union[SignalSourceProvider, Number]
self, src: Union[SignalSourceProvider, Num]
) -> Union["Multiplication", "ConstantMultiplication"]:
"""
Overloads the multiplication operator to make it return a new Multiplication
......@@ -128,9 +127,7 @@ class Operation(GraphComponent, SignalSourceProvider):
raise NotImplementedError
@abstractmethod
def __truediv__(
self, src: Union[SignalSourceProvider, Number]
) -> "Division":
def __truediv__(self, src: Union[SignalSourceProvider, Num]) -> "Division":
"""
Overloads the division operator to make it return a new Division operation
object that is connected to the self and other objects.
......@@ -139,7 +136,7 @@ class Operation(GraphComponent, SignalSourceProvider):
@abstractmethod
def __rtruediv__(
self, src: Union[SignalSourceProvider, Number]
self, src: Union[SignalSourceProvider, Num]
) -> Union["Division", "Reciprocal"]:
"""
Overloads the division operator to make it return a new Division operation
......@@ -219,7 +216,7 @@ class Operation(GraphComponent, SignalSourceProvider):
@abstractmethod
def current_output(
self, index: int, delays: Optional[DelayMap] = None, prefix: str = ""
) -> Optional[Number]:
) -> Optional[Num]:
"""
Get the current output at the given index of this operation, if available.
......@@ -238,13 +235,13 @@ class Operation(GraphComponent, SignalSourceProvider):
def evaluate_output(
self,
index: int,
input_values: Sequence[Number],
input_values: Sequence[Num],
results: Optional[MutableResultMap] = None,
delays: Optional[MutableDelayMap] = None,
prefix: str = "",
bits_override: Optional[int] = None,
truncate: bool = True,
) -> Number:
) -> Num:
"""
Evaluate the output at the given index of this operation with the given input
values.
......@@ -281,7 +278,7 @@ class Operation(GraphComponent, SignalSourceProvider):
@abstractmethod
def current_outputs(
self, delays: Optional[DelayMap] = None, prefix: str = ""
) -> Sequence[Optional[Number]]:
) -> Sequence[Optional[Num]]:
"""
Get all current outputs of this operation, if available.
......@@ -294,13 +291,13 @@ class Operation(GraphComponent, SignalSourceProvider):
@abstractmethod
def evaluate_outputs(
self,
input_values: Sequence[Number],
input_values: Sequence[Num],
results: Optional[MutableResultMap] = None,
delays: Optional[MutableDelayMap] = None,
prefix: str = "",
bits_override: Optional[int] = None,
truncate: bool = True,
) -> Sequence[Number]:
) -> Sequence[Num]:
"""
Evaluate all outputs of this operation given the input values.
See evaluate_output for more information.
......@@ -336,7 +333,7 @@ class Operation(GraphComponent, SignalSourceProvider):
raise NotImplementedError
@abstractmethod
def truncate_input(self, index: int, value: Number, bits: int) -> Number:
def truncate_input(self, index: int, value: Num, bits: int) -> Num:
"""
Truncate the value to be used as input at the given index to a certain bit
length.
......@@ -397,7 +394,7 @@ class Operation(GraphComponent, SignalSourceProvider):
@execution_time.setter
@abstractmethod
def execution_time(self, latency: int) -> None:
def execution_time(self, latency: Optional[int]) -> None:
"""
Sets the execution time of the operation to the specified integer
value. The execution time cannot be a negative integer.
......@@ -570,52 +567,63 @@ class AbstractOperation(Operation, AbstractGraphComponent):
self._execution_time = execution_time
@overload
@abstractmethod
def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ
def evaluate(
self, *inputs: Operation
) -> List[Operation]: # pylint: disable=arguments-differ
...
@overload
@abstractmethod
def evaluate(
self, *inputs: Num
) -> List[Num]: # pylint: disable=arguments-differ
...
@abstractmethod
def evaluate(self, *inputs): # pylint: disable=arguments-differ
"""
Evaluate the operation and generate a list of output values given a
list of input values.
"""
raise NotImplementedError
def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Addition":
def __add__(self, src: Union[SignalSourceProvider, Num]) -> "Addition":
# Import here to avoid circular imports.
from b_asic.core_operations import Addition, Constant
return Addition(
self, Constant(src) if isinstance(src, Number) else src
)
if isinstance(src, NumRuntime):
return Addition(self, Constant(src))
else:
return Addition(self, src)
def __radd__(self, src: Union[SignalSourceProvider, Number]) -> "Addition":
def __radd__(self, src: Union[SignalSourceProvider, Num]) -> "Addition":
# Import here to avoid circular imports.
from b_asic.core_operations import Addition, Constant
return Addition(
Constant(src) if isinstance(src, Number) else src, self
Constant(src) if isinstance(src, NumRuntime) else src, self
)
def __sub__(
self, src: Union[SignalSourceProvider, Number]
) -> "Subtraction":
def __sub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction":
# Import here to avoid circular imports.
from b_asic.core_operations import Constant, Subtraction
return Subtraction(
self, Constant(src) if isinstance(src, Number) else src
self, Constant(src) if isinstance(src, NumRuntime) else src
)
def __rsub__(
self, src: Union[SignalSourceProvider, Number]
) -> "Subtraction":
def __rsub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction":
# Import here to avoid circular imports.
from b_asic.core_operations import Constant, Subtraction
return Subtraction(
Constant(src) if isinstance(src, Number) else src, self
Constant(src) if isinstance(src, NumRuntime) else src, self
)
def __mul__(
self, src: Union[SignalSourceProvider, Number]
self, src: Union[SignalSourceProvider, Num]
) -> Union["Multiplication", "ConstantMultiplication"]:
# Import here to avoid circular imports.
from b_asic.core_operations import (
......@@ -625,12 +633,12 @@ class AbstractOperation(Operation, AbstractGraphComponent):
return (
ConstantMultiplication(src, self)
if isinstance(src, Number)
if isinstance(src, NumRuntime)
else Multiplication(self, src)
)
def __rmul__(
self, src: Union[SignalSourceProvider, Number]
self, src: Union[SignalSourceProvider, Num]
) -> Union["Multiplication", "ConstantMultiplication"]:
# Import here to avoid circular imports.
from b_asic.core_operations import (
......@@ -640,27 +648,25 @@ class AbstractOperation(Operation, AbstractGraphComponent):
return (
ConstantMultiplication(src, self)
if isinstance(src, Number)
if isinstance(src, NumRuntime)
else Multiplication(src, self)
)
def __truediv__(
self, src: Union[SignalSourceProvider, Number]
) -> "Division":
def __truediv__(self, src: Union[SignalSourceProvider, Num]) -> "Division":
# Import here to avoid circular imports.
from b_asic.core_operations import Constant, Division
return Division(
self, Constant(src) if isinstance(src, Number) else src
self, Constant(src) if isinstance(src, NumRuntime) else src
)
def __rtruediv__(
self, src: Union[SignalSourceProvider, Number]
self, src: Union[SignalSourceProvider, Num]
) -> Union["Division", "Reciprocal"]:
# Import here to avoid circular imports.
from b_asic.core_operations import Constant, Division, Reciprocal
if isinstance(src, Number):
if isinstance(src, NumRuntime):
if src == 1:
return Reciprocal(self)
else:
......@@ -771,19 +777,19 @@ class AbstractOperation(Operation, AbstractGraphComponent):
def current_output(
self, index: int, delays: Optional[DelayMap] = None, prefix: str = ""
) -> Optional[Number]:
) -> Optional[Num]:
return None
def evaluate_output(
self,
index: int,
input_values: Sequence[Number],
input_values: Sequence[Num],
results: Optional[MutableResultMap] = None,
delays: Optional[MutableDelayMap] = None,
prefix: str = "",
bits_override: Optional[int] = None,
truncate: bool = True,
) -> Number:
) -> Num:
if index < 0 or index >= self.output_count:
raise IndexError(
"Output index out of range (expected"
......@@ -828,7 +834,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
def current_outputs(
self, delays: Optional[DelayMap] = None, prefix: str = ""
) -> Sequence[Optional[Number]]:
) -> Sequence[Optional[Num]]:
return [
self.current_output(i, delays, prefix)
for i in range(self.output_count)
......@@ -836,13 +842,13 @@ class AbstractOperation(Operation, AbstractGraphComponent):
def evaluate_outputs(
self,
input_values: Sequence[Number],
input_values: Sequence[Num],
results: Optional[MutableResultMap] = None,
delays: Optional[MutableDelayMap] = None,
prefix: str = "",
bits_override: Optional[int] = None,
truncate: bool = True,
) -> Sequence[Number]:
) -> Sequence[Num]:
return [
self.evaluate_output(
i,
......@@ -860,18 +866,11 @@ class AbstractOperation(Operation, AbstractGraphComponent):
# Import here to avoid circular imports.
from b_asic.special_operations import Input
try:
result = self.evaluate(*([Input()] * self.input_count))
if isinstance(result, collections.abc.Sequence) and all(
isinstance(e, Operation) for e in result
):
return result
if isinstance(result, Operation):
return [result]
except TypeError:
pass
except ValueError:
pass
result = self.evaluate(*([Input()] * self.input_count))
if isinstance(result, collections.abc.Sequence) and all(
isinstance(e, Operation) for e in result
):
return cast(List[Operation], result)
return [self]
def to_sfg(self) -> "SFG":
......@@ -966,14 +965,19 @@ class AbstractOperation(Operation, AbstractGraphComponent):
)
return self.input(0)
def truncate_input(self, index: int, value: Number, bits: int) -> Number:
return int(value) & ((2**bits) - 1)
# TODO: Fix
def truncate_input(self, index: int, value: Num, bits: int) -> Num:
if isinstance(value, (float, int)):
return round(value) & ((2**bits) - 1)
else:
raise TypeError
# TODO: Seems wrong??? - Oscar
def truncate_inputs(
self,
input_values: Sequence[Number],
input_values: Sequence[Num],
bits_override: Optional[int] = None,
) -> Sequence[Number]:
) -> Sequence[Num]:
"""
Truncate the values to be used as inputs to the bit lengths specified
by the respective signals connected to each input.
......@@ -981,7 +985,6 @@ class AbstractOperation(Operation, AbstractGraphComponent):
args = []
for i, input_port in enumerate(self.inputs):
value = input_values[i]
bits = bits_override
if bits_override is None and input_port.signal_count >= 1:
bits = input_port.signals[0].bits
if bits_override is not None:
......@@ -990,7 +993,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
"Complex value cannot be truncated to {bits} bits as"
" requested by the signal connected to input #{i}"
)
value = self.truncate_input(i, value, bits)
value = self.truncate_input(i, value, bits_override)
args.append(value)
return args
......@@ -1026,6 +1029,34 @@ class AbstractOperation(Operation, AbstractGraphComponent):
return latency_offsets
def _check_all_latencies_set(self) -> None:
"""Raises an exception of an input or output does not have its latency offset set
"""
self.input_latency_offsets()
self.output_latency_offsets()
def input_latency_offsets(self) -> List[int]:
latency_offsets = [i.latency_offset for i in self.inputs]
if any(val is None for val in latency_offsets):
raise ValueError(
"Missing latencies for inputs"
f" {[i for (i, latency) in enumerate(latency_offsets) if latency is None]}"
)
return cast(List[int], latency_offsets)
def output_latency_offsets(self) -> List[int]:
latency_offsets = [i.latency_offset for i in self.outputs]
if any(val is None for val in latency_offsets):
raise ValueError(
"Missing latencies for outputs"
f" {[i for i in latency_offsets if i is not None]}"
)
return cast(List[int], latency_offsets)
def set_latency(self, latency: int) -> None:
if latency < 0:
raise ValueError("Latency cannot be negative")
......@@ -1110,33 +1141,28 @@ class AbstractOperation(Operation, AbstractGraphComponent):
(0, 0),
)
def _check_all_latencies_set(self):
if any(val is None for val in self.latency_offsets.values()):
raise ValueError(
f"All latencies must be set: {self.latency_offsets}"
)
def _get_plot_coordinates_for_latency(
self,
) -> Tuple[Tuple[float, float], ...]:
self._check_all_latencies_set()
# Points for latency polygon
latency = []
input_latencies = self.input_latency_offsets()
output_latencies = self.output_latency_offsets()
# Remember starting point
start_point = (self.inputs[0].latency_offset, 0.0)
start_point = (input_latencies[0], 0.0)
num_in = self.input_count
latency.append(start_point)
for k in range(1, num_in):
latency.append((self.inputs[k - 1].latency_offset, k / num_in))
latency.append((self.inputs[k].latency_offset, k / num_in))
latency.append((self.inputs[num_in - 1].latency_offset, 1))
latency.append((input_latencies[k - 1], k / num_in))
latency.append((input_latencies[k], k / num_in))
latency.append((input_latencies[num_in - 1], 1))
num_out = self.output_count
latency.append((self.outputs[num_out - 1].latency_offset, 1))
latency.append((output_latencies[num_out - 1], 1))
for k in reversed(range(1, num_out)):
latency.append((self.outputs[k].latency_offset, k / num_out))
latency.append((self.outputs[k - 1].latency_offset, k / num_out))
latency.append((self.outputs[0].latency_offset, 0.0))
latency.append((output_latencies[k], k / num_out))
latency.append((output_latencies[k - 1], k / num_out))
latency.append((output_latencies[0], 0.0))
# Close the polygon
latency.append(start_point)
......@@ -1144,10 +1170,9 @@ class AbstractOperation(Operation, AbstractGraphComponent):
def get_input_coordinates(self) -> Tuple[Tuple[float, float], ...]:
# doc-string inherited
self._check_all_latencies_set()
return tuple(
(
self.inputs[k].latency_offset,
self.input_latency_offsets()[k],
(1 + 2 * k) / (2 * len(self.inputs)),
)
for k in range(len(self.inputs))
......@@ -1155,10 +1180,9 @@ class AbstractOperation(Operation, AbstractGraphComponent):
def get_output_coordinates(self) -> Tuple[Tuple[float, float], ...]:
# doc-string inherited
self._check_all_latencies_set()
return tuple(
(
self.outputs[k].latency_offset,
self.output_latency_offsets()[k],
(1 + 2 * k) / (2 * len(self.outputs)),
)
for k in range(len(self.outputs))
......
......@@ -131,7 +131,7 @@ class AbstractPort(Port):
return self._latency_offset
@latency_offset.setter
def latency_offset(self, latency_offset: int):
def latency_offset(self, latency_offset: Optional[int]):
self._latency_offset = latency_offset
......
......@@ -5,7 +5,6 @@ Contains operations with special purposes that may be treated differently from
normal operations in an SFG.
"""
from numbers import Number
from typing import List, Optional, Sequence, Tuple
from b_asic.graph_component import Name, TypeName
......@@ -16,6 +15,7 @@ from b_asic.operation import (
MutableResultMap,
)
from b_asic.port import SignalSourceProvider
from b_asic.types import Name, Num, TypeName
class Input(AbstractOperation):
......@@ -28,12 +28,12 @@ class Input(AbstractOperation):
_execution_time = 0
def __init__(self, name: Name = Name("")):
def __init__(self, name: Name = ""):
"""Construct an Input operation."""
super().__init__(
input_count=0,
output_count=1,
name=Name(name),
name=name,
latency_offsets={"out0": 0},
)
self.set_param("value", 0)
......@@ -46,12 +46,12 @@ class Input(AbstractOperation):
return self.param("value")
@property
def value(self) -> Number:
def value(self) -> Num:
"""Get the current value of this input."""
return self.param("value")
@value.setter
def value(self, value: Number) -> None:
def value(self, value: Num) -> None:
"""Set the current value of this input."""
self.set_param("value", value)
......@@ -152,7 +152,7 @@ class Delay(AbstractOperation):
def __init__(
self,
src0: Optional[SignalSourceProvider] = None,
initial_value: Number = 0,
initial_value: Num = 0,
name: Name = Name(""),
):
"""Construct a Delay operation."""
......@@ -173,7 +173,7 @@ class Delay(AbstractOperation):
def current_output(
self, index: int, delays: Optional[DelayMap] = None, prefix: str = ""
) -> Optional[Number]:
) -> Optional[Num]:
if delays is not None:
return delays.get(
self.key(index, prefix), self.param("initial_value")
......@@ -183,13 +183,13 @@ class Delay(AbstractOperation):
def evaluate_output(
self,
index: int,
input_values: Sequence[Number],
input_values: Sequence[Num],
results: Optional[MutableResultMap] = None,
delays: Optional[MutableDelayMap] = None,
prefix: str = "",
bits_override: Optional[int] = None,
truncate: bool = True,
) -> Number:
) -> Num:
if index != 0:
raise IndexError(
f"Output index out of range (expected 0-0, got {index})"
......@@ -214,11 +214,11 @@ class Delay(AbstractOperation):
return value
@property
def initial_value(self) -> Number:
def initial_value(self) -> Num:
"""Get the initial value of this delay."""
return self.param("initial_value")
@initial_value.setter
def initial_value(self, value: Number) -> None:
def initial_value(self, value: Num) -> None:
"""Set the initial value of this delay."""
self.set_param("initial_value", value)
from typing import NewType, Union
# https://stackoverflow.com/questions/69334475/how-to-hint-at-number-types-i-e-subclasses-of-number-not-numbers-themselv
Num = Union[int, float, complex]
NumRuntime = (complex, float, int)
Name = str
# # We want to be able to initialize Name with String literals, but still have the
# # benefit of static type checking that we don't pass non-names to name locations.
# # However, until python 3.11 a string literal type was not available. In those versions,
# # we'll fall back on just aliasing `str` => Name.
# if sys.version_info >= (3, 11):
# from typing import LiteralString
# Name: TypeAlias = NewType("Name", str) | LiteralString
# else:
# Name = str
TypeName = NewType("TypeName", str)
GraphID = NewType("GraphID", str)
GraphIDNumber = NewType("GraphIDNumber", int)
......@@ -318,7 +318,9 @@ class TestIOCoordinates:
bfly = Butterfly()
bfly.set_latency_offsets({"in0": 3, "out1": 5})
with pytest.raises(ValueError, match="All latencies must be set:"):
with pytest.raises(
ValueError, match="Missing latencies for inputs \\[1\\]"
):
bfly.get_io_coordinates()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment