From aa3a5975889e179899fc7a43d16d3df28f878752 Mon Sep 17 00:00:00 2001 From: Oscar Gustafsson <oscar.gustafsson@gmail.com> Date: Fri, 10 Feb 2023 09:08:03 +0100 Subject: [PATCH] Improve typing --- b_asic/schedule.py | 29 ++++++++++++++------------ b_asic/scheduler_gui/compile.py | 4 ++-- b_asic/sfg_generator.py | 13 ++++++------ b_asic/signal.py | 37 ++++++++++++++++++++------------- 4 files changed, 47 insertions(+), 36 deletions(-) diff --git a/b_asic/schedule.py b/b_asic/schedule.py index 9ff02aa1..575759b2 100644 --- a/b_asic/schedule.py +++ b/b_asic/schedule.py @@ -617,7 +617,7 @@ class Schedule: self._y_locations[graph_id] = y_location return operation_gap + y_location * (operation_height + operation_gap) - def _plot_schedule(self, ax, operation_gap=None): + def _plot_schedule(self, ax, operation_gap: Optional[float] = None): """Draw the schedule.""" line_cache = [] @@ -719,7 +719,7 @@ class Schedule: ax.grid() for graph_id, op_start_time in self._start_times.items(): y_pos = self._get_y_position(graph_id, operation_gap=operation_gap) - op = self._sfg.find_by_id(graph_id) + op = cast(Operation, self._sfg.find_by_id(graph_id)) # Rewrite to make better use of NumPy ( latency_coordinates, @@ -741,10 +741,12 @@ class Schedule: linewidth=3, ) ytickpositions.append(y_pos + 0.5) - yticklabels.append(self._sfg.find_by_id(graph_id).name) + yticklabels.append( + cast(Operation, self._sfg.find_by_id(graph_id)).name + ) for graph_id, op_start_time in self._start_times.items(): - op = self._sfg.find_by_id(graph_id) + op = cast(Operation, self._sfg.find_by_id(graph_id)) out_coordinates = op.get_output_coordinates() source_y_pos = self._get_y_position( graph_id, operation_gap=operation_gap @@ -752,7 +754,8 @@ class Schedule: for output_port in op.outputs: for output_signal in output_port.signals: - destination_op = output_signal.destination.operation + destination = cast(InputPort, output_signal.destination) + destination_op = destination.operation destination_start_time = self._start_times[ destination_op.graph_id ] @@ -760,13 +763,11 @@ class Schedule: destination_op.graph_id, operation_gap=operation_gap ) destination_in_coordinates = ( - output_signal.destination.operation.get_input_coordinates() + destination.operation.get_input_coordinates() ) _draw_offset_arrow( out_coordinates[output_port.index], - destination_in_coordinates[ - output_signal.destination.index - ], + destination_in_coordinates[destination.index], [op_start_time, source_y_pos], [destination_start_time, destination_y_pos], name=graph_id, @@ -798,11 +799,13 @@ class Schedule: color="black", ) - def _reset_y_locations(self): + def _reset_y_locations(self) -> None: """Reset all the y-locations in the schedule to None""" self._y_locations = self._y_locations = defaultdict(lambda: None) - def plot_in_axes(self, ax: Axes, operation_gap: float = None) -> None: + def plot_in_axes( + self, ax: Axes, operation_gap: Optional[float] = None + ) -> None: """ Plot the schedule in a :class:`matplotlib.axes.Axes` or subclass. @@ -815,7 +818,7 @@ class Schedule: the operation is always 1. """ - def plot(self, operation_gap: float = None) -> None: + def plot(self, operation_gap: Optional[float] = None) -> None: """ Plot the schedule. Will display based on the current Matplotlib backend. @@ -827,7 +830,7 @@ class Schedule: """ self._get_figure(operation_gap=operation_gap).show() - def _get_figure(self, operation_gap: float = None) -> Figure: + def _get_figure(self, operation_gap: Optional[float] = None) -> Figure: """ Create a Figure and an Axes and plot schedule in the Axes. diff --git a/b_asic/scheduler_gui/compile.py b/b_asic/scheduler_gui/compile.py index d8cfe326..227828aa 100644 --- a/b_asic/scheduler_gui/compile.py +++ b/b_asic/scheduler_gui/compile.py @@ -60,7 +60,7 @@ def compile_rc(*filenames: str) -> None: """ _check_qt_version() - def compile(filename: str = None) -> None: + def compile(filename: str) -> None: outfile = f"{os.path.splitext(filename)[0]}_rc.py" rcc = shutil.which("pyside2-rcc") arguments = f"-g python -o {outfile} {filename}" @@ -191,7 +191,7 @@ def compile_ui(*filenames: str) -> None: compile(filename) -def compile_all(): +def compile_all() -> None: """ The compiler will search for resource (.qrc) files and form (.ui) files and compile accordingly. diff --git a/b_asic/sfg_generator.py b/b_asic/sfg_generator.py index a97f9026..c9440460 100644 --- a/b_asic/sfg_generator.py +++ b/b_asic/sfg_generator.py @@ -54,9 +54,8 @@ def wdf_allpass( ------- Signal flow graph """ - np.asarray(coefficients) - coefficients = np.squeeze(coefficients) - if coefficients.ndim != 1: + np_coefficients = np.squeeze(np.asarray(coefficients)) + if np_coefficients.ndim != 1: raise TypeError("coefficients must be a 1D-array") if input_op is None: input_op = Input() @@ -64,11 +63,11 @@ def wdf_allpass( output = Output() if name is None: name = "WDF allpass section" - order = len(coefficients) + order = len(np_coefficients) odd_order = order % 2 if odd_order: # First-order section - coeff = coefficients[0] + coeff = np_coefficients[0] adaptor0 = SymmetricTwoportAdaptor( coeff, input_op, @@ -87,7 +86,7 @@ def wdf_allpass( offset1, offset2 = (1, 2) if odd_order else (0, 1) for n in range(sos_count): adaptor1 = SymmetricTwoportAdaptor( - coefficients[2 * n + offset1], + np_coefficients[2 * n + offset1], signal_out, latency=latency, latency_offsets=latency_offsets, @@ -97,7 +96,7 @@ def wdf_allpass( delay1 = Delay(adaptor1.output(1)) delay2 = Delay() adaptor2 = SymmetricTwoportAdaptor( - coefficients[2 * n + offset2], + np_coefficients[2 * n + offset2], delay1, delay2, latency=latency, diff --git a/b_asic/signal.py b/b_asic/signal.py index 06dc2cf0..f2ce69bb 100644 --- a/b_asic/signal.py +++ b/b_asic/signal.py @@ -47,8 +47,10 @@ class Signal(AbstractGraphComponent): def __init__( self, - source: Optional["OutputPort"] = None, - destination: Optional["InputPort"] = None, + source: Optional[Union["OutputPort", "Signal", "Operation"]] = None, + destination: Optional[ + Union["InputPort", "Signal", "Operation"] + ] = None, bits: Optional[int] = None, name: Name = Name(""), ): @@ -95,7 +97,7 @@ class Signal(AbstractGraphComponent): Parameters ========== - source : OutputPort, Signal, or Operation, optional + source : OutputPort, Signal, or Operation OutputPort, Signal, or Operation to connect as source to the signal. If Signal, it will connect to the source of the signal, so later on changing the source of the argument Signal will not affect this Signal. @@ -107,13 +109,15 @@ class Signal(AbstractGraphComponent): if isinstance(source, (Signal, Operation)): # Signal or Operation - source = source.source + new_source = source.source + else: + new_source = source - if source is not self._source: + if new_source is not self._source: self.remove_source() - self._source = source - if self not in source.signals: - source.add_signal(self) + self._source = new_source + if self not in new_source.signals: + new_source.add_signal(self) def set_destination( self, destination: Union["InputPort", "Signal", "Operation"] @@ -134,15 +138,20 @@ class Signal(AbstractGraphComponent): is raised. """ - if hasattr(destination, "destination"): + # import here to avoid cyclic imports + from b_asic.operation import Operation + + if isinstance(destination, (Signal, Operation)): # Signal or Operation - destination = destination.destination + new_destination = destination.destination + else: + new_destination = destination - if destination is not self._destination: + if new_destination is not self._destination: self.remove_destination() - self._destination = destination - if self not in destination.signals: - destination.add_signal(self) + self._destination = new_destination + if self not in new_destination.signals: + new_destination.add_signal(self) def remove_source(self) -> None: """ -- GitLab