From 82b75b20b2eeb0bb0aa324cadf9c8317b1daffa0 Mon Sep 17 00:00:00 2001
From: Oscar Gustafsson <oscar.gustafsson@gmail.com>
Date: Wed, 16 Apr 2025 13:37:27 +0200
Subject: [PATCH] More docs and typing fixes

---
 b_asic/architecture.py                 |  7 +++---
 b_asic/codegen/vhdl/__init__.py        |  2 +-
 b_asic/core_operations.py              |  1 -
 b_asic/operation.py                    | 32 +++++++++++---------------
 b_asic/resources.py                    | 29 ++++++++++++-----------
 b_asic/schedule.py                     | 19 +++++++--------
 b_asic/scheduler.py                    | 12 +++++-----
 b_asic/scheduler_gui/main_window.py    | 16 +++++++------
 b_asic/scheduler_gui/operation_item.py | 12 ++++++----
 docs_sphinx/codegen/vhdl.rst           |  4 ++++
 10 files changed, 70 insertions(+), 64 deletions(-)

diff --git a/b_asic/architecture.py b/b_asic/architecture.py
index a843a903..51109516 100644
--- a/b_asic/architecture.py
+++ b/b_asic/architecture.py
@@ -279,7 +279,7 @@ class Resource(HardwareBlock):
     def is_assigned(self) -> bool:
         return self._assignment is not None
 
-    def assign(self, strategy: str = "left_edge"):
+    def assign(self, strategy: Literal["left_edge"] = "left_edge"):
         """
         Perform assignment of processes to resource.
 
@@ -851,7 +851,6 @@ of :class:`~b_asic.architecture.ProcessingElement`
         Returns
         -------
         :class:`Resource`
-
         """
         re = {p.entity_name: p for p in chain(self.memories, self.processing_elements)}
         return re[name]
@@ -881,7 +880,7 @@ of :class:`~b_asic.architecture.ProcessingElement`
         else:
             raise ValueError("Resource not in architecture")
 
-    def assign_resources(self, strategy: str = "left_edge") -> None:
+    def assign_resources(self, strategy: Literal["left_edge"] = "left_edge") -> None:
         """
         Convenience method to assign all resources in the architecture.
 
@@ -894,7 +893,6 @@ of :class:`~b_asic.architecture.ProcessingElement`
         --------
         Memory.assign
         ProcessingElement.assign
-
         """
         for resource in chain(self.memories, self.processing_elements):
             resource.assign(strategy=strategy)
@@ -938,6 +936,7 @@ of :class:`~b_asic.architecture.ProcessingElement`
         if isinstance(proc, str):
             proc = source.collection.from_name(proc)
 
+        proc = cast(Process, proc)
         # Move the process
         if proc in source:
             destination.add_process(proc, assign=assign)
