Skip to content
Snippets Groups Projects
Commit aa3a5975 authored by Oscar Gustafsson's avatar Oscar Gustafsson :bicyclist:
Browse files

Improve typing

parent a8edfcd9
No related branches found
No related tags found
1 merge request!173Improve typing
Pipeline #89223 passed
...@@ -617,7 +617,7 @@ class Schedule: ...@@ -617,7 +617,7 @@ class Schedule:
self._y_locations[graph_id] = y_location self._y_locations[graph_id] = y_location
return operation_gap + y_location * (operation_height + operation_gap) return operation_gap + y_location * (operation_height + operation_gap)
def _plot_schedule(self, ax, operation_gap=None): def _plot_schedule(self, ax, operation_gap: Optional[float] = None):
"""Draw the schedule.""" """Draw the schedule."""
line_cache = [] line_cache = []
...@@ -719,7 +719,7 @@ class Schedule: ...@@ -719,7 +719,7 @@ class Schedule:
ax.grid() ax.grid()
for graph_id, op_start_time in self._start_times.items(): for graph_id, op_start_time in self._start_times.items():
y_pos = self._get_y_position(graph_id, operation_gap=operation_gap) y_pos = self._get_y_position(graph_id, operation_gap=operation_gap)
op = self._sfg.find_by_id(graph_id) op = cast(Operation, self._sfg.find_by_id(graph_id))
# Rewrite to make better use of NumPy # Rewrite to make better use of NumPy
( (
latency_coordinates, latency_coordinates,
...@@ -741,10 +741,12 @@ class Schedule: ...@@ -741,10 +741,12 @@ class Schedule:
linewidth=3, linewidth=3,
) )
ytickpositions.append(y_pos + 0.5) ytickpositions.append(y_pos + 0.5)
yticklabels.append(self._sfg.find_by_id(graph_id).name) yticklabels.append(
cast(Operation, self._sfg.find_by_id(graph_id)).name
)
for graph_id, op_start_time in self._start_times.items(): for graph_id, op_start_time in self._start_times.items():
op = self._sfg.find_by_id(graph_id) op = cast(Operation, self._sfg.find_by_id(graph_id))
out_coordinates = op.get_output_coordinates() out_coordinates = op.get_output_coordinates()
source_y_pos = self._get_y_position( source_y_pos = self._get_y_position(
graph_id, operation_gap=operation_gap graph_id, operation_gap=operation_gap
...@@ -752,7 +754,8 @@ class Schedule: ...@@ -752,7 +754,8 @@ class Schedule:
for output_port in op.outputs: for output_port in op.outputs:
for output_signal in output_port.signals: for output_signal in output_port.signals:
destination_op = output_signal.destination.operation destination = cast(InputPort, output_signal.destination)
destination_op = destination.operation
destination_start_time = self._start_times[ destination_start_time = self._start_times[
destination_op.graph_id destination_op.graph_id
] ]
...@@ -760,13 +763,11 @@ class Schedule: ...@@ -760,13 +763,11 @@ class Schedule:
destination_op.graph_id, operation_gap=operation_gap destination_op.graph_id, operation_gap=operation_gap
) )
destination_in_coordinates = ( destination_in_coordinates = (
output_signal.destination.operation.get_input_coordinates() destination.operation.get_input_coordinates()
) )
_draw_offset_arrow( _draw_offset_arrow(
out_coordinates[output_port.index], out_coordinates[output_port.index],
destination_in_coordinates[ destination_in_coordinates[destination.index],
output_signal.destination.index
],
[op_start_time, source_y_pos], [op_start_time, source_y_pos],
[destination_start_time, destination_y_pos], [destination_start_time, destination_y_pos],
name=graph_id, name=graph_id,
...@@ -798,11 +799,13 @@ class Schedule: ...@@ -798,11 +799,13 @@ class Schedule:
color="black", color="black",
) )
def _reset_y_locations(self): def _reset_y_locations(self) -> None:
"""Reset all the y-locations in the schedule to None""" """Reset all the y-locations in the schedule to None"""
self._y_locations = self._y_locations = defaultdict(lambda: None) self._y_locations = self._y_locations = defaultdict(lambda: None)
def plot_in_axes(self, ax: Axes, operation_gap: float = None) -> None: def plot_in_axes(
self, ax: Axes, operation_gap: Optional[float] = None
) -> None:
""" """
Plot the schedule in a :class:`matplotlib.axes.Axes` or subclass. Plot the schedule in a :class:`matplotlib.axes.Axes` or subclass.
...@@ -815,7 +818,7 @@ class Schedule: ...@@ -815,7 +818,7 @@ class Schedule:
the operation is always 1. the operation is always 1.
""" """
def plot(self, operation_gap: float = None) -> None: def plot(self, operation_gap: Optional[float] = None) -> None:
""" """
Plot the schedule. Will display based on the current Matplotlib backend. Plot the schedule. Will display based on the current Matplotlib backend.
...@@ -827,7 +830,7 @@ class Schedule: ...@@ -827,7 +830,7 @@ class Schedule:
""" """
self._get_figure(operation_gap=operation_gap).show() self._get_figure(operation_gap=operation_gap).show()
def _get_figure(self, operation_gap: float = None) -> Figure: def _get_figure(self, operation_gap: Optional[float] = None) -> Figure:
""" """
Create a Figure and an Axes and plot schedule in the Axes. Create a Figure and an Axes and plot schedule in the Axes.
......
...@@ -60,7 +60,7 @@ def compile_rc(*filenames: str) -> None: ...@@ -60,7 +60,7 @@ def compile_rc(*filenames: str) -> None:
""" """
_check_qt_version() _check_qt_version()
def compile(filename: str = None) -> None: def compile(filename: str) -> None:
outfile = f"{os.path.splitext(filename)[0]}_rc.py" outfile = f"{os.path.splitext(filename)[0]}_rc.py"
rcc = shutil.which("pyside2-rcc") rcc = shutil.which("pyside2-rcc")
arguments = f"-g python -o {outfile} {filename}" arguments = f"-g python -o {outfile} {filename}"
...@@ -191,7 +191,7 @@ def compile_ui(*filenames: str) -> None: ...@@ -191,7 +191,7 @@ def compile_ui(*filenames: str) -> None:
compile(filename) compile(filename)
def compile_all(): def compile_all() -> None:
""" """
The compiler will search for resource (.qrc) files and form (.ui) files The compiler will search for resource (.qrc) files and form (.ui) files
and compile accordingly. and compile accordingly.
......
...@@ -54,9 +54,8 @@ def wdf_allpass( ...@@ -54,9 +54,8 @@ def wdf_allpass(
------- -------
Signal flow graph Signal flow graph
""" """
np.asarray(coefficients) np_coefficients = np.squeeze(np.asarray(coefficients))
coefficients = np.squeeze(coefficients) if np_coefficients.ndim != 1:
if coefficients.ndim != 1:
raise TypeError("coefficients must be a 1D-array") raise TypeError("coefficients must be a 1D-array")
if input_op is None: if input_op is None:
input_op = Input() input_op = Input()
...@@ -64,11 +63,11 @@ def wdf_allpass( ...@@ -64,11 +63,11 @@ def wdf_allpass(
output = Output() output = Output()
if name is None: if name is None:
name = "WDF allpass section" name = "WDF allpass section"
order = len(coefficients) order = len(np_coefficients)
odd_order = order % 2 odd_order = order % 2
if odd_order: if odd_order:
# First-order section # First-order section
coeff = coefficients[0] coeff = np_coefficients[0]
adaptor0 = SymmetricTwoportAdaptor( adaptor0 = SymmetricTwoportAdaptor(
coeff, coeff,
input_op, input_op,
...@@ -87,7 +86,7 @@ def wdf_allpass( ...@@ -87,7 +86,7 @@ def wdf_allpass(
offset1, offset2 = (1, 2) if odd_order else (0, 1) offset1, offset2 = (1, 2) if odd_order else (0, 1)
for n in range(sos_count): for n in range(sos_count):
adaptor1 = SymmetricTwoportAdaptor( adaptor1 = SymmetricTwoportAdaptor(
coefficients[2 * n + offset1], np_coefficients[2 * n + offset1],
signal_out, signal_out,
latency=latency, latency=latency,
latency_offsets=latency_offsets, latency_offsets=latency_offsets,
...@@ -97,7 +96,7 @@ def wdf_allpass( ...@@ -97,7 +96,7 @@ def wdf_allpass(
delay1 = Delay(adaptor1.output(1)) delay1 = Delay(adaptor1.output(1))
delay2 = Delay() delay2 = Delay()
adaptor2 = SymmetricTwoportAdaptor( adaptor2 = SymmetricTwoportAdaptor(
coefficients[2 * n + offset2], np_coefficients[2 * n + offset2],
delay1, delay1,
delay2, delay2,
latency=latency, latency=latency,
......
...@@ -47,8 +47,10 @@ class Signal(AbstractGraphComponent): ...@@ -47,8 +47,10 @@ class Signal(AbstractGraphComponent):
def __init__( def __init__(
self, self,
source: Optional["OutputPort"] = None, source: Optional[Union["OutputPort", "Signal", "Operation"]] = None,
destination: Optional["InputPort"] = None, destination: Optional[
Union["InputPort", "Signal", "Operation"]
] = None,
bits: Optional[int] = None, bits: Optional[int] = None,
name: Name = Name(""), name: Name = Name(""),
): ):
...@@ -95,7 +97,7 @@ class Signal(AbstractGraphComponent): ...@@ -95,7 +97,7 @@ class Signal(AbstractGraphComponent):
Parameters Parameters
========== ==========
source : OutputPort, Signal, or Operation, optional source : OutputPort, Signal, or Operation
OutputPort, Signal, or Operation to connect as source to the signal. OutputPort, Signal, or Operation to connect as source to the signal.
If Signal, it will connect to the source of the signal, so later on If Signal, it will connect to the source of the signal, so later on
changing the source of the argument Signal will not affect this Signal. changing the source of the argument Signal will not affect this Signal.
...@@ -107,13 +109,15 @@ class Signal(AbstractGraphComponent): ...@@ -107,13 +109,15 @@ class Signal(AbstractGraphComponent):
if isinstance(source, (Signal, Operation)): if isinstance(source, (Signal, Operation)):
# Signal or Operation # Signal or Operation
source = source.source new_source = source.source
else:
new_source = source
if source is not self._source: if new_source is not self._source:
self.remove_source() self.remove_source()
self._source = source self._source = new_source
if self not in source.signals: if self not in new_source.signals:
source.add_signal(self) new_source.add_signal(self)
def set_destination( def set_destination(
self, destination: Union["InputPort", "Signal", "Operation"] self, destination: Union["InputPort", "Signal", "Operation"]
...@@ -134,15 +138,20 @@ class Signal(AbstractGraphComponent): ...@@ -134,15 +138,20 @@ class Signal(AbstractGraphComponent):
is raised. is raised.
""" """
if hasattr(destination, "destination"): # import here to avoid cyclic imports
from b_asic.operation import Operation
if isinstance(destination, (Signal, Operation)):
# Signal or Operation # Signal or Operation
destination = destination.destination new_destination = destination.destination
else:
new_destination = destination
if destination is not self._destination: if new_destination is not self._destination:
self.remove_destination() self.remove_destination()
self._destination = destination self._destination = new_destination
if self not in destination.signals: if self not in new_destination.signals:
destination.add_signal(self) new_destination.add_signal(self)
def remove_source(self) -> None: def remove_source(self) -> None:
""" """
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment