From fc3fa7649e441b009e8ee4888db919315de3600a Mon Sep 17 00:00:00 2001 From: Oscar Gustafsson <oscar.gustafsson@gmail.com> Date: Thu, 16 Feb 2023 19:36:15 +0100 Subject: [PATCH] Typing and general code cleanup --- b_asic/core_operations.py | 6 +-- b_asic/graph_component.py | 2 +- b_asic/operation.py | 19 +++++---- b_asic/resources.py | 33 ++++++++-------- b_asic/schedule.py | 12 ++++-- b_asic/sfg_generators.py | 21 +++++----- b_asic/signal_flow_graph.py | 74 +++++++++++++++++------------------- b_asic/simulation.py | 22 ++++++----- b_asic/special_operations.py | 3 +- test/test_sfg.py | 2 +- 10 files changed, 97 insertions(+), 97 deletions(-) diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py index aa9e33d4..d67b0242 100644 --- a/b_asic/core_operations.py +++ b/b_asic/core_operations.py @@ -258,12 +258,12 @@ class AddSub(AbstractOperation): @property def is_add(self) -> Num: - """Get if operation is add.""" + """Get if operation is an addition.""" return self.param("is_add") @is_add.setter def is_add(self, is_add: bool) -> None: - """Set if operation is add.""" + """Set if operation is an addition.""" self.set_param("is_add", is_add) @@ -774,7 +774,7 @@ class Reciprocal(AbstractOperation): latency_offsets: Optional[Dict[str, int]] = None, execution_time: Optional[int] = None, ): - """Construct an Reciprocal operation.""" + """Construct a Reciprocal operation.""" super().__init__( input_count=1, output_count=1, diff --git a/b_asic/graph_component.py b/b_asic/graph_component.py index 1f910c30..f4725892 100644 --- a/b_asic/graph_component.py +++ b/b_asic/graph_component.py @@ -9,7 +9,7 @@ from collections import deque from copy import copy, deepcopy from typing import Any, Dict, Generator, Iterable, Mapping, cast -from b_asic.types import GraphID, GraphIDNumber, Name, Num, TypeName +from b_asic.types import GraphID, Name, TypeName class GraphComponent(ABC): diff --git a/b_asic/operation.py b/b_asic/operation.py index 8b8251dc..6548cb45 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -11,7 +11,6 @@ from abc import abstractmethod from numbers import Number from typing import ( TYPE_CHECKING, - Any, Dict, Iterable, List, @@ -34,7 +33,7 @@ from b_asic.graph_component import ( ) from b_asic.port import InputPort, OutputPort, SignalSourceProvider from b_asic.signal import Signal -from b_asic.types import Num, NumRuntime +from b_asic.types import Num if TYPE_CHECKING: # Conditionally imported to avoid circular imports @@ -593,7 +592,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): # Import here to avoid circular imports. from b_asic.core_operations import Addition, Constant - if isinstance(src, NumRuntime): + if isinstance(src, Number): return Addition(self, Constant(src)) else: return Addition(self, src) @@ -603,7 +602,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): from b_asic.core_operations import Addition, Constant return Addition( - Constant(src) if isinstance(src, NumRuntime) else src, self + Constant(src) if isinstance(src, Number) else src, self ) def __sub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction": @@ -611,7 +610,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): from b_asic.core_operations import Constant, Subtraction return Subtraction( - self, Constant(src) if isinstance(src, NumRuntime) else src + self, Constant(src) if isinstance(src, Number) else src ) def __rsub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction": @@ -619,7 +618,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): from b_asic.core_operations import Constant, Subtraction return Subtraction( - Constant(src) if isinstance(src, NumRuntime) else src, self + Constant(src) if isinstance(src, Number) else src, self ) def __mul__( @@ -633,7 +632,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): return ( ConstantMultiplication(src, self) - if isinstance(src, NumRuntime) + if isinstance(src, Number) else Multiplication(self, src) ) @@ -648,7 +647,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): return ( ConstantMultiplication(src, self) - if isinstance(src, NumRuntime) + if isinstance(src, Number) else Multiplication(src, self) ) @@ -657,7 +656,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): from b_asic.core_operations import Constant, Division return Division( - self, Constant(src) if isinstance(src, NumRuntime) else src + self, Constant(src) if isinstance(src, Number) else src ) def __rtruediv__( @@ -666,7 +665,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): # Import here to avoid circular imports. from b_asic.core_operations import Constant, Division, Reciprocal - if isinstance(src, NumRuntime): + if isinstance(src, Number): if src == 1: return Reciprocal(self) else: diff --git a/b_asic/resources.py b/b_asic/resources.py index 6ad988ff..667eba29 100644 --- a/b_asic/resources.py +++ b/b_asic/resources.py @@ -35,7 +35,7 @@ def draw_exclusion_graph_coloring( ] = None, ): """ - Use matplotlib.pyplot and networkx to draw a colored exclusion graph from the memory assigment + Use matplotlib.pyplot and networkx to draw a colored exclusion graph from the memory assignment .. code-block:: python @@ -77,7 +77,6 @@ def draw_exclusion_graph_coloring( '#aa00aa', '#00aaaa', ] - node_color_dict = {} if color_list is None: node_color_dict = {k: COLOR_LIST[v] for k, v in color_dict.items()} else: @@ -137,8 +136,8 @@ class ProcessCollection: Parameters ---------- ax : :class:`matplotlib.axes.Axes`, optional - Matplotlib Axes object to draw this lifetime chart onto. If not provided (i.e., set to None), this method will - return a new axes object on return. + Matplotlib Axes object to draw this lifetime chart onto. If not provided (i.e., set to None), + this method will return a new axes object on return. show_name : bool, default: True Show name of all processes in the lifetime chart. @@ -147,7 +146,7 @@ class ProcessCollection: ax: Associated Matplotlib Axes (or array of Axes) object """ - # Setup the Axes object + # Set up the Axes object if ax is None: _, _ax = plt.subplots() else: @@ -159,7 +158,7 @@ class ProcessCollection: process.execution_time for process in self._collection ) if max_execution_time > self._schedule_time: - # Schedule time needs to be greater than or equal to the maximum process life time + # Schedule time needs to be greater than or equal to the maximum process lifetime raise KeyError( f'Error: Schedule time: {self._schedule_time} < Max execution' f' time: {max_execution_time}' @@ -222,7 +221,7 @@ class ProcessCollection: self, add_name: bool = True ) -> nx.Graph: """ - Generate exclusion graph based on processes overlaping in time + Generate exclusion graph based on processes overlapping in time Parameters ---------- @@ -270,20 +269,20 @@ class ProcessCollection: Parameters ---------- heuristic : str, default: "graph_color" - The heuristic used when spliting this ProcessCollection. + The heuristic used when splitting this ProcessCollection. Valid options are: * "graph_color" * "..." read_ports : int, optional - The number of read ports used when spliting process collection based on memory variable access. + The number of read ports used when splitting process collection based on memory variable access. write_ports : int, optional - The number of write ports used when spliting process collection based on memory variable access. + The number of write ports used when splitting process collection based on memory variable access. total_ports : int, optional - The total number of ports used when spliting process collection based on memory variable access. + The total number of ports used when splitting process collection based on memory variable access. Returns ------- - A set of new ProcessColleciton objects with the process spliting. + A set of new ProcessCollection objects with the process splitting. """ if total_ports is None: if read_ports is None or write_ports is None: @@ -308,20 +307,20 @@ class ProcessCollection: Parameters ---------- read_ports : int, optional - The number of read ports used when spliting process collection based on memory variable access. + The number of read ports used when splitting process collection based on memory variable access. write_ports : int, optional - The number of write ports used when spliting process collection based on memory variable access. + The number of write ports used when splitting process collection based on memory variable access. total_ports : int, optional - The total number of ports used when spliting process collection based on memory variable access. + The total number of ports used when splitting process collection based on memory variable access. """ if read_ports != 1 or write_ports != 1: raise ValueError( - "Spliting with read and write ports not equal to one with the" + "Splitting with read and write ports not equal to one with the" " graph coloring heuristic does not make sense." ) if total_ports not in (1, 2): raise ValueError( - "Total ports should be either 1 (non-concurent reads/writes)" + "Total ports should be either 1 (non-concurrent reads/writes)" " or 2 (concurrent read/writes) for graph coloring heuristic." ) diff --git a/b_asic/schedule.py b/b_asic/schedule.py index 575759b2..bab48e36 100644 --- a/b_asic/schedule.py +++ b/b_asic/schedule.py @@ -603,7 +603,7 @@ class Schedule: def _get_y_position( self, graph_id, operation_height=1.0, operation_gap=None - ): + ) -> float: if operation_gap is None: operation_gap = OPERATION_GAP y_location = self._y_locations[graph_id] @@ -617,11 +617,15 @@ class Schedule: self._y_locations[graph_id] = y_location return operation_gap + y_location * (operation_height + operation_gap) - def _plot_schedule(self, ax, operation_gap: Optional[float] = None): + def _plot_schedule( + self, ax: Axes, operation_gap: Optional[float] = None + ) -> None: """Draw the schedule.""" line_cache = [] - def _draw_arrow(start, end, name="", laps=0): + def _draw_arrow( + start: List[float], end: List[float], name: str = "", laps: int = 0 + ): """Draw an arrow from *start* to *end*.""" if end[0] < start[0] or laps > 0: # Wrap around if start not in line_cache: @@ -848,7 +852,7 @@ class Schedule: self._plot_schedule(ax, operation_gap=operation_gap) return fig - def _repr_svg_(self): + def _repr_svg_(self) -> str: """ Generate an SVG of the schedule. This is automatically displayed in e.g. Jupyter Qt console. diff --git a/b_asic/sfg_generators.py b/b_asic/sfg_generators.py index 48edf226..37d011f8 100644 --- a/b_asic/sfg_generators.py +++ b/b_asic/sfg_generators.py @@ -72,9 +72,8 @@ def wdf_allpass( odd_order = order % 2 if odd_order: # First-order section - coeff = np_coefficients[0] adaptor0 = SymmetricTwoportAdaptor( - coeff, + np_coefficients[0], input_op, latency=latency, latency_offsets=latency_offsets, @@ -185,10 +184,11 @@ def direct_form_fir( prev_add = None for i, coeff in enumerate(np_coefficients): tmp_mul = ConstantMultiplication(coeff, prev_delay, **mult_properties) - if prev_add is None: - prev_add = tmp_mul - else: - prev_add = Addition(tmp_mul, prev_add, **add_properties) + prev_add = ( + tmp_mul + if prev_add is None + else Addition(tmp_mul, prev_add, **add_properties) + ) if i < taps - 1: prev_delay = Delay(prev_delay) @@ -266,10 +266,11 @@ def transposed_direct_form_fir( prev_add = None for i, coeff in enumerate(reversed(np_coefficients)): tmp_mul = ConstantMultiplication(coeff, input_op, **mult_properties) - if prev_delay is None: - tmp_add = tmp_mul - else: - tmp_add = Addition(tmp_mul, prev_delay, **add_properties) + tmp_add = ( + tmp_mul + if prev_delay is None + else Addition(tmp_mul, prev_delay, **add_properties) + ) if i < taps - 1: prev_delay = Delay(tmp_add) diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 9e0c33f3..417baa24 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -27,13 +27,7 @@ from typing import ( from graphviz import Digraph -from b_asic.graph_component import ( - GraphComponent, - GraphID, - GraphIDNumber, - Name, - TypeName, -) +from b_asic.graph_component import GraphComponent from b_asic.operation import ( AbstractOperation, MutableDelayMap, @@ -44,6 +38,7 @@ from b_asic.operation import ( from b_asic.port import InputPort, OutputPort, SignalSourceProvider from b_asic.signal import Signal from b_asic.special_operations import Delay, Input, Output +from b_asic.types import GraphID, GraphIDNumber, Name, Num, TypeName DelayQueue = List[Tuple[str, ResultKey, OutputPort]] @@ -377,7 +372,7 @@ class SFG(AbstractOperation): def evaluate_output( self, index: int, - input_values: Sequence[Number], + input_values: Sequence[Num], results: Optional[MutableResultMap] = None, delays: Optional[MutableDelayMap] = None, prefix: str = "", @@ -465,11 +460,11 @@ class SFG(AbstractOperation): for input_port, input_operation in zip( self.inputs, self.input_operations ): - dest = input_operation.output(0).signals[0].destination - if dest is None: + destination = input_operation.output(0).signals[0].destination + if destination is None: raise ValueError("Missing destination in signal.") - dest.clear() - input_port.signals[0].set_destination(dest) + destination.clear() + input_port.signals[0].set_destination(destination) # For each output_signal, connect it to the corresponding operation for output_port, output_operation in zip( self.outputs, self.output_operations @@ -825,12 +820,12 @@ class SFG(AbstractOperation): with pg.subgraph(name=f"cluster_{i}") as sub: sub.attr(label=f"N{i}") for port in ports: - portstr = f"{port.operation.graph_id}.{port.index}" + port_string = f"{port.operation.graph_id}.{port.index}" if port.operation.output_count > 1: - sub.node(portstr) + sub.node(port_string) else: sub.node( - portstr, + port_string, shape='rectangle', label=port.operation.graph_id, height="0.1", @@ -841,28 +836,29 @@ class SFG(AbstractOperation): for i in range(len(p_list)): ports = p_list[i] for port in ports: + source_label = port.operation.graph_id + node_node = f"{source_label}.{port.index}" for signal in port.signals: destination = cast(InputPort, signal.destination) - if isinstance(destination.operation, Delay): - dest_node = destination.operation.graph_id + "In" - else: - dest_node = destination.operation.graph_id - dest_label = destination.operation.graph_id - node_node = f"{port.operation.graph_id}.{port.index}" - pg.edge(node_node, dest_node) + destination_label = destination.operation.graph_id + destination_node = ( + destination_label + "In" + if isinstance(destination.operation, Delay) + else destination_label + ) + pg.edge(node_node, destination_node) pg.node( - dest_node, - label=dest_label, + destination_node, + label=destination_label, shape=_OPERATION_SHAPE[ destination.operation.type_name() ], ) - if port.operation.type_name() == Delay.type_name(): - source_node = port.operation.graph_id + "Out" - else: - source_node = port.operation.graph_id - source_label = port.operation.graph_id - node_node = f"{port.operation.graph_id}.{port.index}" + source_node = ( + source_label + "Out" + if port.operation.type_name() == Delay.type_name() + else source_label + ) pg.edge(source_node, node_node) pg.node( source_node, @@ -887,8 +883,8 @@ class SFG(AbstractOperation): printed_ops = set() - for iter_num, iter in enumerate(precedence_list, start=1): - for outport_num, outport in enumerate(iter, start=1): + for iter_num, iterable in enumerate(precedence_list, start=1): + for outport_num, outport in enumerate(iterable, start=1): if outport not in printed_ops: # Only print once per operation, even if it has multiple outports out_str.write("\n") @@ -1160,7 +1156,7 @@ class SFG(AbstractOperation): self._components_dfs_order.extend( [new_signal, source.operation] ) - if not source.operation in self._operations_dfs_order: + if source.operation not in self._operations_dfs_order: self._operations_dfs_order.append(source.operation) # Check if the signal has not been added before. @@ -1331,7 +1327,7 @@ class SFG(AbstractOperation): bits_override: Optional[int], truncate: bool, deferred_delays: DelayQueue, - ) -> Number: + ) -> Num: key_base = ( (prefix + "." + src.operation.graph_id) if prefix @@ -1376,7 +1372,7 @@ class SFG(AbstractOperation): bits_override: Optional[int], truncate: bool, deferred_delays: DelayQueue, - ) -> Number: + ) -> Num: input_values = [ self._evaluate_source( input_port.signals[0].source, @@ -1458,13 +1454,13 @@ class SFG(AbstractOperation): "image/png" ] - def show(self, format=None, show_id=False, engine=None) -> None: + def show(self, fmt=None, show_id=False, engine=None) -> None: """ Shows a visual representation of the SFG using the default system viewer. Parameters ---------- - format : string, optional + fmt : string, optional File format of the generated graph. Output formats can be found at https://www.graphviz.org/doc/info/output.html Most common are "pdf", "eps", "png", and "svg". Default is None which @@ -1481,8 +1477,8 @@ class SFG(AbstractOperation): dg = self.sfg_digraph(show_id=show_id) if engine is not None: dg.engine = engine - if format is not None: - dg.format = format + if fmt is not None: + dg.format = fmt dg.view() def critical_path(self): diff --git a/b_asic/simulation.py b/b_asic/simulation.py index 75bf31af..c9fa14eb 100644 --- a/b_asic/simulation.py +++ b/b_asic/simulation.py @@ -8,6 +8,7 @@ from collections import defaultdict from numbers import Number from typing import ( Callable, + List, Mapping, MutableMapping, MutableSequence, @@ -20,11 +21,12 @@ import numpy as np from b_asic.operation import MutableDelayMap, ResultKey from b_asic.signal_flow_graph import SFG +from b_asic.types import Num -ResultArrayMap = Mapping[ResultKey, Sequence[Number]] -MutableResultArrayMap = MutableMapping[ResultKey, MutableSequence[Number]] -InputFunction = Callable[[int], Number] -InputProvider = Union[Number, Sequence[Number], InputFunction] +ResultArrayMap = Mapping[ResultKey, Sequence[Num]] +MutableResultArrayMap = MutableMapping[ResultKey, MutableSequence[Num]] +InputFunction = Callable[[int], Num] +InputProvider = Union[Num, Sequence[Num], InputFunction] class Simulation: @@ -39,7 +41,7 @@ class Simulation: _results: MutableResultArrayMap _delays: MutableDelayMap _iteration: int - _input_functions: Sequence[InputFunction] + _input_functions: List[InputFunction] _input_length: Optional[int] def __init__( @@ -105,7 +107,7 @@ class Simulation: save_results: bool = True, bits_override: Optional[int] = None, truncate: bool = True, - ) -> Sequence[Number]: + ) -> Sequence[Num]: """ Run one iteration of the simulation and return the resulting output values. """ @@ -117,12 +119,12 @@ class Simulation: save_results: bool = True, bits_override: Optional[int] = None, truncate: bool = True, - ) -> Sequence[Number]: + ) -> Sequence[Num]: """ Run the simulation until its iteration is greater than or equal to the given iteration and return the output values of the last iteration. """ - result: Sequence[Number] = [] + result: Sequence[Num] = [] while self._iteration < iteration: input_values = [ self._input_functions[i](self._iteration) @@ -149,7 +151,7 @@ class Simulation: save_results: bool = True, bits_override: Optional[int] = None, truncate: bool = True, - ) -> Sequence[Number]: + ) -> Sequence[Num]: """ Run a given number of iterations of the simulation and return the output values of the last iteration. @@ -163,7 +165,7 @@ class Simulation: save_results: bool = True, bits_override: Optional[int] = None, truncate: bool = True, - ) -> Sequence[Number]: + ) -> Sequence[Num]: """ Run the simulation until the end of its input arrays and return the output values of the last iteration. diff --git a/b_asic/special_operations.py b/b_asic/special_operations.py index 16edaaea..35b78efa 100644 --- a/b_asic/special_operations.py +++ b/b_asic/special_operations.py @@ -5,9 +5,8 @@ Contains operations with special purposes that may be treated differently from normal operations in an SFG. """ -from typing import List, Optional, Sequence, Tuple +from typing import Optional, Sequence, Tuple -from b_asic.graph_component import Name, TypeName from b_asic.operation import ( AbstractOperation, DelayMap, diff --git a/test/test_sfg.py b/test/test_sfg.py index 433c7b17..a4276fa6 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -1424,7 +1424,7 @@ class TestSFGGraph: def test_show_sfg_invalid_format(self, sfg_simple_filter): with pytest.raises(ValueError): - sfg_simple_filter.show(format="ppddff") + sfg_simple_filter.show(fmt="ppddff") def test_show_sfg_invalid_engine(self, sfg_simple_filter): with pytest.raises(ValueError): -- GitLab