diff --git a/b_asic/codegen/vhdl/__init__.py b/b_asic/codegen/vhdl/__init__.py
index f34b60d0..0b584689 100644
--- a/b_asic/codegen/vhdl/__init__.py
+++ b/b_asic/codegen/vhdl/__init__.py
@@ -28,7 +28,7 @@ def write(
     f : TextIO
         The file object to emit VHDL code to.
     indent_level : int
-        Indentation level to use. Exactly ``f'{VHDL_TAB*indent_level}`` is written
+        Indentation level to use. Exactly ``f'{VHDL_TAB*indent_level}'`` is written
         before the text is written.
     text : str
         The text to write to.
diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py
index 0b39ba9d..495bc666 100644
--- a/b_asic/core_operations.py
+++ b/b_asic/core_operations.py
@@ -1683,7 +1683,6 @@ class DontCare(AbstractOperation):
     ----------
     name : Name, optional
         Operation name.
-
     """
 
     __slots__ = "_name"
diff --git a/b_asic/operation.py b/b_asic/operation.py
index 2ca298e9..725f627a 100644
--- a/b_asic/operation.py
+++ b/b_asic/operation.py
@@ -558,15 +558,12 @@ class AbstractOperation(Operation, AbstractGraphComponent):
             dict_ele = []
             for signal in current_input.signals:
                 if signal.source:
-                    if signal.source_operation.graph_id:
-                        dict_ele.append(signal.source_operation.graph_id)
-                    else:
-                        dict_ele.append(GraphID("no_id"))
+                    dict_ele.append(
+                        cast(Operation, signal.source_operation).graph_id
+                        or GraphID("no_id")
+                    )
                 else:
-                    if signal.graph_id:
-                        dict_ele.append(signal.graph_id)
-                    else:
-                        dict_ele.append(GraphID("no_id"))
+                    dict_ele.append(signal.graph_id or GraphID("no_id"))
             inputs_dict[i] = dict_ele
 
         outputs_dict: dict[int, list[GraphID] | str] = {}
@@ -577,15 +574,12 @@ class AbstractOperation(Operation, AbstractGraphComponent):
             dict_ele = []
             for signal in outport.signals:
                 if signal.destination:
-                    if signal.destination_operation.graph_id:
-                        dict_ele.append(signal.destination_operation.graph_id)
-                    else:
-                        dict_ele.append(GraphID("no_id"))
+                    dict_ele.append(
+                        cast(Operation, signal.destination_operation).graph_id
+                        or GraphID("no_id")
+                    )
                 else:
-                    if signal.graph_id:
-                        dict_ele.append(signal.graph_id)
-                    else:
-                        dict_ele.append(GraphID("no_id"))
+                    dict_ele.append(signal.graph_id or GraphID("no_id"))
             outputs_dict[i] = dict_ele
 
         return (
@@ -781,11 +775,11 @@ class AbstractOperation(Operation, AbstractGraphComponent):
         return list(range(self.input_count))
 
     @property
-    def neighbors(self) -> Iterable[GraphComponent]:
+    def neighbors(self) -> Sequence[GraphComponent]:
         return list(self.input_signals) + list(self.output_signals)
 
     @property
-    def preceding_operations(self) -> Iterable[Operation]:
+    def preceding_operations(self) -> list[Operation]:
         """
         Return an Iterable of all Operations that are connected to this
         Operations input ports.
@@ -795,7 +789,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
         ]
 
     @property
-    def subsequent_operations(self) -> Iterable[Operation]:
+    def subsequent_operations(self) -> list[Operation]:
         """
         Return an Iterable of all Operations that are connected to this
         Operations output ports.
diff --git a/b_asic/resources.py b/b_asic/resources.py
index 26a36336..0e06b295 100644
--- a/b_asic/resources.py
+++ b/b_asic/resources.py
@@ -6,7 +6,7 @@ from collections import Counter, defaultdict
 from collections.abc import Iterable
 from functools import reduce
 from math import floor, log2
-from typing import TYPE_CHECKING, Literal, TypeVar
+from typing import TYPE_CHECKING, Literal, TypeVar, Union
 
 import matplotlib.pyplot as plt
 import networkx as nx
@@ -204,6 +204,7 @@ def draw_exclusion_graph_coloring(
         "#00aaaa",
         "#666666",
     ]
+    node_color_dict: dict[Process, str | tuple[float, float, float]]
     if color_list is None:
         node_color_dict = {k: COLOR_LIST[v] for k, v in color_dict.items()}
     else:
@@ -1160,7 +1161,7 @@ class ProcessCollection:
         This method takes the processes from `sequence`, in order, and assigns them to
         to multiple new `ProcessCollection` based on port collisions in a first-come
         first-served manner. The first :class:`Process` in `sequence` is assigned first, and
-        the last :class:`Process` in `sequence is assigned last.
+        the last :class:`Process` in `sequence` is assigned last.
 
         Parameters
         ----------
@@ -1264,8 +1265,8 @@ class ProcessCollection:
                     process_fits_in_collection = self._get_process_fits_in_collection(
                         process, collections, read_ports, write_ports, total_ports
                     )
-
-            best_collection.add_process(process)
+            if best_collection is not None:
+                best_collection.add_process(process)
 
         collections = [
             collection for collection in collections if collection.collection
@@ -1331,7 +1332,8 @@ class ProcessCollection:
                     process_fits_in_collection = self._get_process_fits_in_collection(
                         process, collections, read_ports, write_ports, total_ports
                     )
-            best_collection.add_process(process)
+            if best_collection is not None:
+                best_collection.add_process(process)
 
         collections = [
             collection for collection in collections if collection.collection
@@ -1379,7 +1381,7 @@ class ProcessCollection:
     @staticmethod
     def _count_number_of_pes_read_from(
         processing_elements: list["ProcessingElement"],
-        collection: "ProcessCollection",
+        collection: Union["ProcessCollection", list["Process"]],
     ) -> int:
         collection_process_names = {proc.name.split(".")[0] for proc in collection}
         count = 0
@@ -1394,7 +1396,7 @@ class ProcessCollection:
     @staticmethod
     def _count_number_of_pes_written_to(
         processing_elements: list["ProcessingElement"],
-        collection: "ProcessCollection",
+        collection: Union["ProcessCollection", list["Process"]],
     ) -> int:
         collection_process_names = {proc.name.split(".")[0] for proc in collection}
         count = 0
@@ -1853,7 +1855,9 @@ class ProcessCollection:
         -------
         A set of new ProcessCollections.
         """
-        process_collection_set_list = [[] for _ in range(max(coloring.values()) + 1)]
+        process_collection_set_list: list[list[Process]] = [
+            [] for _ in range(max(coloring.values()) + 1)
+        ]
         for process, color in coloring.items():
             process_collection_set_list[color].append(process)
         return [
@@ -2093,7 +2097,7 @@ class ProcessCollection:
         adr_pipe_depth: int | None = None,
     ):
         """
-        Generate VHDL code for memory based storage of processes (MemoryVariables).
+        Generate VHDL code for memory-based storage of processes (MemoryVariables).
 
         Parameters
         ----------
@@ -2243,8 +2247,8 @@ class ProcessCollection:
             A tuple of two ProcessCollections, one with shorter than or equal execution
             times and one with longer execution times.
         """
-        short = []
-        long = []
+        short: list[Process] = []
+        long: list[Process] = []
         for process in self.collection:
             if process.execution_time <= length:
                 short.append(process)
@@ -2274,7 +2278,7 @@ class ProcessCollection:
         total_ports: int = 2,
     ):
         """
-        Generate VHDL code for register based storage.
+        Generate VHDL code for register-based storage of processes (MemoryVariables).
 
         This is based on Forward-Backward Register Allocation.
 
@@ -2302,7 +2306,6 @@ class ProcessCollection:
         ----------
         - K. Parhi: VLSI Digital Signal Processing Systems: Design and
           Implementation, Ch. 6.3.2
-
         """
         # Check that entity name is a valid VHDL identifier
         if not is_valid_vhdl_identifier(entity_name):
diff --git a/b_asic/schedule.py b/b_asic/schedule.py
index a3d8aaa1..dbd6b512 100644
--- a/b_asic/schedule.py
+++ b/b_asic/schedule.py
@@ -696,9 +696,9 @@ class Schedule:
         graph_id : GraphID
             The GraphID of the operation to move.
         new_y : int
-            The new y-position of the operation.
+            The new y-location of the operation.
         insert : bool, optional
-            If True, all operations on that y-position will be moved one position.
+            If True, all operations on that y-location will be moved one position.
             The default is False.
         """
         if insert:
@@ -723,7 +723,7 @@ class Schedule:
 
     def get_y_location(self, graph_id: GraphID) -> int:
         """
-        Get the y-position of the Operation with GraphID *graph_id*.
+        Get the y-location of the Operation with GraphID *graph_id*.
 
         Parameters
         ----------
@@ -733,20 +733,20 @@ class Schedule:
         Returns
         -------
         int
-            The y-position of the operation.
+            The y-location of the operation.
         """
         return self._y_locations[graph_id]
 
     def set_y_location(self, graph_id: GraphID, y_location: int) -> None:
         """
-        Set the y-position of the Operation with GraphID *graph_id* to *y_location*.
+        Set the y-location of the Operation with GraphID *graph_id* to *y_location*.
 
         Parameters
         ----------
         graph_id : GraphID
             The GraphID of the operation to move.
         y_location : int
-            The new y-position of the operation.
+            The new y-location of the operation.
         """
         self._y_locations[graph_id] = y_location
 
@@ -761,7 +761,8 @@ class Schedule:
     def place_operation(
         self, op: Operation, time: int, op_laps: dict[GraphID, int]
     ) -> None:
-        """Schedule the given operation in given time.
+        """
+        Schedule *op* at *time*.
 
         Parameters
         ----------
@@ -784,7 +785,7 @@ class Schedule:
             laps = 0
             if self._schedule_time is not None:
                 current_lap = time // self._schedule_time
-                source_port = source_op = input_port.signals[0].source
+                source_port = cast(OutputPort, input_port.signals[0].source)
                 source_op = source_port.operation
 
                 if not isinstance(source_op, (Delay, DontCare)):
@@ -1356,7 +1357,7 @@ class Schedule:
         )
 
     def reset_y_locations(self) -> None:
-        """Reset all the y-locations in the schedule to None"""
+        """Reset all the y-locations in the schedule to None."""
         self._y_locations = defaultdict(_y_locations_default)
 
     def plot(self, ax: Axes, operation_gap: float = OPERATION_GAP) -> None:
diff --git a/b_asic/scheduler.py b/b_asic/scheduler.py
index 0bae94f3..b49ea990 100644
--- a/b_asic/scheduler.py
+++ b/b_asic/scheduler.py
@@ -22,7 +22,7 @@ PriorityTableType = list[tuple["GraphID", int, int, int, int]]
 
 class Scheduler(ABC):
     """
-    Scheduler base class
+    Scheduler base class.
 
     Parameters
     ----------
@@ -31,7 +31,7 @@ class Scheduler(ABC):
     output_delta_times : dict(GraphID, int), optional
         The relative times when outputs should be produced.
     sort_y_location : bool, default: True
-        If the y-position should be sorted based on start time of operations.
+        If the y-location should be sorted based on start time of operations.
     """
 
     __slots__ = (
@@ -233,7 +233,7 @@ class ASAPScheduler(Scheduler):
     output_delta_times : dict(GraphID, int), optional
         The relative times when outputs should be produced.
     sort_y_location : bool, default: True
-        If the y-position should be sorted based on start time of operations.
+        If the y-location should be sorted based on start time of operations.
     """
 
     def apply_scheduling(self, schedule: "Schedule") -> None:
@@ -328,7 +328,7 @@ class ALAPScheduler(Scheduler):
     output_delta_times : dict(GraphID, int), optional
         The relative times when outputs should be produced.
     sort_y_location : bool, default: True
-        If the y-position should be sorted based on start time of operations.
+        If the y-location should be sorted based on start time of operations.
     """
 
     def apply_scheduling(self, schedule: "Schedule") -> None:
@@ -408,7 +408,7 @@ class ListScheduler(Scheduler):
     output_delta_times : dict(GraphID, int) | None, optional
         The relative times when outputs should be produced.
     sort_y_location : bool, default: True
-        If the y-position should be sorted based on start time of operations.
+        If the y-location should be sorted based on start time of operations.
     """
 
     __slots__ = (
@@ -929,7 +929,7 @@ class RecursiveListScheduler(ListScheduler):
     output_delta_times : dict(GraphID, int) | None, optional
         The relative times when outputs should be produced.
     sort_y_location : bool, default: True
-        If the y-position should be sorted based on start time of operations.
+        If the y-location should be sorted based on start time of operations.
     """
 
     __slots__ = ('_recursive_ops', '_recursive_ops_set', '_remaining_recursive_ops')
diff --git a/b_asic/scheduler_gui/main_window.py b/b_asic/scheduler_gui/main_window.py
index 21283440..453042cf 100644
--- a/b_asic/scheduler_gui/main_window.py
+++ b/b_asic/scheduler_gui/main_window.py
@@ -218,7 +218,7 @@ class ScheduleMainWindow(QMainWindow, Ui_MainWindow):
         self.splitter.setCollapsible(1, True)
 
     def _init_graphics(self) -> None:
-        """Initialize the QGraphics framework"""
+        """Initialize the QGraphics framework."""
         self._scene = QGraphicsScene()
         self._scene.addRect(0, 0, 0, 0)  # dummy rect to be able to setPos() graph
         self.view.setScene(self._scene)
@@ -277,7 +277,9 @@ class ScheduleMainWindow(QMainWindow, Ui_MainWindow):
     def _decrease_time_resolution(self) -> None:
         """Callback for decreasing time resolution."""
         # Get possible factors
-        vals = [str(v) for v in self.schedule.get_possible_time_resolution_decrements()]
+        vals = [
+            str(v) for v in self._schedule.get_possible_time_resolution_decrements()
+        ]
         # Create dialog
         factor, ok = QInputDialog.getItem(
             self, "Decrease time resolution", "Factor", vals, editable=False
@@ -285,8 +287,8 @@ class ScheduleMainWindow(QMainWindow, Ui_MainWindow):
         # Check return value
         if ok:
             if int(factor) > 1:
-                self.schedule.decrease_time_resolution(int(factor))
-                self.open(self.schedule)
+                self._schedule.decrease_time_resolution(int(factor))
+                self.open(self._schedule)
                 print(f"schedule.decrease_time_resolution({factor})")
                 self.update_statusbar(f"Time resolution decreased by a factor {factor}")
         else:  # Cancelled
@@ -703,7 +705,7 @@ class ScheduleMainWindow(QMainWindow, Ui_MainWindow):
 
     @Slot()
     def _reopen_schedule(self) -> None:
-        self.open(self._schedule)
+        self.open(cast(Schedule, self._schedule))
 
     def update_statusbar(self, msg: str) -> None:
         """
@@ -945,7 +947,7 @@ class ScheduleMainWindow(QMainWindow, Ui_MainWindow):
         self.update_statusbar("Saved Preferences Loaded")
 
     def update_color_preferences(self) -> None:
-        """Update preferences of Latency color per type"""
+        """Update preferences of Latency color per type."""
         used_type_names = self._schedule.get_used_type_names()
         match (LATENCY_COLOR_TYPE.changed, self._color_changed_per_type):
             case (True, False):
@@ -975,7 +977,7 @@ class ScheduleMainWindow(QMainWindow, Ui_MainWindow):
         self.save_colortype()
 
     def save_colortype(self) -> None:
-        """Save preferences of Latency color per type in settings"""
+        """Save preferences of Latency color per type in settings."""
         settings = QSettings()
         for key, color in self._color_per_type.items():
             if self._graph:
diff --git a/b_asic/scheduler_gui/operation_item.py b/b_asic/scheduler_gui/operation_item.py
index 63778af3..c2550994 100644
--- a/b_asic/scheduler_gui/operation_item.py
+++ b/b_asic/scheduler_gui/operation_item.py
@@ -216,7 +216,7 @@ class OperationItem(QGraphicsItemGroup):
         self.setCursor(QCursor(Qt.CursorShape.OpenHandCursor))
 
     def set_font(self, font: QFont) -> None:
-        """Set the items font settings according to a give QFont."""
+        """Set the items font settings to *font*."""
         self._label_item.prepareGeometryChange()
         self._label_item.setFont(font)
         center = self._latency_item.boundingRect().center()
@@ -224,16 +224,18 @@ class OperationItem(QGraphicsItemGroup):
         self._label_item.setPos(self._latency_item.pos() + center)
 
     def set_font_color(self, color: QColor) -> None:
-        """Set the items font color settings according to a give QColor"""
+        """Set the items font color settings to *color*."""
         self._label_item.prepareGeometryChange()
         self._label_item.setBrush(color)
 
     def set_show_port_numbers(self, port_number: bool = True):
+        """Set if port numbers are shown."""
         for item in self._port_number_items:
             item.setVisible(port_number)
 
     def set_port_active(self, key: str):
-        item = self._ports[key]["item"]
+        """Set the port as active, i.e., draw it in special colors."""
+        item = cast(QPointF, self._ports[key]["item"])
         if ACTIVE_COLOR_TYPE.changed:
             self._port_filling_brush_active = QBrush(ACTIVE_COLOR_TYPE.current_color)
             self._port_outline_pen_active = QPen(ACTIVE_COLOR_TYPE.current_color)
@@ -246,7 +248,8 @@ class OperationItem(QGraphicsItemGroup):
         item.setPen(self._port_outline_pen_active)
 
     def set_port_inactive(self, key: str, warning: bool = False):
-        item = self._ports[key]["item"]
+        """Set the port as inactive, i.e., draw it in standard colors."""
+        item = cast(QPointF, self._ports[key]["item"])
         item.setBrush(
             self._port_filling_brush_warning if warning else self._port_filling_brush
         )
@@ -347,6 +350,7 @@ class OperationItem(QGraphicsItemGroup):
         self.set_inactive()
 
     def _open_context_menu(self):
+        """Create and open context menu."""
         menu = QMenu()
         swap = QAction(get_icon("swap"), "Swap")
         menu.addAction(swap)
diff --git a/docs_sphinx/codegen/vhdl.rst b/docs_sphinx/codegen/vhdl.rst
index 4dc935e6..ea3e15e4 100644
--- a/docs_sphinx/codegen/vhdl.rst
+++ b/docs_sphinx/codegen/vhdl.rst
@@ -2,6 +2,10 @@
 ``b_asic.codegen.vhdl``
 ***********************
 
+.. automodule:: b_asic.codegen.vhdl
+   :members:
+   :undoc-members:
+
 ``common`` module
 -----------------
 
-- 
GitLab