From fc3fa7649e441b009e8ee4888db919315de3600a Mon Sep 17 00:00:00 2001
From: Oscar Gustafsson <oscar.gustafsson@gmail.com>
Date: Thu, 16 Feb 2023 19:36:15 +0100
Subject: [PATCH] Typing and general code cleanup

---
 b_asic/core_operations.py    |  6 +--
 b_asic/graph_component.py    |  2 +-
 b_asic/operation.py          | 19 +++++----
 b_asic/resources.py          | 33 ++++++++--------
 b_asic/schedule.py           | 12 ++++--
 b_asic/sfg_generators.py     | 21 +++++-----
 b_asic/signal_flow_graph.py  | 74 +++++++++++++++++-------------------
 b_asic/simulation.py         | 22 ++++++-----
 b_asic/special_operations.py |  3 +-
 test/test_sfg.py             |  2 +-
 10 files changed, 97 insertions(+), 97 deletions(-)

diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py
index aa9e33d4..d67b0242 100644
--- a/b_asic/core_operations.py
+++ b/b_asic/core_operations.py
@@ -258,12 +258,12 @@ class AddSub(AbstractOperation):
 
     @property
     def is_add(self) -> Num:
-        """Get if operation is add."""
+        """Get if operation is an addition."""
         return self.param("is_add")
 
     @is_add.setter
     def is_add(self, is_add: bool) -> None:
-        """Set if operation is add."""
+        """Set if operation is an addition."""
         self.set_param("is_add", is_add)
 
 
@@ -774,7 +774,7 @@ class Reciprocal(AbstractOperation):
         latency_offsets: Optional[Dict[str, int]] = None,
         execution_time: Optional[int] = None,
     ):
-        """Construct an Reciprocal operation."""
+        """Construct a Reciprocal operation."""
         super().__init__(
             input_count=1,
             output_count=1,
diff --git a/b_asic/graph_component.py b/b_asic/graph_component.py
index 1f910c30..f4725892 100644
--- a/b_asic/graph_component.py
+++ b/b_asic/graph_component.py
@@ -9,7 +9,7 @@ from collections import deque
 from copy import copy, deepcopy
 from typing import Any, Dict, Generator, Iterable, Mapping, cast
 
-from b_asic.types import GraphID, GraphIDNumber, Name, Num, TypeName
+from b_asic.types import GraphID, Name, TypeName
 
 
 class GraphComponent(ABC):
diff --git a/b_asic/operation.py b/b_asic/operation.py
index 8b8251dc..6548cb45 100644
--- a/b_asic/operation.py
+++ b/b_asic/operation.py
@@ -11,7 +11,6 @@ from abc import abstractmethod
 from numbers import Number
 from typing import (
     TYPE_CHECKING,
-    Any,
     Dict,
     Iterable,
     List,
@@ -34,7 +33,7 @@ from b_asic.graph_component import (
 )
 from b_asic.port import InputPort, OutputPort, SignalSourceProvider
 from b_asic.signal import Signal
-from b_asic.types import Num, NumRuntime
+from b_asic.types import Num
 
 if TYPE_CHECKING:
     # Conditionally imported to avoid circular imports
@@ -593,7 +592,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
         # Import here to avoid circular imports.
         from b_asic.core_operations import Addition, Constant
 
-        if isinstance(src, NumRuntime):
+        if isinstance(src, Number):
             return Addition(self, Constant(src))
         else:
             return Addition(self, src)
@@ -603,7 +602,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
         from b_asic.core_operations import Addition, Constant
 
         return Addition(
-            Constant(src) if isinstance(src, NumRuntime) else src, self
+            Constant(src) if isinstance(src, Number) else src, self
         )
 
     def __sub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction":
@@ -611,7 +610,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
         from b_asic.core_operations import Constant, Subtraction
 
         return Subtraction(
-            self, Constant(src) if isinstance(src, NumRuntime) else src
+            self, Constant(src) if isinstance(src, Number) else src
         )
 
     def __rsub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction":
@@ -619,7 +618,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
         from b_asic.core_operations import Constant, Subtraction
 
         return Subtraction(
-            Constant(src) if isinstance(src, NumRuntime) else src, self
+            Constant(src) if isinstance(src, Number) else src, self
         )
 
     def __mul__(
@@ -633,7 +632,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
 
         return (
             ConstantMultiplication(src, self)
-            if isinstance(src, NumRuntime)
+            if isinstance(src, Number)
             else Multiplication(self, src)
         )
 
@@ -648,7 +647,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
 
         return (
             ConstantMultiplication(src, self)
-            if isinstance(src, NumRuntime)
+            if isinstance(src, Number)
             else Multiplication(src, self)
         )
 
@@ -657,7 +656,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
         from b_asic.core_operations import Constant, Division
 
         return Division(
-            self, Constant(src) if isinstance(src, NumRuntime) else src
+            self, Constant(src) if isinstance(src, Number) else src
         )
 
     def __rtruediv__(
@@ -666,7 +665,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
         # Import here to avoid circular imports.
         from b_asic.core_operations import Constant, Division, Reciprocal
 
-        if isinstance(src, NumRuntime):
+        if isinstance(src, Number):
             if src == 1:
                 return Reciprocal(self)
             else:
diff --git a/b_asic/resources.py b/b_asic/resources.py
index 6ad988ff..667eba29 100644
--- a/b_asic/resources.py
+++ b/b_asic/resources.py
@@ -35,7 +35,7 @@ def draw_exclusion_graph_coloring(
     ] = None,
 ):
     """
-    Use matplotlib.pyplot and networkx to draw a colored exclusion graph from the memory assigment
+    Use matplotlib.pyplot and networkx to draw a colored exclusion graph from the memory assignment
 
     .. code-block:: python
 
@@ -77,7 +77,6 @@ def draw_exclusion_graph_coloring(
         '#aa00aa',
         '#00aaaa',
     ]
-    node_color_dict = {}
     if color_list is None:
         node_color_dict = {k: COLOR_LIST[v] for k, v in color_dict.items()}
     else:
@@ -137,8 +136,8 @@ class ProcessCollection:
         Parameters
         ----------
         ax : :class:`matplotlib.axes.Axes`, optional
-            Matplotlib Axes object to draw this lifetime chart onto. If not provided (i.e., set to None), this method will
-            return a new axes object on return.
+            Matplotlib Axes object to draw this lifetime chart onto. If not provided (i.e., set to None),
+            this method will return a new axes object on return.
         show_name : bool, default: True
             Show name of all processes in the lifetime chart.
 
@@ -147,7 +146,7 @@ class ProcessCollection:
             ax: Associated Matplotlib Axes (or array of Axes) object
         """
 
-        # Setup the Axes object
+        # Set up the Axes object
         if ax is None:
             _, _ax = plt.subplots()
         else:
@@ -159,7 +158,7 @@ class ProcessCollection:
             process.execution_time for process in self._collection
         )
         if max_execution_time > self._schedule_time:
-            # Schedule time needs to be greater than or equal to the maximum process life time
+            # Schedule time needs to be greater than or equal to the maximum process lifetime
             raise KeyError(
                 f'Error: Schedule time: {self._schedule_time} < Max execution'
                 f' time: {max_execution_time}'
@@ -222,7 +221,7 @@ class ProcessCollection:
         self, add_name: bool = True
     ) -> nx.Graph:
         """
-        Generate exclusion graph based on processes overlaping in time
+        Generate exclusion graph based on processes overlapping in time
 
         Parameters
         ----------
@@ -270,20 +269,20 @@ class ProcessCollection:
         Parameters
         ----------
         heuristic : str, default: "graph_color"
-            The heuristic used when spliting this ProcessCollection.
+            The heuristic used when splitting this ProcessCollection.
             Valid options are:
                 * "graph_color"
                 * "..."
         read_ports : int, optional
-            The number of read ports used when spliting process collection based on memory variable access.
+            The number of read ports used when splitting process collection based on memory variable access.
         write_ports : int, optional
-            The number of write ports used when spliting process collection based on memory variable access.
+            The number of write ports used when splitting process collection based on memory variable access.
         total_ports : int, optional
-            The total number of ports used when spliting process collection based on memory variable access.
+            The total number of ports used when splitting process collection based on memory variable access.
 
         Returns
         -------
-        A set of new ProcessColleciton objects with the process spliting.
+        A set of new ProcessCollection objects with the process splitting.
         """
         if total_ports is None:
             if read_ports is None or write_ports is None:
@@ -308,20 +307,20 @@ class ProcessCollection:
         Parameters
         ----------
         read_ports : int, optional
-            The number of read ports used when spliting process collection based on memory variable access.
+            The number of read ports used when splitting process collection based on memory variable access.
         write_ports : int, optional
-            The number of write ports used when spliting process collection based on memory variable access.
+            The number of write ports used when splitting process collection based on memory variable access.
         total_ports : int, optional
-            The total number of ports used when spliting process collection based on memory variable access.
+            The total number of ports used when splitting process collection based on memory variable access.
         """
         if read_ports != 1 or write_ports != 1:
             raise ValueError(
-                "Spliting with read and write ports not equal to one with the"
+                "Splitting with read and write ports not equal to one with the"
                 " graph coloring heuristic does not make sense."
             )
         if total_ports not in (1, 2):
             raise ValueError(
-                "Total ports should be either 1 (non-concurent reads/writes)"
+                "Total ports should be either 1 (non-concurrent reads/writes)"
                 " or 2 (concurrent read/writes) for graph coloring heuristic."
             )
 
diff --git a/b_asic/schedule.py b/b_asic/schedule.py
index 575759b2..bab48e36 100644
--- a/b_asic/schedule.py
+++ b/b_asic/schedule.py
@@ -603,7 +603,7 @@ class Schedule:
 
     def _get_y_position(
         self, graph_id, operation_height=1.0, operation_gap=None
-    ):
+    ) -> float:
         if operation_gap is None:
             operation_gap = OPERATION_GAP
         y_location = self._y_locations[graph_id]
@@ -617,11 +617,15 @@ class Schedule:
             self._y_locations[graph_id] = y_location
         return operation_gap + y_location * (operation_height + operation_gap)
 
-    def _plot_schedule(self, ax, operation_gap: Optional[float] = None):
+    def _plot_schedule(
+        self, ax: Axes, operation_gap: Optional[float] = None
+    ) -> None:
         """Draw the schedule."""
         line_cache = []
 
-        def _draw_arrow(start, end, name="", laps=0):
+        def _draw_arrow(
+            start: List[float], end: List[float], name: str = "", laps: int = 0
+        ):
             """Draw an arrow from *start* to *end*."""
             if end[0] < start[0] or laps > 0:  # Wrap around
                 if start not in line_cache:
@@ -848,7 +852,7 @@ class Schedule:
         self._plot_schedule(ax, operation_gap=operation_gap)
         return fig
 
-    def _repr_svg_(self):
+    def _repr_svg_(self) -> str:
         """
         Generate an SVG of the schedule. This is automatically displayed in e.g.
         Jupyter Qt console.
diff --git a/b_asic/sfg_generators.py b/b_asic/sfg_generators.py
index 48edf226..37d011f8 100644
--- a/b_asic/sfg_generators.py
+++ b/b_asic/sfg_generators.py
@@ -72,9 +72,8 @@ def wdf_allpass(
     odd_order = order % 2
     if odd_order:
         # First-order section
-        coeff = np_coefficients[0]
         adaptor0 = SymmetricTwoportAdaptor(
-            coeff,
+            np_coefficients[0],
             input_op,
             latency=latency,
             latency_offsets=latency_offsets,
@@ -185,10 +184,11 @@ def direct_form_fir(
     prev_add = None
     for i, coeff in enumerate(np_coefficients):
         tmp_mul = ConstantMultiplication(coeff, prev_delay, **mult_properties)
-        if prev_add is None:
-            prev_add = tmp_mul
-        else:
-            prev_add = Addition(tmp_mul, prev_add, **add_properties)
+        prev_add = (
+            tmp_mul
+            if prev_add is None
+            else Addition(tmp_mul, prev_add, **add_properties)
+        )
         if i < taps - 1:
             prev_delay = Delay(prev_delay)
 
@@ -266,10 +266,11 @@ def transposed_direct_form_fir(
     prev_add = None
     for i, coeff in enumerate(reversed(np_coefficients)):
         tmp_mul = ConstantMultiplication(coeff, input_op, **mult_properties)
-        if prev_delay is None:
-            tmp_add = tmp_mul
-        else:
-            tmp_add = Addition(tmp_mul, prev_delay, **add_properties)
+        tmp_add = (
+            tmp_mul
+            if prev_delay is None
+            else Addition(tmp_mul, prev_delay, **add_properties)
+        )
         if i < taps - 1:
             prev_delay = Delay(tmp_add)
 
diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py
index 9e0c33f3..417baa24 100644
--- a/b_asic/signal_flow_graph.py
+++ b/b_asic/signal_flow_graph.py
@@ -27,13 +27,7 @@ from typing import (
 
 from graphviz import Digraph
 
-from b_asic.graph_component import (
-    GraphComponent,
-    GraphID,
-    GraphIDNumber,
-    Name,
-    TypeName,
-)
+from b_asic.graph_component import GraphComponent
 from b_asic.operation import (
     AbstractOperation,
     MutableDelayMap,
@@ -44,6 +38,7 @@ from b_asic.operation import (
 from b_asic.port import InputPort, OutputPort, SignalSourceProvider
 from b_asic.signal import Signal
 from b_asic.special_operations import Delay, Input, Output
+from b_asic.types import GraphID, GraphIDNumber, Name, Num, TypeName
 
 DelayQueue = List[Tuple[str, ResultKey, OutputPort]]
 
@@ -377,7 +372,7 @@ class SFG(AbstractOperation):
     def evaluate_output(
         self,
         index: int,
-        input_values: Sequence[Number],
+        input_values: Sequence[Num],
         results: Optional[MutableResultMap] = None,
         delays: Optional[MutableDelayMap] = None,
         prefix: str = "",
@@ -465,11 +460,11 @@ class SFG(AbstractOperation):
         for input_port, input_operation in zip(
             self.inputs, self.input_operations
         ):
-            dest = input_operation.output(0).signals[0].destination
-            if dest is None:
+            destination = input_operation.output(0).signals[0].destination
+            if destination is None:
                 raise ValueError("Missing destination in signal.")
-            dest.clear()
-            input_port.signals[0].set_destination(dest)
+            destination.clear()
+            input_port.signals[0].set_destination(destination)
         # For each output_signal, connect it to the corresponding operation
         for output_port, output_operation in zip(
             self.outputs, self.output_operations
@@ -825,12 +820,12 @@ class SFG(AbstractOperation):
             with pg.subgraph(name=f"cluster_{i}") as sub:
                 sub.attr(label=f"N{i}")
                 for port in ports:
-                    portstr = f"{port.operation.graph_id}.{port.index}"
+                    port_string = f"{port.operation.graph_id}.{port.index}"
                     if port.operation.output_count > 1:
-                        sub.node(portstr)
+                        sub.node(port_string)
                     else:
                         sub.node(
-                            portstr,
+                            port_string,
                             shape='rectangle',
                             label=port.operation.graph_id,
                             height="0.1",
@@ -841,28 +836,29 @@ class SFG(AbstractOperation):
         for i in range(len(p_list)):
             ports = p_list[i]
             for port in ports:
+                source_label = port.operation.graph_id
+                node_node = f"{source_label}.{port.index}"
                 for signal in port.signals:
                     destination = cast(InputPort, signal.destination)
-                    if isinstance(destination.operation, Delay):
-                        dest_node = destination.operation.graph_id + "In"
-                    else:
-                        dest_node = destination.operation.graph_id
-                    dest_label = destination.operation.graph_id
-                    node_node = f"{port.operation.graph_id}.{port.index}"
-                    pg.edge(node_node, dest_node)
+                    destination_label = destination.operation.graph_id
+                    destination_node = (
+                        destination_label + "In"
+                        if isinstance(destination.operation, Delay)
+                        else destination_label
+                    )
+                    pg.edge(node_node, destination_node)
                     pg.node(
-                        dest_node,
-                        label=dest_label,
+                        destination_node,
+                        label=destination_label,
                         shape=_OPERATION_SHAPE[
                             destination.operation.type_name()
                         ],
                     )
-                if port.operation.type_name() == Delay.type_name():
-                    source_node = port.operation.graph_id + "Out"
-                else:
-                    source_node = port.operation.graph_id
-                source_label = port.operation.graph_id
-                node_node = f"{port.operation.graph_id}.{port.index}"
+                source_node = (
+                    source_label + "Out"
+                    if port.operation.type_name() == Delay.type_name()
+                    else source_label
+                )
                 pg.edge(source_node, node_node)
                 pg.node(
                     source_node,
@@ -887,8 +883,8 @@ class SFG(AbstractOperation):
 
         printed_ops = set()
 
-        for iter_num, iter in enumerate(precedence_list, start=1):
-            for outport_num, outport in enumerate(iter, start=1):
+        for iter_num, iterable in enumerate(precedence_list, start=1):
+            for outport_num, outport in enumerate(iterable, start=1):
                 if outport not in printed_ops:
                     # Only print once per operation, even if it has multiple outports
                     out_str.write("\n")
@@ -1160,7 +1156,7 @@ class SFG(AbstractOperation):
                         self._components_dfs_order.extend(
                             [new_signal, source.operation]
                         )
-                        if not source.operation in self._operations_dfs_order:
+                        if source.operation not in self._operations_dfs_order:
                             self._operations_dfs_order.append(source.operation)
 
                     # Check if the signal has not been added before.
@@ -1331,7 +1327,7 @@ class SFG(AbstractOperation):
         bits_override: Optional[int],
         truncate: bool,
         deferred_delays: DelayQueue,
-    ) -> Number:
+    ) -> Num:
         key_base = (
             (prefix + "." + src.operation.graph_id)
             if prefix
@@ -1376,7 +1372,7 @@ class SFG(AbstractOperation):
         bits_override: Optional[int],
         truncate: bool,
         deferred_delays: DelayQueue,
-    ) -> Number:
+    ) -> Num:
         input_values = [
             self._evaluate_source(
                 input_port.signals[0].source,
@@ -1458,13 +1454,13 @@ class SFG(AbstractOperation):
             "image/png"
         ]
 
-    def show(self, format=None, show_id=False, engine=None) -> None:
+    def show(self, fmt=None, show_id=False, engine=None) -> None:
         """
         Shows a visual representation of the SFG using the default system viewer.
 
         Parameters
         ----------
-        format : string, optional
+        fmt : string, optional
             File format of the generated graph. Output formats can be found at
             https://www.graphviz.org/doc/info/output.html
             Most common are "pdf", "eps", "png", and "svg". Default is None which
@@ -1481,8 +1477,8 @@ class SFG(AbstractOperation):
         dg = self.sfg_digraph(show_id=show_id)
         if engine is not None:
             dg.engine = engine
-        if format is not None:
-            dg.format = format
+        if fmt is not None:
+            dg.format = fmt
         dg.view()
 
     def critical_path(self):
diff --git a/b_asic/simulation.py b/b_asic/simulation.py
index 75bf31af..c9fa14eb 100644
--- a/b_asic/simulation.py
+++ b/b_asic/simulation.py
@@ -8,6 +8,7 @@ from collections import defaultdict
 from numbers import Number
 from typing import (
     Callable,
+    List,
     Mapping,
     MutableMapping,
     MutableSequence,
@@ -20,11 +21,12 @@ import numpy as np
 
 from b_asic.operation import MutableDelayMap, ResultKey
 from b_asic.signal_flow_graph import SFG
+from b_asic.types import Num
 
-ResultArrayMap = Mapping[ResultKey, Sequence[Number]]
-MutableResultArrayMap = MutableMapping[ResultKey, MutableSequence[Number]]
-InputFunction = Callable[[int], Number]
-InputProvider = Union[Number, Sequence[Number], InputFunction]
+ResultArrayMap = Mapping[ResultKey, Sequence[Num]]
+MutableResultArrayMap = MutableMapping[ResultKey, MutableSequence[Num]]
+InputFunction = Callable[[int], Num]
+InputProvider = Union[Num, Sequence[Num], InputFunction]
 
 
 class Simulation:
@@ -39,7 +41,7 @@ class Simulation:
     _results: MutableResultArrayMap
     _delays: MutableDelayMap
     _iteration: int
-    _input_functions: Sequence[InputFunction]
+    _input_functions: List[InputFunction]
     _input_length: Optional[int]
 
     def __init__(
@@ -105,7 +107,7 @@ class Simulation:
         save_results: bool = True,
         bits_override: Optional[int] = None,
         truncate: bool = True,
-    ) -> Sequence[Number]:
+    ) -> Sequence[Num]:
         """
         Run one iteration of the simulation and return the resulting output values.
         """
@@ -117,12 +119,12 @@ class Simulation:
         save_results: bool = True,
         bits_override: Optional[int] = None,
         truncate: bool = True,
-    ) -> Sequence[Number]:
+    ) -> Sequence[Num]:
         """
         Run the simulation until its iteration is greater than or equal to the given
         iteration and return the output values of the last iteration.
         """
-        result: Sequence[Number] = []
+        result: Sequence[Num] = []
         while self._iteration < iteration:
             input_values = [
                 self._input_functions[i](self._iteration)
@@ -149,7 +151,7 @@ class Simulation:
         save_results: bool = True,
         bits_override: Optional[int] = None,
         truncate: bool = True,
-    ) -> Sequence[Number]:
+    ) -> Sequence[Num]:
         """
         Run a given number of iterations of the simulation and return the output
         values of the last iteration.
@@ -163,7 +165,7 @@ class Simulation:
         save_results: bool = True,
         bits_override: Optional[int] = None,
         truncate: bool = True,
-    ) -> Sequence[Number]:
+    ) -> Sequence[Num]:
         """
         Run the simulation until the end of its input arrays and return the output
         values of the last iteration.
diff --git a/b_asic/special_operations.py b/b_asic/special_operations.py
index 16edaaea..35b78efa 100644
--- a/b_asic/special_operations.py
+++ b/b_asic/special_operations.py
@@ -5,9 +5,8 @@ Contains operations with special purposes that may be treated differently from
 normal operations in an SFG.
 """
 
-from typing import List, Optional, Sequence, Tuple
+from typing import Optional, Sequence, Tuple
 
-from b_asic.graph_component import Name, TypeName
 from b_asic.operation import (
     AbstractOperation,
     DelayMap,
diff --git a/test/test_sfg.py b/test/test_sfg.py
index 433c7b17..a4276fa6 100644
--- a/test/test_sfg.py
+++ b/test/test_sfg.py
@@ -1424,7 +1424,7 @@ class TestSFGGraph:
 
     def test_show_sfg_invalid_format(self, sfg_simple_filter):
         with pytest.raises(ValueError):
-            sfg_simple_filter.show(format="ppddff")
+            sfg_simple_filter.show(fmt="ppddff")
 
     def test_show_sfg_invalid_engine(self, sfg_simple_filter):
         with pytest.raises(ValueError):
-- 
GitLab