Skip to content
Snippets Groups Projects
Commit 6733a627 authored by Oscar Gustafsson's avatar Oscar Gustafsson :bicyclist:
Browse files

Fix typing and add mypy config

parent 686e1f56
No related branches found
No related tags found
1 merge request!125Fix typing and add mypy config
Pipeline #88426 passed
......@@ -7,7 +7,7 @@ Contains the base for all components with an ID in a signal flow graph.
from abc import ABC, abstractmethod
from collections import deque
from copy import copy, deepcopy
from typing import Any, Dict, Generator, Iterable, Mapping, NewType
from typing import Any, Dict, Generator, Iterable, Mapping, NewType, cast
Name = NewType("Name", str)
TypeName = NewType("TypeName", str)
......@@ -176,6 +176,7 @@ class AbstractGraphComponent(GraphComponent):
component = fontier.popleft()
yield component
for neighbor in component.neighbors:
neighbor = cast(AbstractGraphComponent, neighbor)
if neighbor not in visited:
visited.add(neighbor)
fontier.append(neighbor)
......@@ -21,6 +21,7 @@ from typing import (
Sequence,
Tuple,
Union,
cast,
)
from b_asic.graph_component import (
......@@ -324,7 +325,7 @@ class Operation(GraphComponent, SignalSourceProvider):
@property
@abstractmethod
def latency_offsets(self) -> Dict[str, int]:
def latency_offsets(self) -> Dict[str, Optional[int]]:
"""
Get a dictionary with all the operations ports latency-offsets.
"""
......@@ -569,7 +570,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
def __str__(self) -> str:
"""Get a string representation of this operation."""
inputs_dict = {}
inputs_dict: Dict[int, Union[List[GraphID], str]] = {}
for i, inport in enumerate(self.inputs):
if inport.signal_count == 0:
inputs_dict[i] = "-"
......@@ -588,7 +589,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
dict_ele.append(GraphID("no_id"))
inputs_dict[i] = dict_ele
outputs_dict = {}
outputs_dict: Dict[int, Union[List[GraphID], str]] = {}
for i, outport in enumerate(self.outputs):
if outport.signal_count == 0:
outputs_dict[i] = "-"
......@@ -778,7 +779,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
last_operations = [last_operations]
outputs = [Output(o) for o in last_operations]
except TypeError:
operation_copy: Operation = self.copy_component()
operation_copy: Operation = cast(Operation, self.copy_component())
inputs = []
for i in range(self.input_count):
_input = Input()
......@@ -790,7 +791,9 @@ class AbstractOperation(Operation, AbstractGraphComponent):
return SFG(inputs=inputs, outputs=outputs)
def copy_component(self, *args, **kwargs) -> GraphComponent:
new_component: Operation = super().copy_component(*args, **kwargs)
new_component: Operation = cast(
Operation, super().copy_component(*args, **kwargs)
)
for i, inp in enumerate(self.inputs):
new_component.input(i).latency_offset = inp.latency_offset
for i, outp in enumerate(self.outputs):
......@@ -885,13 +888,16 @@ class AbstractOperation(Operation, AbstractGraphComponent):
return max(
(
(outp.latency_offset - inp.latency_offset)
(
cast(int, outp.latency_offset)
- cast(int, inp.latency_offset)
)
for outp, inp in it.product(self.outputs, self.inputs)
)
)
@property
def latency_offsets(self) -> Dict[str, int]:
def latency_offsets(self) -> Dict[str, Optional[int]]:
latency_offsets = {}
for i, inp in enumerate(self.inputs):
......@@ -948,13 +954,15 @@ class AbstractOperation(Operation, AbstractGraphComponent):
if self._execution_time is not None:
self._execution_time *= factor
for port in [*self.inputs, *self.outputs]:
port.latency_offset *= factor
if port.latency_offset is not None:
port.latency_offset *= factor
def _decrease_time_resolution(self, factor: int) -> None:
if self._execution_time is not None:
self._execution_time = self._execution_time // factor
for port in [*self.inputs, *self.outputs]:
port.latency_offset = port.latency_offset // factor
if port.latency_offset is not None:
port.latency_offset = port.latency_offset // factor
def get_plot_coordinates(
self,
......@@ -978,11 +986,18 @@ class AbstractOperation(Operation, AbstractGraphComponent):
[0, 0],
]
def _check_all_latencies_set(self):
if any(val is None for _, val in self.latency_offsets.items()):
raise ValueError(
f"All latencies must be set: {self.latency_offsets}"
)
def _get_plot_coordinates_for_latency(self) -> List[List[float]]:
self._check_all_latencies_set()
# Points for latency polygon
latency = []
# Remember starting point
start_point = [self.inputs[0].latency_offset, 0]
start_point = [self.inputs[0].latency_offset, 0.0]
num_in = self.input_count
latency.append(start_point)
for k in range(1, num_in):
......@@ -995,7 +1010,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
for k in reversed(range(1, num_out)):
latency.append([self.outputs[k].latency_offset, k / num_out])
latency.append([self.outputs[k - 1].latency_offset, k / num_out])
latency.append([self.outputs[0].latency_offset, 0])
latency.append([self.outputs[0].latency_offset, 0.0])
# Close the polygon
latency.append(start_point)
......@@ -1004,6 +1019,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
def get_io_coordinates(
self,
) -> Tuple[List[List[float]], List[List[float]]]:
self._check_all_latencies_set()
# Doc-string inherited
input_coords = [
[
......
......@@ -66,3 +66,9 @@ src_paths = ["b_asic", "test"]
skip = [
"test/test_gui"
]
[tool.mypy]
packages = ["b_asic", "test"]
no_site_packages = true
ignore_missing_imports = true
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment