Skip to content
Snippets Groups Projects
Commit 6733a627 authored by Oscar Gustafsson's avatar Oscar Gustafsson :bicyclist:
Browse files

Fix typing and add mypy config

parent 686e1f56
No related branches found
No related tags found
1 merge request!125Fix typing and add mypy config
Pipeline #88426 passed
...@@ -7,7 +7,7 @@ Contains the base for all components with an ID in a signal flow graph. ...@@ -7,7 +7,7 @@ Contains the base for all components with an ID in a signal flow graph.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import deque from collections import deque
from copy import copy, deepcopy from copy import copy, deepcopy
from typing import Any, Dict, Generator, Iterable, Mapping, NewType from typing import Any, Dict, Generator, Iterable, Mapping, NewType, cast
Name = NewType("Name", str) Name = NewType("Name", str)
TypeName = NewType("TypeName", str) TypeName = NewType("TypeName", str)
...@@ -176,6 +176,7 @@ class AbstractGraphComponent(GraphComponent): ...@@ -176,6 +176,7 @@ class AbstractGraphComponent(GraphComponent):
component = fontier.popleft() component = fontier.popleft()
yield component yield component
for neighbor in component.neighbors: for neighbor in component.neighbors:
neighbor = cast(AbstractGraphComponent, neighbor)
if neighbor not in visited: if neighbor not in visited:
visited.add(neighbor) visited.add(neighbor)
fontier.append(neighbor) fontier.append(neighbor)
...@@ -21,6 +21,7 @@ from typing import ( ...@@ -21,6 +21,7 @@ from typing import (
Sequence, Sequence,
Tuple, Tuple,
Union, Union,
cast,
) )
from b_asic.graph_component import ( from b_asic.graph_component import (
...@@ -324,7 +325,7 @@ class Operation(GraphComponent, SignalSourceProvider): ...@@ -324,7 +325,7 @@ class Operation(GraphComponent, SignalSourceProvider):
@property @property
@abstractmethod @abstractmethod
def latency_offsets(self) -> Dict[str, int]: def latency_offsets(self) -> Dict[str, Optional[int]]:
""" """
Get a dictionary with all the operations ports latency-offsets. Get a dictionary with all the operations ports latency-offsets.
""" """
...@@ -569,7 +570,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -569,7 +570,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
def __str__(self) -> str: def __str__(self) -> str:
"""Get a string representation of this operation.""" """Get a string representation of this operation."""
inputs_dict = {} inputs_dict: Dict[int, Union[List[GraphID], str]] = {}
for i, inport in enumerate(self.inputs): for i, inport in enumerate(self.inputs):
if inport.signal_count == 0: if inport.signal_count == 0:
inputs_dict[i] = "-" inputs_dict[i] = "-"
...@@ -588,7 +589,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -588,7 +589,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
dict_ele.append(GraphID("no_id")) dict_ele.append(GraphID("no_id"))
inputs_dict[i] = dict_ele inputs_dict[i] = dict_ele
outputs_dict = {} outputs_dict: Dict[int, Union[List[GraphID], str]] = {}
for i, outport in enumerate(self.outputs): for i, outport in enumerate(self.outputs):
if outport.signal_count == 0: if outport.signal_count == 0:
outputs_dict[i] = "-" outputs_dict[i] = "-"
...@@ -778,7 +779,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -778,7 +779,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
last_operations = [last_operations] last_operations = [last_operations]
outputs = [Output(o) for o in last_operations] outputs = [Output(o) for o in last_operations]
except TypeError: except TypeError:
operation_copy: Operation = self.copy_component() operation_copy: Operation = cast(Operation, self.copy_component())
inputs = [] inputs = []
for i in range(self.input_count): for i in range(self.input_count):
_input = Input() _input = Input()
...@@ -790,7 +791,9 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -790,7 +791,9 @@ class AbstractOperation(Operation, AbstractGraphComponent):
return SFG(inputs=inputs, outputs=outputs) return SFG(inputs=inputs, outputs=outputs)
def copy_component(self, *args, **kwargs) -> GraphComponent: def copy_component(self, *args, **kwargs) -> GraphComponent:
new_component: Operation = super().copy_component(*args, **kwargs) new_component: Operation = cast(
Operation, super().copy_component(*args, **kwargs)
)
for i, inp in enumerate(self.inputs): for i, inp in enumerate(self.inputs):
new_component.input(i).latency_offset = inp.latency_offset new_component.input(i).latency_offset = inp.latency_offset
for i, outp in enumerate(self.outputs): for i, outp in enumerate(self.outputs):
...@@ -885,13 +888,16 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -885,13 +888,16 @@ class AbstractOperation(Operation, AbstractGraphComponent):
return max( return max(
( (
(outp.latency_offset - inp.latency_offset) (
cast(int, outp.latency_offset)
- cast(int, inp.latency_offset)
)
for outp, inp in it.product(self.outputs, self.inputs) for outp, inp in it.product(self.outputs, self.inputs)
) )
) )
@property @property
def latency_offsets(self) -> Dict[str, int]: def latency_offsets(self) -> Dict[str, Optional[int]]:
latency_offsets = {} latency_offsets = {}
for i, inp in enumerate(self.inputs): for i, inp in enumerate(self.inputs):
...@@ -948,13 +954,15 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -948,13 +954,15 @@ class AbstractOperation(Operation, AbstractGraphComponent):
if self._execution_time is not None: if self._execution_time is not None:
self._execution_time *= factor self._execution_time *= factor
for port in [*self.inputs, *self.outputs]: for port in [*self.inputs, *self.outputs]:
port.latency_offset *= factor if port.latency_offset is not None:
port.latency_offset *= factor
def _decrease_time_resolution(self, factor: int) -> None: def _decrease_time_resolution(self, factor: int) -> None:
if self._execution_time is not None: if self._execution_time is not None:
self._execution_time = self._execution_time // factor self._execution_time = self._execution_time // factor
for port in [*self.inputs, *self.outputs]: for port in [*self.inputs, *self.outputs]:
port.latency_offset = port.latency_offset // factor if port.latency_offset is not None:
port.latency_offset = port.latency_offset // factor
def get_plot_coordinates( def get_plot_coordinates(
self, self,
...@@ -978,11 +986,18 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -978,11 +986,18 @@ class AbstractOperation(Operation, AbstractGraphComponent):
[0, 0], [0, 0],
] ]
def _check_all_latencies_set(self):
if any(val is None for _, val in self.latency_offsets.items()):
raise ValueError(
f"All latencies must be set: {self.latency_offsets}"
)
def _get_plot_coordinates_for_latency(self) -> List[List[float]]: def _get_plot_coordinates_for_latency(self) -> List[List[float]]:
self._check_all_latencies_set()
# Points for latency polygon # Points for latency polygon
latency = [] latency = []
# Remember starting point # Remember starting point
start_point = [self.inputs[0].latency_offset, 0] start_point = [self.inputs[0].latency_offset, 0.0]
num_in = self.input_count num_in = self.input_count
latency.append(start_point) latency.append(start_point)
for k in range(1, num_in): for k in range(1, num_in):
...@@ -995,7 +1010,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -995,7 +1010,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
for k in reversed(range(1, num_out)): for k in reversed(range(1, num_out)):
latency.append([self.outputs[k].latency_offset, k / num_out]) latency.append([self.outputs[k].latency_offset, k / num_out])
latency.append([self.outputs[k - 1].latency_offset, k / num_out]) latency.append([self.outputs[k - 1].latency_offset, k / num_out])
latency.append([self.outputs[0].latency_offset, 0]) latency.append([self.outputs[0].latency_offset, 0.0])
# Close the polygon # Close the polygon
latency.append(start_point) latency.append(start_point)
...@@ -1004,6 +1019,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -1004,6 +1019,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
def get_io_coordinates( def get_io_coordinates(
self, self,
) -> Tuple[List[List[float]], List[List[float]]]: ) -> Tuple[List[List[float]], List[List[float]]]:
self._check_all_latencies_set()
# Doc-string inherited # Doc-string inherited
input_coords = [ input_coords = [
[ [
......
...@@ -66,3 +66,9 @@ src_paths = ["b_asic", "test"] ...@@ -66,3 +66,9 @@ src_paths = ["b_asic", "test"]
skip = [ skip = [
"test/test_gui" "test/test_gui"
] ]
[tool.mypy]
packages = ["b_asic", "test"]
no_site_packages = true
ignore_missing_imports = true
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment