Skip to content
Snippets Groups Projects
Commit 93f159ce authored by Oscar Gustafsson's avatar Oscar Gustafsson :bicyclist:
Browse files

Cleanup code and add typing

parent a8b79691
No related branches found
No related tags found
1 merge request!234Cleanup code and add typing
Pipeline #90367 passed
...@@ -9,7 +9,7 @@ import logging ...@@ -9,7 +9,7 @@ import logging
import os import os
import sys import sys
from pprint import pprint from pprint import pprint
from typing import Optional, Tuple from typing import List, Optional, Tuple
from qtpy.QtCore import QFileInfo, QSize, Qt from qtpy.QtCore import QFileInfo, QSize, Qt
from qtpy.QtGui import QCursor, QIcon, QKeySequence, QPainter from qtpy.QtGui import QCursor, QIcon, QKeySequence, QPainter
...@@ -47,6 +47,7 @@ from b_asic.gui_utils.plot_window import PlotWindow ...@@ -47,6 +47,7 @@ from b_asic.gui_utils.plot_window import PlotWindow
from b_asic.operation import Operation from b_asic.operation import Operation
from b_asic.port import InputPort, OutputPort from b_asic.port import InputPort, OutputPort
from b_asic.save_load_structure import python_to_sfg, sfg_to_python from b_asic.save_load_structure import python_to_sfg, sfg_to_python
from b_asic.signal import Signal
from b_asic.signal_flow_graph import SFG from b_asic.signal_flow_graph import SFG
# from b_asic import FastSimulation # from b_asic import FastSimulation
...@@ -158,13 +159,12 @@ class MainWindow(QMainWindow): ...@@ -158,13 +159,12 @@ class MainWindow(QMainWindow):
self.toolbar.addAction("Clear workspace", self.clear_workspace) self.toolbar.addAction("Clear workspace", self.clear_workspace)
def resizeEvent(self, event) -> None: def resizeEvent(self, event) -> None:
self.ui.operation_box.setGeometry( ui_width = self.ui.operation_box.width()
10, 10, self.ui.operation_box.width(), self.height() self.ui.operation_box.setGeometry(10, 10, ui_width, self.height())
)
self.graphic_view.setGeometry( self.graphic_view.setGeometry(
self.ui.operation_box.width() + 20, ui_width + 20,
60, 60,
self.width() - self.ui.operation_box.width() - 20, self.width() - ui_width - 20,
self.height() - 30, self.height() - 30,
) )
super().resizeEvent(event) super().resizeEvent(event)
...@@ -214,20 +214,20 @@ class MainWindow(QMainWindow): ...@@ -214,20 +214,20 @@ class MainWindow(QMainWindow):
self.logger.info("Saved SFG to path: " + str(module)) self.logger.info("Saved SFG to path: " + str(module))
def save_work(self, event=None): def save_work(self, event=None) -> None:
self.sfg_widget = SelectSFGWindow(self) self.sfg_widget = SelectSFGWindow(self)
self.sfg_widget.show() self.sfg_widget.show()
# Wait for input to dialog. # Wait for input to dialog.
self.sfg_widget.ok.connect(self._save_work) self.sfg_widget.ok.connect(self._save_work)
def load_work(self, event=None): def load_work(self, event=None) -> None:
module, accepted = QFileDialog().getOpenFileName() module, accepted = QFileDialog().getOpenFileName()
if not accepted: if not accepted:
return return
self._load_from_file(module) self._load_from_file(module)
def _load_from_file(self, module): def _load_from_file(self, module) -> None:
self.logger.info("Loading SFG from path: " + str(module)) self.logger.info("Loading SFG from path: " + str(module))
try: try:
sfg, positions = python_to_sfg(module) sfg, positions = python_to_sfg(module)
...@@ -252,7 +252,7 @@ class MainWindow(QMainWindow): ...@@ -252,7 +252,7 @@ class MainWindow(QMainWindow):
self._load_sfg(sfg, positions) self._load_sfg(sfg, positions)
self.logger.info("Loaded SFG from path: " + str(module)) self.logger.info("Loaded SFG from path: " + str(module))
def _load_sfg(self, sfg, positions=None): def _load_sfg(self, sfg, positions=None) -> None:
if positions is None: if positions is None:
positions = {} positions = {}
...@@ -299,11 +299,11 @@ class MainWindow(QMainWindow): ...@@ -299,11 +299,11 @@ class MainWindow(QMainWindow):
self.sfg_dict[sfg.name] = sfg self.sfg_dict[sfg.name] = sfg
self.update() self.update()
def exit_app(self): def exit_app(self) -> None:
self.logger.info("Exiting the application.") self.logger.info("Exiting the application.")
QApplication.quit() QApplication.quit()
def clear_workspace(self): def clear_workspace(self) -> None:
self.logger.info("Clearing workspace from operations and SFGs.") self.logger.info("Clearing workspace from operations and SFGs.")
self.pressed_operations.clear() self.pressed_operations.clear()
self.pressed_ports.clear() self.pressed_ports.clear()
...@@ -341,7 +341,7 @@ class MainWindow(QMainWindow): ...@@ -341,7 +341,7 @@ class MainWindow(QMainWindow):
sfg = SFG(inputs=inputs, outputs=outputs, name=name) sfg = SFG(inputs=inputs, outputs=outputs, name=name)
self.logger.info("Created SFG with name: %s from selected operations." % name) self.logger.info("Created SFG with name: %s from selected operations." % name)
def check_equality(signal, signal_2): def check_equality(signal: Signal, signal_2: Signal) -> bool:
if not ( if not (
signal.source.operation.type_name() signal.source.operation.type_name()
== signal_2.source.operation.type_name() == signal_2.source.operation.type_name()
...@@ -440,11 +440,11 @@ class MainWindow(QMainWindow): ...@@ -440,11 +440,11 @@ class MainWindow(QMainWindow):
self.sfg_dict[sfg.name] = sfg self.sfg_dict[sfg.name] = sfg
def _show_precedence_graph(self, event=None) -> None: def _show_precedence_graph(self, event=None) -> None:
self.dialog = ShowPCWindow(self) self._precedence_graph_dialog = ShowPCWindow(self)
self.dialog.add_sfg_to_dialog() self._precedence_graph_dialog.add_sfg_to_dialog()
self.dialog.show() self._precedence_graph_dialog.show()
def get_operations_from_namespace(self, namespace) -> None: def get_operations_from_namespace(self, namespace) -> List[str]:
self.logger.info( self.logger.info(
"Fetching operations from namespace: " + str(namespace.__name__) "Fetching operations from namespace: " + str(namespace.__name__)
) )
...@@ -667,11 +667,11 @@ class MainWindow(QMainWindow): ...@@ -667,11 +667,11 @@ class MainWindow(QMainWindow):
self.update() self.update()
def paintEvent(self, event): def paintEvent(self, event) -> None:
for signal in self.signalPortDict.keys(): for signal in self.signalPortDict.keys():
signal.moveLine() signal.moveLine()
def _select_operations(self): def _select_operations(self) -> None:
selected = [button.widget() for button in self.scene.selectedItems()] selected = [button.widget() for button in self.scene.selectedItems()]
for button in selected: for button in selected:
button._toggle_button(pressed=False) button._toggle_button(pressed=False)
...@@ -682,8 +682,8 @@ class MainWindow(QMainWindow): ...@@ -682,8 +682,8 @@ class MainWindow(QMainWindow):
self.pressed_operations = selected self.pressed_operations = selected
def _simulate_sfg(self): def _simulate_sfg(self) -> None:
for sfg, properties in self.dialog.properties.items(): for sfg, properties in self._simulation_dialog.properties.items():
self.logger.info("Simulating SFG with name: %s" % str(sfg.name)) self.logger.info("Simulating SFG with name: %s" % str(sfg.name))
simulation = FastSimulation(sfg, input_providers=properties["input_values"]) simulation = FastSimulation(sfg, input_providers=properties["input_values"])
l_result = simulation.run_for( l_result = simulation.run_for(
...@@ -697,36 +697,32 @@ class MainWindow(QMainWindow): ...@@ -697,36 +697,32 @@ class MainWindow(QMainWindow):
if properties["show_plot"]: if properties["show_plot"]:
self.logger.info("Opening plot for SFG with name: " + str(sfg.name)) self.logger.info("Opening plot for SFG with name: " + str(sfg.name))
self.logger.info( self._plot = PlotWindow(simulation.results, sfg_name=sfg.name)
"To save the plot press 'Ctrl+S' when the plot is focused." self._plot.show()
)
# self.plot = Plot(simulation, sfg, self)
self.plot = PlotWindow(simulation.results)
self.plot.show()
def simulate_sfg(self, event=None): def simulate_sfg(self, event=None) -> None:
self.dialog = SimulateSFGWindow(self) self._simulation_dialog = SimulateSFGWindow(self)
for _, sfg in self.sfg_dict.items(): for _, sfg in self.sfg_dict.items():
self.dialog.add_sfg_to_dialog(sfg) self._simulation_dialog.add_sfg_to_dialog(sfg)
self.dialog.show() self._simulation_dialog.show()
# Wait for input to dialog. # Wait for input to dialog.
# Kinda buggy because of the separate window in the same thread. # Kinda buggy because of the separate window in the same thread.
self.dialog.simulate.connect(self._simulate_sfg) self._simulation_dialog.simulate.connect(self._simulate_sfg)
def display_faq_page(self, event=None): def display_faq_page(self, event=None) -> None:
self.faq_page = FaqWindow(self) self._faq_page = FaqWindow(self)
self.faq_page.scroll_area.show() self._faq_page.scroll_area.show()
def display_about_page(self, event=None): def display_about_page(self, event=None) -> None:
self.about_page = AboutWindow(self) self._about_page = AboutWindow(self)
self.about_page.show() self._about_page.show()
def display_keybinds_page(self, event=None): def display_keybinds_page(self, event=None) -> None:
self.keybinds_page = KeybindsWindow(self) self._keybinds_page = KeybindsWindow(self)
self.keybinds_page.show() self._keybinds_page.show()
def start_gui(): def start_gui():
......
...@@ -32,6 +32,19 @@ class SignalGeneratorInput(QGridLayout): ...@@ -32,6 +32,19 @@ class SignalGeneratorInput(QGridLayout):
"""Return the SignalGenerator based on the graphical input.""" """Return the SignalGenerator based on the graphical input."""
raise NotImplementedError raise NotImplementedError
def _parse_number(self, string, _type, name, default):
string = string.strip()
try:
if not string:
return default
return _type(string)
except ValueError:
self._logger.warning(
f"Cannot parse {name}: {string} not a {_type.__name__}, setting to"
f" {default}"
)
return default
class DelayInput(SignalGeneratorInput): class DelayInput(SignalGeneratorInput):
""" """
...@@ -82,6 +95,7 @@ class ZeroPadInput(SignalGeneratorInput): ...@@ -82,6 +95,7 @@ class ZeroPadInput(SignalGeneratorInput):
self.input_label = QLabel("Input") self.input_label = QLabel("Input")
self.addWidget(self.input_label, 0, 0) self.addWidget(self.input_label, 0, 0)
self.input_sequence = QLineEdit() self.input_sequence = QLineEdit()
self.input_sequence.setPlaceholderText("0.1, -0.2, 0.7")
self.addWidget(self.input_sequence, 0, 1) self.addWidget(self.input_sequence, 0, 1)
def get_generator(self) -> SignalGenerator: def get_generator(self) -> SignalGenerator:
...@@ -91,14 +105,11 @@ class ZeroPadInput(SignalGeneratorInput): ...@@ -91,14 +105,11 @@ class ZeroPadInput(SignalGeneratorInput):
try: try:
if not val: if not val:
val = 0 val = 0
val = complex(val) val = complex(val)
except ValueError: except ValueError:
self._logger.warning(f"Skipping value: {val}, not a digit.") self._logger.warning(f"Skipping value: {val}, not a digit.")
continue continue
input_values.append(val) input_values.append(val)
return ZeroPad(input_values) return ZeroPad(input_values)
...@@ -138,34 +149,20 @@ class SinusoidInput(SignalGeneratorInput): ...@@ -138,34 +149,20 @@ class SinusoidInput(SignalGeneratorInput):
self.frequency_label = QLabel("Frequency") self.frequency_label = QLabel("Frequency")
self.addWidget(self.frequency_label, 0, 0) self.addWidget(self.frequency_label, 0, 0)
self.frequency_input = QLineEdit() self.frequency_input = QLineEdit()
self.frequency_input.setText("0.1")
self.addWidget(self.frequency_input, 0, 1) self.addWidget(self.frequency_input, 0, 1)
self.phase_label = QLabel("Phase") self.phase_label = QLabel("Phase")
self.addWidget(self.phase_label, 1, 0) self.addWidget(self.phase_label, 1, 0)
self.phase_input = QLineEdit() self.phase_input = QLineEdit()
self.phase_input.setText("0.0")
self.addWidget(self.phase_input, 1, 1) self.addWidget(self.phase_input, 1, 1)
def get_generator(self) -> SignalGenerator: def get_generator(self) -> SignalGenerator:
frequency = self.frequency_input.text().strip() frequency = self._parse_number(
try: self.frequency_input.text(), float, "Frequency", 0.1
if not frequency: )
frequency = 0.1 phase = self._parse_number(self.phase_input.text(), float, "Phase", 0.0)
frequency = float(frequency)
except ValueError:
self._logger.warning(f"Cannot parse frequency: {frequency} not a number.")
frequency = 0.1
phase = self.phase_input.text().strip()
try:
if not phase:
phase = 0
phase = float(phase)
except ValueError:
self._logger.warning(f"Cannot parse phase: {phase} not a number.")
phase = 0
return Sinusoid(frequency, phase) return Sinusoid(frequency, phase)
...@@ -196,26 +193,10 @@ class GaussianInput(SignalGeneratorInput): ...@@ -196,26 +193,10 @@ class GaussianInput(SignalGeneratorInput):
self.addWidget(self.seed_spin_box, 2, 1) self.addWidget(self.seed_spin_box, 2, 1)
def get_generator(self) -> SignalGenerator: def get_generator(self) -> SignalGenerator:
scale = self.scale_input.text().strip() scale = self._parse_number(
try: self.scale_input.text(), float, "Standard deviation", 1.0
if not scale: )
scale = 1 loc = self._parse_number(self.loc_input.text(), float, "Average value", 0.0)
scale = float(scale)
except ValueError:
self._logger.warning(f"Cannot parse scale: {scale} not a number.")
scale = 1
loc = self.loc_input.text().strip()
try:
if not loc:
loc = 0
loc = float(loc)
except ValueError:
self._logger.warning(f"Cannot parse loc: {loc} not a number.")
loc = 0
return Gaussian(self.seed_spin_box.value(), loc, scale) return Gaussian(self.seed_spin_box.value(), loc, scale)
...@@ -246,26 +227,8 @@ class UniformInput(SignalGeneratorInput): ...@@ -246,26 +227,8 @@ class UniformInput(SignalGeneratorInput):
self.addWidget(self.seed_spin_box, 2, 1) self.addWidget(self.seed_spin_box, 2, 1)
def get_generator(self) -> SignalGenerator: def get_generator(self) -> SignalGenerator:
low = self.low_input.text().strip() low = self._parse_number(self.low_input.text(), float, "Lower bound", -1.0)
try: high = self._parse_number(self.high_input.text(), float, "Upper bound", 1.0)
if not low:
low = -1.0
low = float(low)
except ValueError:
self._logger.warning(f"Cannot parse low: {low} not a number.")
low = -1.0
high = self.high_input.text().strip()
try:
if not high:
high = 1.0
high = float(high)
except ValueError:
self._logger.warning(f"Cannot parse high: {high} not a number.")
high = 1.0
return Uniform(self.seed_spin_box.value(), low, high) return Uniform(self.seed_spin_box.value(), low, high)
...@@ -284,16 +247,9 @@ class ConstantInput(SignalGeneratorInput): ...@@ -284,16 +247,9 @@ class ConstantInput(SignalGeneratorInput):
self.addWidget(self.constant_input, 0, 1) self.addWidget(self.constant_input, 0, 1)
def get_generator(self) -> SignalGenerator: def get_generator(self) -> SignalGenerator:
constant = self.constant_input.text().strip() constant = self._parse_number(
try: self.constant_input.text(), complex, "Constant", 1.0
if not constant: )
constant = 1.0
constant = complex(constant)
except ValueError:
self._logger.warning(f"Cannot parse constant: {constant} not a number.")
constant = 0.0
return Constant(constant) return Constant(constant)
......
""" """
B-ASIC window to simulate an SFG. B-ASIC window to simulate an SFG.
""" """
import numpy as np
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
from qtpy.QtCore import Qt, Signal from qtpy.QtCore import Qt, Signal
from qtpy.QtGui import QKeySequence
from qtpy.QtWidgets import ( from qtpy.QtWidgets import (
QCheckBox, QCheckBox,
QComboBox, QComboBox,
QDialog, QDialog,
QFileDialog,
QFormLayout, QFormLayout,
QFrame, QFrame,
QGridLayout, QGridLayout,
QHBoxLayout,
QLabel, QLabel,
QLayout, QLayout,
QLineEdit,
QPushButton, QPushButton,
QShortcut,
QSizePolicy,
QSpinBox, QSpinBox,
QVBoxLayout, QVBoxLayout,
) )
from b_asic.GUI.signal_generator_input import _GENERATOR_MAPPING from b_asic.GUI.signal_generator_input import _GENERATOR_MAPPING
from b_asic.signal_generator import FromFile
class SimulateSFGWindow(QDialog): class SimulateSFGWindow(QDialog):
...@@ -50,7 +40,7 @@ class SimulateSFGWindow(QDialog): ...@@ -50,7 +40,7 @@ class SimulateSFGWindow(QDialog):
self.input_grid = QGridLayout() self.input_grid = QGridLayout()
self.input_files = {} self.input_files = {}
def add_sfg_to_dialog(self, sfg): def add_sfg_to_dialog(self, sfg) -> None:
sfg_layout = QVBoxLayout() sfg_layout = QVBoxLayout()
options_layout = QFormLayout() options_layout = QFormLayout()
...@@ -112,7 +102,7 @@ class SimulateSFGWindow(QDialog): ...@@ -112,7 +102,7 @@ class SimulateSFGWindow(QDialog):
self.sfg_to_layout[sfg] = sfg_layout self.sfg_to_layout[sfg] = sfg_layout
self.dialog_layout.addLayout(sfg_layout) self.dialog_layout.addLayout(sfg_layout)
def change_input_format(self, i, text): def change_input_format(self, i: int, text: str) -> None:
grid = self.input_grid.itemAtPosition(i, 2) grid = self.input_grid.itemAtPosition(i, 2)
if grid: if grid:
for j in reversed(range(grid.count())): for j in reversed(range(grid.count())):
...@@ -127,13 +117,11 @@ class SimulateSFGWindow(QDialog): ...@@ -127,13 +117,11 @@ class SimulateSFGWindow(QDialog):
if text in _GENERATOR_MAPPING: if text in _GENERATOR_MAPPING:
param_grid = _GENERATOR_MAPPING[text](self._window.logger) param_grid = _GENERATOR_MAPPING[text](self._window.logger)
else: else:
raise Exception("Input selection is not implemented") raise ValueError("Input selection is not implemented")
self.input_grid.addLayout(param_grid, i, 2) self.input_grid.addLayout(param_grid, i, 2)
return def save_properties(self) -> None:
def save_properties(self):
for sfg, _properties in self.input_fields.items(): for sfg, _properties in self.input_fields.items():
ic_value = self.input_fields[sfg]["iteration_count"].value() ic_value = self.input_fields[sfg]["iteration_count"].value()
if ic_value == 0: if ic_value == 0:
...@@ -148,7 +136,7 @@ class SimulateSFGWindow(QDialog): ...@@ -148,7 +136,7 @@ class SimulateSFGWindow(QDialog):
if in_format in _GENERATOR_MAPPING: if in_format in _GENERATOR_MAPPING:
tmp2 = in_param.get_generator() tmp2 = in_param.get_generator()
else: else:
raise Exception("Input selection is not implemented") raise ValueError("Input selection is not implemented")
input_values.append(tmp2) input_values.append(tmp2)
...@@ -166,45 +154,3 @@ class SimulateSFGWindow(QDialog): ...@@ -166,45 +154,3 @@ class SimulateSFGWindow(QDialog):
self.accept() self.accept()
self.simulate.emit() self.simulate.emit()
class Plot(FigureCanvas):
def __init__(
self, simulation, sfg, window, parent=None, width=5, height=4, dpi=100
):
self.simulation = simulation
self.sfg = sfg
self.dpi = dpi
self._window = window
fig = Figure(figsize=(width, height), dpi=dpi)
fig.suptitle(sfg.name, fontsize=20)
self.axes = fig.add_subplot(111)
FigureCanvas.__init__(self, fig)
self.setParent(parent)
FigureCanvas.setSizePolicy(self, QSizePolicy.Expanding, QSizePolicy.Expanding)
FigureCanvas.updateGeometry(self)
self.save_figure = QShortcut(QKeySequence("Ctrl+S"), self)
self.save_figure.activated.connect(self._save_plot_figure)
self._plot_values_sfg()
def _save_plot_figure(self):
self._window.logger.info(f"Saving plot of figure: {self.sfg.name}.")
file_choices = "PNG (*.png)|*.png"
path, ext = QFileDialog.getSaveFileName(self, "Save file", "", file_choices)
path = path.encode("utf-8")
if not path[-4:] == file_choices[-4:].encode("utf-8"):
path += file_choices[-4:].encode("utf-8")
if path:
self.print_figure(path.decode(), dpi=self.dpi)
self._window.logger.info(f"Saved plot: {self.sfg.name} to path: {path}.")
def _plot_values_sfg(self):
x_axis = list(range(len(self.simulation.results["0"])))
for _output in range(self.sfg.output_count):
y_axis = self.simulation.results[str(_output)]
self.axes.plot(x_axis, y_axis)
...@@ -141,9 +141,9 @@ def test_help_dialogs(qtbot): ...@@ -141,9 +141,9 @@ def test_help_dialogs(qtbot):
widget.display_about_page() widget.display_about_page()
widget.display_keybinds_page() widget.display_keybinds_page()
qtbot.wait(100) qtbot.wait(100)
widget.faq_page.close() widget._faq_page.close()
widget.about_page.close() widget._about_page.close()
widget.keybinds_page.close() widget._keybinds_page.close()
widget.exit_app() widget.exit_app()
...@@ -159,7 +159,7 @@ def test_simulate(qtbot, datadir): ...@@ -159,7 +159,7 @@ def test_simulate(qtbot, datadir):
qtbot.wait(100) qtbot.wait(100)
# widget.dialog.save_properties() # widget.dialog.save_properties()
# qtbot.wait(100) # qtbot.wait(100)
widget.dialog.close() widget._simulation_dialog.close()
widget.exit_app() widget.exit_app()
......
...@@ -27,8 +27,8 @@ def test_MemoryVariables(secondorder_iir_schedule): ...@@ -27,8 +27,8 @@ def test_MemoryVariables(secondorder_iir_schedule):
pc = secondorder_iir_schedule.get_memory_variables() pc = secondorder_iir_schedule.get_memory_variables()
mem_vars = pc.collection mem_vars = pc.collection
pattern = re.compile( pattern = re.compile(
"MemoryVariable\\(3, <b_asic.port.OutputPort object at 0x[a-f0-9]+>," "MemoryVariable\\(3, <b_asic.port.OutputPort object at 0x[a-fA-F0-9]+>,"
" {<b_asic.port.InputPort object at 0x[a-f0-9]+>: 4}, 'cmul1.0'\\)" " {<b_asic.port.InputPort object at 0x[a-fA-F0-9]+>: 4}, 'cmul1.0'\\)"
) )
mem_var = [m for m in mem_vars if m.name == 'cmul1.0'][0] mem_var = [m for m in mem_vars if m.name == 'cmul1.0'][0]
assert pattern.match(repr(mem_var)) assert pattern.match(repr(mem_var))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment