diff --git a/b_asic/schedule.py b/b_asic/schedule.py index 9ff02aa12d4191ec87f21c2976f44ac4728adc16..575759b2a73aed44a52757855641af67827da3c7 100644 --- a/b_asic/schedule.py +++ b/b_asic/schedule.py @@ -617,7 +617,7 @@ class Schedule: self._y_locations[graph_id] = y_location return operation_gap + y_location * (operation_height + operation_gap) - def _plot_schedule(self, ax, operation_gap=None): + def _plot_schedule(self, ax, operation_gap: Optional[float] = None): """Draw the schedule.""" line_cache = [] @@ -719,7 +719,7 @@ class Schedule: ax.grid() for graph_id, op_start_time in self._start_times.items(): y_pos = self._get_y_position(graph_id, operation_gap=operation_gap) - op = self._sfg.find_by_id(graph_id) + op = cast(Operation, self._sfg.find_by_id(graph_id)) # Rewrite to make better use of NumPy ( latency_coordinates, @@ -741,10 +741,12 @@ class Schedule: linewidth=3, ) ytickpositions.append(y_pos + 0.5) - yticklabels.append(self._sfg.find_by_id(graph_id).name) + yticklabels.append( + cast(Operation, self._sfg.find_by_id(graph_id)).name + ) for graph_id, op_start_time in self._start_times.items(): - op = self._sfg.find_by_id(graph_id) + op = cast(Operation, self._sfg.find_by_id(graph_id)) out_coordinates = op.get_output_coordinates() source_y_pos = self._get_y_position( graph_id, operation_gap=operation_gap @@ -752,7 +754,8 @@ class Schedule: for output_port in op.outputs: for output_signal in output_port.signals: - destination_op = output_signal.destination.operation + destination = cast(InputPort, output_signal.destination) + destination_op = destination.operation destination_start_time = self._start_times[ destination_op.graph_id ] @@ -760,13 +763,11 @@ class Schedule: destination_op.graph_id, operation_gap=operation_gap ) destination_in_coordinates = ( - output_signal.destination.operation.get_input_coordinates() + destination.operation.get_input_coordinates() ) _draw_offset_arrow( out_coordinates[output_port.index], - destination_in_coordinates[ - output_signal.destination.index - ], + destination_in_coordinates[destination.index], [op_start_time, source_y_pos], [destination_start_time, destination_y_pos], name=graph_id, @@ -798,11 +799,13 @@ class Schedule: color="black", ) - def _reset_y_locations(self): + def _reset_y_locations(self) -> None: """Reset all the y-locations in the schedule to None""" self._y_locations = self._y_locations = defaultdict(lambda: None) - def plot_in_axes(self, ax: Axes, operation_gap: float = None) -> None: + def plot_in_axes( + self, ax: Axes, operation_gap: Optional[float] = None + ) -> None: """ Plot the schedule in a :class:`matplotlib.axes.Axes` or subclass. @@ -815,7 +818,7 @@ class Schedule: the operation is always 1. """ - def plot(self, operation_gap: float = None) -> None: + def plot(self, operation_gap: Optional[float] = None) -> None: """ Plot the schedule. Will display based on the current Matplotlib backend. @@ -827,7 +830,7 @@ class Schedule: """ self._get_figure(operation_gap=operation_gap).show() - def _get_figure(self, operation_gap: float = None) -> Figure: + def _get_figure(self, operation_gap: Optional[float] = None) -> Figure: """ Create a Figure and an Axes and plot schedule in the Axes. diff --git a/b_asic/scheduler_gui/compile.py b/b_asic/scheduler_gui/compile.py index d8cfe326ddb349e152e51f3c41bee56318a20830..227828aac3e17009020b19a85379f46e5e7ab6c0 100644 --- a/b_asic/scheduler_gui/compile.py +++ b/b_asic/scheduler_gui/compile.py @@ -60,7 +60,7 @@ def compile_rc(*filenames: str) -> None: """ _check_qt_version() - def compile(filename: str = None) -> None: + def compile(filename: str) -> None: outfile = f"{os.path.splitext(filename)[0]}_rc.py" rcc = shutil.which("pyside2-rcc") arguments = f"-g python -o {outfile} {filename}" @@ -191,7 +191,7 @@ def compile_ui(*filenames: str) -> None: compile(filename) -def compile_all(): +def compile_all() -> None: """ The compiler will search for resource (.qrc) files and form (.ui) files and compile accordingly. diff --git a/b_asic/sfg_generator.py b/b_asic/sfg_generator.py index a97f902649d396984702d4424cc5d8dc61ebc7d1..c9440460acb45325a3624121f72a978d82eef8b8 100644 --- a/b_asic/sfg_generator.py +++ b/b_asic/sfg_generator.py @@ -54,9 +54,8 @@ def wdf_allpass( ------- Signal flow graph """ - np.asarray(coefficients) - coefficients = np.squeeze(coefficients) - if coefficients.ndim != 1: + np_coefficients = np.squeeze(np.asarray(coefficients)) + if np_coefficients.ndim != 1: raise TypeError("coefficients must be a 1D-array") if input_op is None: input_op = Input() @@ -64,11 +63,11 @@ def wdf_allpass( output = Output() if name is None: name = "WDF allpass section" - order = len(coefficients) + order = len(np_coefficients) odd_order = order % 2 if odd_order: # First-order section - coeff = coefficients[0] + coeff = np_coefficients[0] adaptor0 = SymmetricTwoportAdaptor( coeff, input_op, @@ -87,7 +86,7 @@ def wdf_allpass( offset1, offset2 = (1, 2) if odd_order else (0, 1) for n in range(sos_count): adaptor1 = SymmetricTwoportAdaptor( - coefficients[2 * n + offset1], + np_coefficients[2 * n + offset1], signal_out, latency=latency, latency_offsets=latency_offsets, @@ -97,7 +96,7 @@ def wdf_allpass( delay1 = Delay(adaptor1.output(1)) delay2 = Delay() adaptor2 = SymmetricTwoportAdaptor( - coefficients[2 * n + offset2], + np_coefficients[2 * n + offset2], delay1, delay2, latency=latency, diff --git a/b_asic/signal.py b/b_asic/signal.py index 06dc2cf05ee79c9af3946d6779b75a5924e3f2be..f2ce69bb8df4930a7a28dad4af4069758fe387ee 100644 --- a/b_asic/signal.py +++ b/b_asic/signal.py @@ -47,8 +47,10 @@ class Signal(AbstractGraphComponent): def __init__( self, - source: Optional["OutputPort"] = None, - destination: Optional["InputPort"] = None, + source: Optional[Union["OutputPort", "Signal", "Operation"]] = None, + destination: Optional[ + Union["InputPort", "Signal", "Operation"] + ] = None, bits: Optional[int] = None, name: Name = Name(""), ): @@ -95,7 +97,7 @@ class Signal(AbstractGraphComponent): Parameters ========== - source : OutputPort, Signal, or Operation, optional + source : OutputPort, Signal, or Operation OutputPort, Signal, or Operation to connect as source to the signal. If Signal, it will connect to the source of the signal, so later on changing the source of the argument Signal will not affect this Signal. @@ -107,13 +109,15 @@ class Signal(AbstractGraphComponent): if isinstance(source, (Signal, Operation)): # Signal or Operation - source = source.source + new_source = source.source + else: + new_source = source - if source is not self._source: + if new_source is not self._source: self.remove_source() - self._source = source - if self not in source.signals: - source.add_signal(self) + self._source = new_source + if self not in new_source.signals: + new_source.add_signal(self) def set_destination( self, destination: Union["InputPort", "Signal", "Operation"] @@ -134,15 +138,20 @@ class Signal(AbstractGraphComponent): is raised. """ - if hasattr(destination, "destination"): + # import here to avoid cyclic imports + from b_asic.operation import Operation + + if isinstance(destination, (Signal, Operation)): # Signal or Operation - destination = destination.destination + new_destination = destination.destination + else: + new_destination = destination - if destination is not self._destination: + if new_destination is not self._destination: self.remove_destination() - self._destination = destination - if self not in destination.signals: - destination.add_signal(self) + self._destination = new_destination + if self not in new_destination.signals: + new_destination.add_signal(self) def remove_source(self) -> None: """