Skip to content
Snippets Groups Projects
Commit 04f57ca9 authored by Oscar Gustafsson's avatar Oscar Gustafsson :bicyclist:
Browse files

Typing and general code cleanup

parent 53853ed7
No related branches found
No related tags found
No related merge requests found
Pipeline #89854 failed
......@@ -258,12 +258,12 @@ class AddSub(AbstractOperation):
@property
def is_add(self) -> Num:
"""Get if operation is add."""
"""Get if operation is an addition."""
return self.param("is_add")
@is_add.setter
def is_add(self, is_add: bool) -> None:
"""Set if operation is add."""
"""Set if operation is an addition."""
self.set_param("is_add", is_add)
......@@ -774,7 +774,7 @@ class Reciprocal(AbstractOperation):
latency_offsets: Optional[Dict[str, int]] = None,
execution_time: Optional[int] = None,
):
"""Construct an Reciprocal operation."""
"""Construct a Reciprocal operation."""
super().__init__(
input_count=1,
output_count=1,
......
......@@ -9,7 +9,7 @@ from collections import deque
from copy import copy, deepcopy
from typing import Any, Dict, Generator, Iterable, Mapping, cast
from b_asic.types import GraphID, GraphIDNumber, Name, Num, TypeName
from b_asic.types import GraphID, Name, TypeName
class GraphComponent(ABC):
......
......@@ -11,7 +11,6 @@ from abc import abstractmethod
from numbers import Number
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
......@@ -34,7 +33,7 @@ from b_asic.graph_component import (
)
from b_asic.port import InputPort, OutputPort, SignalSourceProvider
from b_asic.signal import Signal
from b_asic.types import Num, NumRuntime
from b_asic.types import Num
if TYPE_CHECKING:
# Conditionally imported to avoid circular imports
......@@ -593,7 +592,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
# Import here to avoid circular imports.
from b_asic.core_operations import Addition, Constant
if isinstance(src, NumRuntime):
if isinstance(src, Number):
return Addition(self, Constant(src))
else:
return Addition(self, src)
......@@ -603,7 +602,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
from b_asic.core_operations import Addition, Constant
return Addition(
Constant(src) if isinstance(src, NumRuntime) else src, self
Constant(src) if isinstance(src, Number) else src, self
)
def __sub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction":
......@@ -611,7 +610,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
from b_asic.core_operations import Constant, Subtraction
return Subtraction(
self, Constant(src) if isinstance(src, NumRuntime) else src
self, Constant(src) if isinstance(src, Number) else src
)
def __rsub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction":
......@@ -619,7 +618,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
from b_asic.core_operations import Constant, Subtraction
return Subtraction(
Constant(src) if isinstance(src, NumRuntime) else src, self
Constant(src) if isinstance(src, Number) else src, self
)
def __mul__(
......@@ -633,7 +632,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
return (
ConstantMultiplication(src, self)
if isinstance(src, NumRuntime)
if isinstance(src, Number)
else Multiplication(self, src)
)
......@@ -648,7 +647,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
return (
ConstantMultiplication(src, self)
if isinstance(src, NumRuntime)
if isinstance(src, Number)
else Multiplication(src, self)
)
......@@ -657,7 +656,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
from b_asic.core_operations import Constant, Division
return Division(
self, Constant(src) if isinstance(src, NumRuntime) else src
self, Constant(src) if isinstance(src, Number) else src
)
def __rtruediv__(
......@@ -666,7 +665,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
# Import here to avoid circular imports.
from b_asic.core_operations import Constant, Division, Reciprocal
if isinstance(src, NumRuntime):
if isinstance(src, Number):
if src == 1:
return Reciprocal(self)
else:
......
......@@ -35,7 +35,7 @@ def draw_exclusion_graph_coloring(
] = None,
):
"""
Use matplotlib.pyplot and networkx to draw a colored exclusion graph from the memory assigment
Use matplotlib.pyplot and networkx to draw a colored exclusion graph from the memory assignment
.. code-block:: python
......@@ -77,7 +77,6 @@ def draw_exclusion_graph_coloring(
'#aa00aa',
'#00aaaa',
]
node_color_dict = {}
if color_list is None:
node_color_dict = {k: COLOR_LIST[v] for k, v in color_dict.items()}
else:
......@@ -137,8 +136,8 @@ class ProcessCollection:
Parameters
----------
ax : :class:`matplotlib.axes.Axes`, optional
Matplotlib Axes object to draw this lifetime chart onto. If not provided (i.e., set to None), this method will
return a new axes object on return.
Matplotlib Axes object to draw this lifetime chart onto. If not provided (i.e., set to None),
this method will return a new axes object on return.
show_name : bool, default: True
Show name of all processes in the lifetime chart.
......@@ -147,7 +146,7 @@ class ProcessCollection:
ax: Associated Matplotlib Axes (or array of Axes) object
"""
# Setup the Axes object
# Set up the Axes object
if ax is None:
_, _ax = plt.subplots()
else:
......@@ -159,7 +158,7 @@ class ProcessCollection:
process.execution_time for process in self._collection
)
if max_execution_time > self._schedule_time:
# Schedule time needs to be greater than or equal to the maximum process life time
# Schedule time needs to be greater than or equal to the maximum process lifetime
raise KeyError(
f'Error: Schedule time: {self._schedule_time} < Max execution'
f' time: {max_execution_time}'
......@@ -222,7 +221,7 @@ class ProcessCollection:
self, add_name: bool = True
) -> nx.Graph:
"""
Generate exclusion graph based on processes overlaping in time
Generate exclusion graph based on processes overlapping in time
Parameters
----------
......@@ -270,20 +269,20 @@ class ProcessCollection:
Parameters
----------
heuristic : str, default: "graph_color"
The heuristic used when spliting this ProcessCollection.
The heuristic used when splitting this ProcessCollection.
Valid options are:
* "graph_color"
* "..."
read_ports : int, optional
The number of read ports used when spliting process collection based on memory variable access.
The number of read ports used when splitting process collection based on memory variable access.
write_ports : int, optional
The number of write ports used when spliting process collection based on memory variable access.
The number of write ports used when splitting process collection based on memory variable access.
total_ports : int, optional
The total number of ports used when spliting process collection based on memory variable access.
The total number of ports used when splitting process collection based on memory variable access.
Returns
-------
A set of new ProcessColleciton objects with the process spliting.
A set of new ProcessCollection objects with the process splitting.
"""
if total_ports is None:
if read_ports is None or write_ports is None:
......@@ -308,20 +307,20 @@ class ProcessCollection:
Parameters
----------
read_ports : int, optional
The number of read ports used when spliting process collection based on memory variable access.
The number of read ports used when splitting process collection based on memory variable access.
write_ports : int, optional
The number of write ports used when spliting process collection based on memory variable access.
The number of write ports used when splitting process collection based on memory variable access.
total_ports : int, optional
The total number of ports used when spliting process collection based on memory variable access.
The total number of ports used when splitting process collection based on memory variable access.
"""
if read_ports != 1 or write_ports != 1:
raise ValueError(
"Spliting with read and write ports not equal to one with the"
"Splitting with read and write ports not equal to one with the"
" graph coloring heuristic does not make sense."
)
if total_ports not in (1, 2):
raise ValueError(
"Total ports should be either 1 (non-concurent reads/writes)"
"Total ports should be either 1 (non-concurrent reads/writes)"
" or 2 (concurrent read/writes) for graph coloring heuristic."
)
......
......@@ -603,7 +603,7 @@ class Schedule:
def _get_y_position(
self, graph_id, operation_height=1.0, operation_gap=None
):
) -> float:
if operation_gap is None:
operation_gap = OPERATION_GAP
y_location = self._y_locations[graph_id]
......@@ -617,11 +617,15 @@ class Schedule:
self._y_locations[graph_id] = y_location
return operation_gap + y_location * (operation_height + operation_gap)
def _plot_schedule(self, ax, operation_gap: Optional[float] = None):
def _plot_schedule(
self, ax: Axes, operation_gap: Optional[float] = None
) -> None:
"""Draw the schedule."""
line_cache = []
def _draw_arrow(start, end, name="", laps=0):
def _draw_arrow(
start: List[float], end: List[float], name: str = "", laps: int = 0
):
"""Draw an arrow from *start* to *end*."""
if end[0] < start[0] or laps > 0: # Wrap around
if start not in line_cache:
......@@ -848,7 +852,7 @@ class Schedule:
self._plot_schedule(ax, operation_gap=operation_gap)
return fig
def _repr_svg_(self):
def _repr_svg_(self) -> str:
"""
Generate an SVG of the schedule. This is automatically displayed in e.g.
Jupyter Qt console.
......
......@@ -72,9 +72,8 @@ def wdf_allpass(
odd_order = order % 2
if odd_order:
# First-order section
coeff = np_coefficients[0]
adaptor0 = SymmetricTwoportAdaptor(
coeff,
np_coefficients[0],
input_op,
latency=latency,
latency_offsets=latency_offsets,
......@@ -185,10 +184,11 @@ def direct_form_fir(
prev_add = None
for i, coeff in enumerate(np_coefficients):
tmp_mul = ConstantMultiplication(coeff, prev_delay, **mult_properties)
if prev_add is None:
prev_add = tmp_mul
else:
prev_add = Addition(tmp_mul, prev_add, **add_properties)
prev_add = (
tmp_mul
if prev_add is None
else Addition(tmp_mul, prev_add, **add_properties)
)
if i < taps - 1:
prev_delay = Delay(prev_delay)
......@@ -266,10 +266,11 @@ def transposed_direct_form_fir(
prev_add = None
for i, coeff in enumerate(reversed(np_coefficients)):
tmp_mul = ConstantMultiplication(coeff, input_op, **mult_properties)
if prev_delay is None:
tmp_add = tmp_mul
else:
tmp_add = Addition(tmp_mul, prev_delay, **add_properties)
tmp_add = (
tmp_mul
if prev_delay is None
else Addition(tmp_mul, prev_delay, **add_properties)
)
if i < taps - 1:
prev_delay = Delay(tmp_add)
......
......@@ -27,13 +27,7 @@ from typing import (
from graphviz import Digraph
from b_asic.graph_component import (
GraphComponent,
GraphID,
GraphIDNumber,
Name,
TypeName,
)
from b_asic.graph_component import GraphComponent
from b_asic.operation import (
AbstractOperation,
MutableDelayMap,
......@@ -44,6 +38,7 @@ from b_asic.operation import (
from b_asic.port import InputPort, OutputPort, SignalSourceProvider
from b_asic.signal import Signal
from b_asic.special_operations import Delay, Input, Output
from b_asic.types import GraphID, GraphIDNumber, Name, Num, TypeName
DelayQueue = List[Tuple[str, ResultKey, OutputPort]]
......@@ -377,7 +372,7 @@ class SFG(AbstractOperation):
def evaluate_output(
self,
index: int,
input_values: Sequence[Number],
input_values: Sequence[Num],
results: Optional[MutableResultMap] = None,
delays: Optional[MutableDelayMap] = None,
prefix: str = "",
......@@ -465,11 +460,11 @@ class SFG(AbstractOperation):
for input_port, input_operation in zip(
self.inputs, self.input_operations
):
dest = input_operation.output(0).signals[0].destination
if dest is None:
destination = input_operation.output(0).signals[0].destination
if destination is None:
raise ValueError("Missing destination in signal.")
dest.clear()
input_port.signals[0].set_destination(dest)
destination.clear()
input_port.signals[0].set_destination(destination)
# For each output_signal, connect it to the corresponding operation
for output_port, output_operation in zip(
self.outputs, self.output_operations
......@@ -825,12 +820,12 @@ class SFG(AbstractOperation):
with pg.subgraph(name=f"cluster_{i}") as sub:
sub.attr(label=f"N{i}")
for port in ports:
portstr = f"{port.operation.graph_id}.{port.index}"
port_string = f"{port.operation.graph_id}.{port.index}"
if port.operation.output_count > 1:
sub.node(portstr)
sub.node(port_string)
else:
sub.node(
portstr,
port_string,
shape='rectangle',
label=port.operation.graph_id,
height="0.1",
......@@ -841,28 +836,29 @@ class SFG(AbstractOperation):
for i in range(len(p_list)):
ports = p_list[i]
for port in ports:
source_label = port.operation.graph_id
node_node = f"{source_label}.{port.index}"
for signal in port.signals:
destination = cast(InputPort, signal.destination)
if isinstance(destination.operation, Delay):
dest_node = destination.operation.graph_id + "In"
else:
dest_node = destination.operation.graph_id
dest_label = destination.operation.graph_id
node_node = f"{port.operation.graph_id}.{port.index}"
pg.edge(node_node, dest_node)
destination_label = destination.operation.graph_id
destination_node = (
destination_label + "In"
if isinstance(destination.operation, Delay)
else destination_label
)
pg.edge(node_node, destination_node)
pg.node(
dest_node,
label=dest_label,
destination_node,
label=destination_label,
shape=_OPERATION_SHAPE[
destination.operation.type_name()
],
)
if port.operation.type_name() == Delay.type_name():
source_node = port.operation.graph_id + "Out"
else:
source_node = port.operation.graph_id
source_label = port.operation.graph_id
node_node = f"{port.operation.graph_id}.{port.index}"
source_node = (
source_label + "Out"
if port.operation.type_name() == Delay.type_name()
else source_label
)
pg.edge(source_node, node_node)
pg.node(
source_node,
......@@ -887,8 +883,8 @@ class SFG(AbstractOperation):
printed_ops = set()
for iter_num, iter in enumerate(precedence_list, start=1):
for outport_num, outport in enumerate(iter, start=1):
for iter_num, iterable in enumerate(precedence_list, start=1):
for outport_num, outport in enumerate(iterable, start=1):
if outport not in printed_ops:
# Only print once per operation, even if it has multiple outports
out_str.write("\n")
......@@ -1160,7 +1156,7 @@ class SFG(AbstractOperation):
self._components_dfs_order.extend(
[new_signal, source.operation]
)
if not source.operation in self._operations_dfs_order:
if source.operation not in self._operations_dfs_order:
self._operations_dfs_order.append(source.operation)
# Check if the signal has not been added before.
......@@ -1331,7 +1327,7 @@ class SFG(AbstractOperation):
bits_override: Optional[int],
truncate: bool,
deferred_delays: DelayQueue,
) -> Number:
) -> Num:
key_base = (
(prefix + "." + src.operation.graph_id)
if prefix
......@@ -1376,7 +1372,7 @@ class SFG(AbstractOperation):
bits_override: Optional[int],
truncate: bool,
deferred_delays: DelayQueue,
) -> Number:
) -> Num:
input_values = [
self._evaluate_source(
input_port.signals[0].source,
......@@ -1458,13 +1454,13 @@ class SFG(AbstractOperation):
"image/png"
]
def show(self, format=None, show_id=False, engine=None) -> None:
def show(self, fmt=None, show_id=False, engine=None) -> None:
"""
Shows a visual representation of the SFG using the default system viewer.
Parameters
----------
format : string, optional
fmt : string, optional
File format of the generated graph. Output formats can be found at
https://www.graphviz.org/doc/info/output.html
Most common are "pdf", "eps", "png", and "svg". Default is None which
......@@ -1481,8 +1477,8 @@ class SFG(AbstractOperation):
dg = self.sfg_digraph(show_id=show_id)
if engine is not None:
dg.engine = engine
if format is not None:
dg.format = format
if fmt is not None:
dg.format = fmt
dg.view()
def critical_path(self):
......
......@@ -8,6 +8,7 @@ from collections import defaultdict
from numbers import Number
from typing import (
Callable,
List,
Mapping,
MutableMapping,
MutableSequence,
......@@ -20,11 +21,12 @@ import numpy as np
from b_asic.operation import MutableDelayMap, ResultKey
from b_asic.signal_flow_graph import SFG
from b_asic.types import Num
ResultArrayMap = Mapping[ResultKey, Sequence[Number]]
MutableResultArrayMap = MutableMapping[ResultKey, MutableSequence[Number]]
InputFunction = Callable[[int], Number]
InputProvider = Union[Number, Sequence[Number], InputFunction]
ResultArrayMap = Mapping[ResultKey, Sequence[Num]]
MutableResultArrayMap = MutableMapping[ResultKey, MutableSequence[Num]]
InputFunction = Callable[[int], Num]
InputProvider = Union[Num, Sequence[Num], InputFunction]
class Simulation:
......@@ -39,7 +41,7 @@ class Simulation:
_results: MutableResultArrayMap
_delays: MutableDelayMap
_iteration: int
_input_functions: Sequence[InputFunction]
_input_functions: List[InputFunction]
_input_length: Optional[int]
def __init__(
......@@ -105,7 +107,7 @@ class Simulation:
save_results: bool = True,
bits_override: Optional[int] = None,
truncate: bool = True,
) -> Sequence[Number]:
) -> Sequence[Num]:
"""
Run one iteration of the simulation and return the resulting output values.
"""
......@@ -117,12 +119,12 @@ class Simulation:
save_results: bool = True,
bits_override: Optional[int] = None,
truncate: bool = True,
) -> Sequence[Number]:
) -> Sequence[Num]:
"""
Run the simulation until its iteration is greater than or equal to the given
iteration and return the output values of the last iteration.
"""
result: Sequence[Number] = []
result: Sequence[Num] = []
while self._iteration < iteration:
input_values = [
self._input_functions[i](self._iteration)
......@@ -149,7 +151,7 @@ class Simulation:
save_results: bool = True,
bits_override: Optional[int] = None,
truncate: bool = True,
) -> Sequence[Number]:
) -> Sequence[Num]:
"""
Run a given number of iterations of the simulation and return the output
values of the last iteration.
......@@ -163,7 +165,7 @@ class Simulation:
save_results: bool = True,
bits_override: Optional[int] = None,
truncate: bool = True,
) -> Sequence[Number]:
) -> Sequence[Num]:
"""
Run the simulation until the end of its input arrays and return the output
values of the last iteration.
......
......@@ -5,9 +5,8 @@ Contains operations with special purposes that may be treated differently from
normal operations in an SFG.
"""
from typing import List, Optional, Sequence, Tuple
from typing import Optional, Sequence, Tuple
from b_asic.graph_component import Name, TypeName
from b_asic.operation import (
AbstractOperation,
DelayMap,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment