diff --git a/b_asic/GUI/drag_button.py b/b_asic/GUI/drag_button.py index 9f10c1cebc9c8590a5a671063a15739954fb4a49..b81ea9d4a6cfcabcaea071154fcaaec452a8e84a 100644 --- a/b_asic/GUI/drag_button.py +++ b/b_asic/GUI/drag_button.py @@ -3,73 +3,88 @@ Drag button class. This class creates a dragbutton which can be clicked, dragged and dropped. """ - import os.path -from PyQt5.QtWidgets import QPushButton -from PyQt5.QtCore import Qt, QSize +from properties_window import PropertiesWindow + +from PyQt5.QtWidgets import QPushButton, QMenu, QAction +from PyQt5.QtCore import Qt, QSize, pyqtSignal from PyQt5.QtGui import QIcon +from utils import decorate_class, handle_error + +@decorate_class(handle_error) class DragButton(QPushButton): - def __init__(self, name, operation, operation_path_name, window, parent = None): + connectionRequested = pyqtSignal(QPushButton) + moved = pyqtSignal() + def __init__(self, name, operation, operation_path_name, is_show_name, window, parent = None): self.name = name - self.__window = window + self.is_show_name = is_show_name + self._window = window self.operation = operation self.operation_path_name = operation_path_name self.clicked = 0 self.pressed = False - super(DragButton, self).__init__(self.__window) - - def mousePressEvent(self, event): self._mouse_press_pos = None self._mouse_move_pos = None + super(DragButton, self).__init__(self._window) + + def contextMenuEvent(self, event): + menu = QMenu() + properties = QAction("Properties") + menu.addAction(properties) + properties.triggered.connect(self.show_properties_window) + menu.exec_(self.cursor().pos()) + + def show_properties_window(self, event): + self.properties_window = PropertiesWindow(self, self._window) + self.properties_window.show() + + def add_label(self, label): + self.label = label + + def mousePressEvent(self, event): if event.button() == Qt.LeftButton: - self._mouse_press_pos = event.globalPos() - self._mouse_move_pos = event.globalPos() + self._mouse_press_pos = event.pos() + self._mouse_move_pos = event.pos() - for signal in self.__window.signalList: + for signal in self._window.signalList: signal.update() self.clicked += 1 if self.clicked == 1: self.pressed = True self.setStyleSheet("background-color: grey; border-style: solid;\ - border-color: black; border-width: 2px; border-radius: 10px") + border-color: black; border-width: 2px") path_to_image = os.path.join('operation_icons', self.operation_path_name + '_grey.png') self.setIcon(QIcon(path_to_image)) - self.setIconSize(QSize(50, 50)) - self.__window.pressed_button.append(self) + self.setIconSize(QSize(55, 55)) + self._window.pressed_button.append(self) elif self.clicked == 2: self.clicked = 0 self.pressed = False self.setStyleSheet("background-color: white; border-style: solid;\ - border-color: black; border-width: 2px; border-radius: 10px") + border-color: black; border-width: 2px") path_to_image = os.path.join('operation_icons', self.operation_path_name + '.png') self.setIcon(QIcon(path_to_image)) - self.setIconSize(QSize(50, 50)) - self.__window.pressed_button.remove(self) + self.setIconSize(QSize(55, 55)) + self._window.pressed_button.remove(self) super(DragButton, self).mousePressEvent(event) def mouseMoveEvent(self, event): if event.buttons() == Qt.LeftButton: - cur_pos = self.mapToGlobal(self.pos()) - global_pos = event.globalPos() - diff = global_pos - self._mouse_move_pos - new_pos = self.mapFromGlobal(cur_pos + diff) - self.move(new_pos) - - self._mouse_move_pos = global_pos - - self.__window.update() + self.move(self.mapToParent(event.pos() - self._mouse_press_pos)) + + self._window.update() super(DragButton, self).mouseMoveEvent(event) def mouseReleaseEvent(self, event): if self._mouse_press_pos is not None: - moved = event.globalPos() - self._mouse_press_pos + moved = event.pos() - self._mouse_press_pos if moved.manhattanLength() > 3: event.ignore() return diff --git a/b_asic/GUI/improved_main_window.py b/b_asic/GUI/improved_main_window.py deleted file mode 100644 index fb580eb9a0cf67305e3b8a322538cc30b75a7569..0000000000000000000000000000000000000000 --- a/b_asic/GUI/improved_main_window.py +++ /dev/null @@ -1,189 +0,0 @@ -"""@package docstring -B-ASIC GUI Module. -This python file is the main window of the GUI for B-ASIC. -""" - -from os import getcwd, path -import sys - -from drag_button import DragButton -from gui_interface import Ui_main_window -from arrow import Arrow -from port_button import PortButton - -from b_asic import Operation -import b_asic.core_operations as c_oper -import b_asic.special_operations as s_oper - -from numpy import linspace - -from PyQt5.QtWidgets import QApplication, QWidget, QMainWindow, QLabel, QAction,\ -QStatusBar, QMenuBar, QLineEdit, QPushButton, QSlider, QScrollArea, QVBoxLayout,\ -QHBoxLayout, QDockWidget, QToolBar, QMenu, QLayout, QSizePolicy, QListWidget,\ -QListWidgetItem, QGraphicsView, QGraphicsScene, QShortcut -from PyQt5.QtCore import Qt, QSize -from PyQt5.QtGui import QIcon, QFont, QPainter, QPen, QBrush, QKeySequence - - -class MainWindow(QMainWindow): - def __init__(self): - super(MainWindow, self).__init__() - self.ui = Ui_main_window() - self.ui.setupUi(self) - self.setWindowTitle(" ") - self.setWindowIcon(QIcon('small_logo.png')) - self.scene = None - self._operations_from_name = dict() - self.zoom = 1 - self.operationList = [] - self.signalList = [] - self.pressed_button = [] - self.portList = [] - self.pressed_ports = [] - self.source = None - - self.init_ui() - self.add_operations_from_namespace(c_oper, self.ui.core_operations_list) - self.add_operations_from_namespace(s_oper, self.ui.special_operations_list) - - self.shortcut_core = QShortcut(QKeySequence("Ctrl+R"), self.ui.operation_box) - self.shortcut_core.activated.connect(self._refresh_operations_list_from_namespace) - - def init_ui(self): - self.ui.core_operations_list.itemClicked.connect(self.on_list_widget_item_clicked) - self.ui.special_operations_list.itemClicked.connect(self.on_list_widget_item_clicked) - self.ui.exit_menu.triggered.connect(self.exit_app) - self.create_graphics_view() - - def create_graphics_view(self): - self.scene = QGraphicsScene() - self.graphic_view = QGraphicsView(self.scene, self) - self.graphic_view.setRenderHint(QPainter.Antialiasing) - self.graphic_view.setGeometry(250, 40, 600, 520) - self.graphic_view.setDragMode(QGraphicsView.ScrollHandDrag) - - def wheelEvent(self, event): - old_zoom = self.zoom - self.zoom += event.angleDelta().y()/2500 - self.graphic_view.scale(self.zoom, self.zoom) - self.zoom = old_zoom - - def exit_app(self, checked): - QApplication.quit() - - def _determine_port_distance(self, length, ports): - """Determine the distance between each port on the side of an operation. - The method returns the distance that each port should have from 0. - """ - return [length / 2] if ports == 1 else linspace(0, length, ports) - - def _create_port(self, operation, output_port=True): - text = ">" if output_port else "<" - button = PortButton(text, operation, self) - button.setStyleSheet("background-color: white") - button.connectionRequested.connect(self.connectButton) - return button - - def add_ports(self, operation): - _output_ports_dist = self._determine_port_distance(50 - 15, operation.operation.output_count) - _input_ports_dist = self._determine_port_distance(50 - 15, operation.operation.input_count) - - for dist in _input_ports_dist: - port = self._create_port(operation) - port.move(0, dist) - port.show() - - for dist in _output_ports_dist: - port = self._create_port(operation) - port.move(50 - 15, dist) - port.show() - - def get_operations_from_namespace(self, namespace): - return [comp for comp in dir(namespace) if hasattr(getattr(namespace, comp), "type_name")] - - def add_operations_from_namespace(self, namespace, _list): - for attr_name in self.get_operations_from_namespace(namespace): - attr = getattr(namespace, attr_name) - try: - attr.type_name() - item = QListWidgetItem(attr_name) - _list.addItem(item) - self._operations_from_name[attr_name] = attr - except NotImplementedError: - pass - - def _create_operation(self, item): - try: - attr_oper = self._operations_from_name[item.text()]() - attr_button = DragButton(attr_oper.graph_id, attr_oper, attr_oper.type_name().lower(), self) - attr_button.move(250, 100) - attr_button.setFixedSize(50, 50) - attr_button.setStyleSheet("background-color: white; border-style: solid;\ - border-color: black; border-width: 2px; border-radius: 10px") - self.add_ports(attr_button) - - icon_path = path.join("operation_icons", f"{attr_oper.type_name().lower()}.png") - if not path.exists(icon_path): - icon_path = path.join("operation_icons", f"custom_operation.png") - attr_button.setIcon(QIcon(icon_path)) - attr_button.setIconSize(QSize(50, 50)) - - attr_button.setParent(None) - self.scene.addWidget(attr_button) - self.operationList.append(attr_button) - except Exception as e: - print("Unexpected error occured: ", e) - - def _refresh_operations_list_from_namespace(self): - self.ui.core_operations_list.clear() - self.ui.special_operations_list.clear() - - self.add_operations_from_namespace(c_oper, self.ui.core_operations_list) - self.add_operations_from_namespace(s_oper, self.ui.special_operations_list) - - - def print_input_port_1(self): - print("Input port 1") - - def print_input_port_2(self): - print("Input port 2") - - def print_output_port_1(self): - print("Output port 1") - - def print_output_port_2(self): - print("Output port 2") - - def on_list_widget_item_clicked(self, item): - self._create_operation(item) - - def keyPressEvent(self, event): - pressed_buttons = [] - for op in self.operationList: - if op.pressed: - pressed_buttons.append(op) - if event.key() == Qt.Key_Delete: - for pressed_op in pressed_buttons: - self.operationList.remove(pressed_op) - pressed_op.remove() - super().keyPressEvent(event) - - def connectButton(self, button): - if len(self.pressed_ports) < 2: - return - for i in range(len(self.pressed_ports) - 1): - line = Arrow(self.pressed_ports[i], self.pressed_ports[i + 1], self) - self.scene.addItem(line) - self.signalList.append(line) - - self.update() - - def paintEvent(self, event): - for signal in self.signalList: - signal.moveLine() - -if __name__ == "__main__": - app = QApplication(sys.argv) - window = MainWindow() - window.show() - sys.exit(app.exec_()) diff --git a/b_asic/GUI/main_window.py b/b_asic/GUI/main_window.py index b65a118438b64ee9b6540bac5177b77cc9c2cf44..aa30dc2a4d913acfbe3398750a105dd8bb9e2c91 100644 --- a/b_asic/GUI/main_window.py +++ b/b_asic/GUI/main_window.py @@ -1,414 +1,215 @@ """@package docstring B-ASIC GUI Module. -This python file is an example of how a GUI can be implemented -using buttons and textboxes. +This python file is the main window of the GUI for B-ASIC. """ +from os import getcwd, path import sys +from drag_button import DragButton +from gui_interface import Ui_main_window +from arrow import Arrow +from port_button import PortButton + +from b_asic import Operation +import b_asic.core_operations as c_oper +import b_asic.special_operations as s_oper +from utils import decorate_class, handle_error + +from numpy import linspace + from PyQt5.QtWidgets import QApplication, QWidget, QMainWindow, QLabel, QAction,\ QStatusBar, QMenuBar, QLineEdit, QPushButton, QSlider, QScrollArea, QVBoxLayout,\ -QHBoxLayout, QDockWidget, QToolBar, QMenu -from PyQt5.QtCore import Qt, QSize, pyqtSlot -from PyQt5.QtGui import QIcon, QFont, QPainter, QPen, QColor - -from b_asic.core_operations import Addition - - -class DragButton(QPushButton): - def __init__(self, name, window, parent = None): - self.name = name - self.__menu = None - self.__window = window - self.counter = 0 - self.clicked = 0 - self.pressed = False - print("Constructor" + self.name) - super(DragButton, self).__init__(self.__window) - - self.__window.setContextMenuPolicy(Qt.CustomContextMenu) - self.__window.customContextMenuRequested.connect(self.create_menu) - - - @pyqtSlot(QAction) - def actionClicked(self, action): - print("Triggern "+ self.name, self.__menu.name) - #self.__window.check_for_remove_op(self.name) - - #def show_context_menu(self, point): - # show context menu - - - def create_menu(self, point): - self.counter += 1 - # create context menu - popMenu = MyMenu('Menu' + str(self.counter)) - popMenu.addAction(QAction('Add a signal', self)) - popMenu.addAction(QAction('Remove a signal', self)) - #action = QAction('Remove operation', self) - popMenu.addAction('Remove operation', lambda:self.removeAction(self)) - popMenu.addSeparator() - popMenu.addAction(QAction('Remove all signals', self)) - self.__window.menuList.append(popMenu) - #self.__window.actionList.append(action) - self.__menu = popMenu - self.pressed = False - self.__menu.exec_(self.__window.sender().mapToGlobal(point)) - self.__menu.triggered.connect(self.actionClicked) - - - def removeAction(self, op): - print(op.__menu.name, op.name) - op.remove() - - """This class is made to create a draggable button""" - - def mousePressEvent(self, event): - self._mouse_press_pos = None - self._mouse_move_pos = None - self.clicked += 1 - if self.clicked == 1: - self.pressed = True - self.setStyleSheet("background-color: grey; border-style: solid;\ - border-color: black; border-width: 2px; border-radius: 10px") - elif self.clicked == 2: - self.clicked = 0 - self.presseed = False - self.setStyleSheet("background-color: white; border-style: solid;\ - border-color: black; border-width: 2px; border-radius: 10px") - - if event.button() == Qt.LeftButton: - self._mouse_press_pos = event.globalPos() - self._mouse_move_pos = event.globalPos() - - super(DragButton, self).mousePressEvent(event) - - def mouseMoveEvent(self, event): - if event.buttons() == Qt.LeftButton: - cur_pos = self.mapToGlobal(self.pos()) - global_pos = event.globalPos() - diff = global_pos - self._mouse_move_pos - new_pos = self.mapFromGlobal(cur_pos + diff) - self.move(new_pos) - self.pressed = False - - self._mouse_move_pos = global_pos - - super(DragButton, self).mouseMoveEvent(event) - - def mouseReleaseEvent(self, event): - if self._mouse_press_pos is not None: - moved = event.globalPos() - self._mouse_press_pos - if moved.manhattanLength() > 3: - event.ignore() - return - - super(DragButton, self).mouseReleaseEvent(event) - - def remove(self): - self.deleteLater() - -class SubWindow(QWidget): - """Creates a sub window """ - def create_window(self, window_width, window_height): - """Creates a window - """ - parent = None - super(SubWindow, self).__init__(parent) - self.setWindowFlags(Qt.WindowStaysOnTopHint) - self.resize(window_width, window_height) +QHBoxLayout, QDockWidget, QToolBar, QMenu, QLayout, QSizePolicy, QListWidget,\ +QListWidgetItem, QGraphicsView, QGraphicsScene, QShortcut, QGraphicsTextItem,\ +QGraphicsProxyWidget +from PyQt5.QtCore import Qt, QSize +from PyQt5.QtGui import QIcon, QFont, QPainter, QPen, QBrush, QKeySequence -class MyMenu(QMenu): - def __init__(self, name, parent = None): - self.name = name - super(MyMenu, self).__init__() - +MIN_WIDTH_SCENE = 600 +MIN_HEIGHT_SCENE = 520 +@decorate_class(handle_error) class MainWindow(QMainWindow): - """Main window for the program""" - # pylint: disable=too-many-instance-attributes - # Eight is reasonable in this case. - def __init__(self, *args, **kwargs): - super(MainWindow, self).__init__(*args, **kwargs) - self.init_ui() - self.counter = 0 - self.operations = [] - self.menuList = [] - self.actionList = [] - - def init_ui(self): + def __init__(self): + super(MainWindow, self).__init__() + self.ui = Ui_main_window() + self.ui.setupUi(self) self.setWindowTitle(" ") self.setWindowIcon(QIcon('small_logo.png')) - self.create_operation_menu() - self.create_menu_bar() - self.setStatusBar(QStatusBar(self)) - - def create_operation_menu(self): - self.operation_box = QDockWidget("Operation Box", self) - self.operation_box.setAllowedAreas(Qt.LeftDockWidgetArea) - self.test = QToolBar(self) - self.operation_list = QMenuBar(self) - self.test.addWidget(self.operation_list) - self.test.setOrientation(Qt.Vertical) - self.operation_list.setStyleSheet("background-color:rgb(222,222,222); vertical") - basic_operations = self.operation_list.addMenu('Basic operations') - special_operations = self.operation_list.addMenu('Special operations') - - self.addition_menu_item = QAction('&Addition', self) - self.addition_menu_item.setStatusTip("Add addition operation to workspace") - self.addition_menu_item.triggered.connect(self.create_addition_operation) - basic_operations.addAction(self.addition_menu_item) - - self.subtraction_menu_item = QAction('&Subtraction', self) - self.subtraction_menu_item.setStatusTip("Add subtraction operation to workspace") - self.subtraction_menu_item.triggered.connect(self.create_subtraction_operation) - basic_operations.addAction(self.subtraction_menu_item) - - self.multiplication_menu_item = QAction('&Multiplication', self) - self.multiplication_menu_item.setStatusTip("Add multiplication operation to workspace") - self.multiplication_menu_item.triggered.connect(self.create_multiplication_operation) - basic_operations.addAction(self.multiplication_menu_item) - - self.division_menu_item = QAction('&Division', self) - self.division_menu_item.setStatusTip("Add division operation to workspace") - #self.division_menu_item.triggered.connect(self.create_division_operation) - basic_operations.addAction(self.division_menu_item) - - self.constant_menu_item = QAction('&Constant', self) - self.constant_menu_item.setStatusTip("Add constant operation to workspace") - #self.constant_menu_item.triggered.connect(self.create_constant_operation) - basic_operations.addAction(self.constant_menu_item) - - self.square_root_menu_item = QAction('&Square root', self) - self.square_root_menu_item.setStatusTip("Add square root operation to workspace") - #self.square_root_menu_item.triggered.connect(self.create_square_root_operation) - basic_operations.addAction(self.square_root_menu_item) - - self.complex_conjugate_menu_item = QAction('&Complex conjugate', self) - self.complex_conjugate_menu_item.setStatusTip("Add complex conjugate operation to workspace") - #self.complex_conjugate_menu_item.triggered.connect(self.create_complex_conjugate_operation) - basic_operations.addAction(self.complex_conjugate_menu_item) - - self.max_menu_item = QAction('&Max', self) - self.max_menu_item.setStatusTip("Add max operation to workspace") - #self.max_menu_item.triggered.connect(self.create_max_operation) - basic_operations.addAction(self.max_menu_item) - - self.min_menu_item = QAction('&Min', self) - self.min_menu_item.setStatusTip("Add min operation to workspace") - #self.min_menu_item.triggered.connect(self.create_min_operation) - basic_operations.addAction(self.min_menu_item) - - self.absolute_menu_item = QAction('&Absolute', self) - self.absolute_menu_item.setStatusTip("Add absolute operation to workspace") - #self.absolute_menu_item.triggered.connect(self.create_absolute_operation) - basic_operations.addAction(self.absolute_menu_item) - - self.constant_addition_menu_item = QAction('&Constant addition', self) - self.constant_addition_menu_item.setStatusTip("Add constant addition operation to workspace") - #self.constant_addition_menu_item.triggered.connect(self.create_constant_addition_operation) - basic_operations.addAction(self.constant_addition_menu_item) - - self.constant_subtraction_menu_item = QAction('&Constant subtraction', self) - self.constant_subtraction_menu_item.setStatusTip("Add constant subtraction operation to workspace") - #self.constant_subtraction_menu_item.triggered.connect(self.create_constant_subtraction_operation) - basic_operations.addAction(self.constant_subtraction_menu_item) - - self.constant_multiplication_menu_item = QAction('&Constant multiplication', self) - self.constant_multiplication_menu_item.setStatusTip("Add constant multiplication operation to workspace") - #self.constant_multiplication_menu_item.triggered.connect(self.create_constant_multiplication_operation) - basic_operations.addAction(self.constant_multiplication_menu_item) - - self.constant_division_menu_item = QAction('&Constant division', self) - self.constant_division_menu_item.setStatusTip("Add constant division operation to workspace") - #self.constant_division_menu_item.triggered.connect(self.create_constant_division_operation) - basic_operations.addAction(self.constant_division_menu_item) - - self.butterfly_menu_item = QAction('&Butterfly', self) - self.butterfly_menu_item.setStatusTip("Add butterfly operation to workspace") - #self.butterfly_menu_item.triggered.connect(self.create_butterfly_operation) - basic_operations.addAction(self.butterfly_menu_item) - - self.operation_box.setWidget(self.operation_list) - self.operation_box.setMaximumSize(240, 400) - self.operation_box.setFeatures(QDockWidget.NoDockWidgetFeatures) - self.operation_box.setFixedSize(300, 500) - self.operation_box.setStyleSheet("background-color: white; border-style: solid;\ - border-color: black; border-width: 2px") - self.addDockWidget(Qt.LeftDockWidgetArea, self.operation_box) - - def create_addition_operation(self): - self.counter += 1 - - # Create drag button - addition_operation = DragButton("OP" + str(self.counter), self) - addition_operation.move(250, 100) - addition_operation.setFixedSize(50, 50) - addition_operation.setStyleSheet("background-color: white; border-style: solid;\ - border-color: black; border-width: 2px; border-radius: 10px") - addition_operation.clicked.connect(self.create_sub_window) - #self.addition_operation.setIcon(QIcon("GUI'\'operation_icons'\'plus.png")) - addition_operation.setText("OP" + str(self.counter)) - addition_operation.setIconSize(QSize(50, 50)) - addition_operation.show() - self.operations.append(addition_operation) - - # set context menu policies - #self.addition_operation.setContextMenuPolicy(Qt.CustomContextMenu) - #self.addition_operation.customContextMenuRequested.connect(self.show_context_menu) - - #self.action.triggered.connect(lambda checked: self.remove(self.addition_operation.name)) - - def check_for_remove_op(self, name): - self.remove(name) - - - def remove(self, name): - for op in self.operations: - print(name, op.name) - if op.name == name: - self.operations.remove(op) - op.remove() - - def create_subtraction_operation(self): - self.subtraction_operation = DragButton("sub" + str(self.counter), self) - self.subtraction_operation.move(250, 100) - self.subtraction_operation.setFixedSize(50, 50) - self.subtraction_operation.setStyleSheet("background-color: white; border-style: solid;\ - border-color: black; border-width: 2px; border-radius: 10px") - self.subtraction_operation.setIcon(QIcon("GUI'\'operation_icons'\'minus.png")) - self.subtraction_operation.setIconSize(QSize(50, 50)) - self.subtraction_operation.clicked.connect(self.create_sub_window) - self.subtraction_operation.show() - - # set context menu policies - self.subtraction_operation.setContextMenuPolicy(Qt.CustomContextMenu) - self.subtraction_operation.customContextMenuRequested.connect(self.show_context_menu) - - # create context menu - self.button_context_menu = QMenu(self) - self.button_context_menu.addAction(QAction('Add a signal', self)) - self.button_context_menu.addAction(QAction('Remove a signal', self)) - self.button_context_menu.addSeparator() - self.button_context_menu.addAction(QAction('Remove all signals', self)) - - def create_multiplication_operation(self): - self.multiplication_operation = DragButton(self) - self.multiplication_operation.move(250, 100) - self.multiplication_operation.setFixedSize(50, 50) - self.multiplication_operation.setStyleSheet("background-color: white; border-style: solid;\ - border-color: black; border-width: 2px; border-radius: 10px") - self.multiplication_operation.clicked.connect(self.create_sub_window) - self.multiplication_operation.setIcon(QIcon(r"GUI\operation_icons\plus.png")) - self.multiplication_operation.setIconSize(QSize(50, 50)) - self.multiplication_operation.show() - - # set context menu policies - self.multiplication_operation.setContextMenuPolicy(Qt.CustomContextMenu) - self.multiplication_operation.customContextMenuRequested.connect(self.show_context_menu) - - # create context menu - self.button_context_menu = QMenu(self) - self.button_context_menu.addAction(QAction('Add a signal', self)) - self.button_context_menu.addAction(QAction('Remove a signal', self)) - self.button_context_menu.addSeparator() - self.button_context_menu.addAction(QAction('Remove all signals', self)) - - - def create_menu_bar(self): - # Menu buttons - load_button = QAction("Load", self) - save_button = QAction("Save", self) - - exit_button = QAction("Exit", self) - exit_button.setShortcut("Ctrl+Q") - exit_button.triggered.connect(self.exit_app) - - edit_button = QAction("Edit", self) - edit_button.setStatusTip("Open edit menu") - edit_button.triggered.connect(self.on_edit_button_click) - - view_button = QAction("View", self) - view_button.setStatusTip("Open view menu") - view_button.triggered.connect(self.on_view_button_click) - - menu_bar = QMenuBar() - menu_bar.setStyleSheet("background-color:rgb(222, 222, 222)") - self.setMenuBar(menu_bar) - - file_menu = menu_bar.addMenu("&File") - file_menu.addAction(save_button) - file_menu.addSeparator() - file_menu.addAction(exit_button) - - edit_menu = menu_bar.addMenu("&Edit") - edit_menu.addAction(edit_button) - - edit_menu.addSeparator() - - view_menu = menu_bar.addMenu("&View") - view_menu.addAction(view_button) - - - def create_sub_window(self): - """ Example of how to create a sub window - """ - self.sub_window = SubWindow() - self.sub_window.create_window(400, 300) - self.sub_window.setWindowTitle("Properties") - - self.sub_window.properties_label = QLabel(self.sub_window) - self.sub_window.properties_label.setText('Properties') - self.sub_window.properties_label.setFixedWidth(400) - self.sub_window.properties_label.setFont(QFont('SansSerif', 14, QFont.Bold)) - self.sub_window.properties_label.setAlignment(Qt.AlignCenter) + self.scene = None + self._operations_from_name = dict() + self.zoom = 1 + self.operationList = [] + self.signalList = [] + self.pressed_button = [] + self.portList = [] + self.pressed_ports = [] + self.source = None + self._window = self - self.sub_window.name_label = QLabel(self.sub_window) - self.sub_window.name_label.setText('Name:') - self.sub_window.name_label.move(20, 40) - - self.sub_window.name_line = QLineEdit(self.sub_window) - self.sub_window.name_line.setPlaceholderText("Write a name here") - self.sub_window.name_line.move(70, 40) - self.sub_window.name_line.resize(100, 20) - - self.sub_window.id_label = QLabel(self.sub_window) - self.sub_window.id_label.setText('Id:') - self.sub_window.id_label.move(20, 70) - - self.sub_window.id_line = QLineEdit(self.sub_window) - self.sub_window.id_line.setPlaceholderText("Write an id here") - self.sub_window.id_line.move(70, 70) - self.sub_window.id_line.resize(100, 20) + self.init_ui() + self.add_operations_from_namespace(c_oper, self.ui.core_operations_list) + self.add_operations_from_namespace(s_oper, self.ui.special_operations_list) - self.sub_window.show() + self.shortcut_core = QShortcut(QKeySequence("Ctrl+R"), self.ui.operation_box) + self.shortcut_core.activated.connect(self._refresh_operations_list_from_namespace) - def keyPressEvent(self, event): - for op in self.operations: - if event.key() == Qt.Key_Delete and op.pressed: - self.operations.remove(op) - op.remove() - - def on_file_button_click(self): - print("File") + self.move_button_index = 0 + self.is_show_names = True - def on_edit_button_click(self): - print("Edit") + self.check_show_names = QAction("Show operation names") + self.check_show_names.triggered.connect(self.view_operation_names) + self.check_show_names.setCheckable(True) + self.check_show_names.setChecked(1) + self.ui.view_menu.addAction(self.check_show_names) - def on_view_button_click(self): - print("View") + def init_ui(self): + self.ui.core_operations_list.itemClicked.connect(self.on_list_widget_item_clicked) + self.ui.special_operations_list.itemClicked.connect(self.on_list_widget_item_clicked) + self.ui.exit_menu.triggered.connect(self.exit_app) + self.create_graphics_view() + + def create_graphics_view(self): + self.scene = QGraphicsScene(self) + self.graphic_view = QGraphicsView(self.scene, self) + self.graphic_view.setRenderHint(QPainter.Antialiasing) + self.graphic_view.setGeometry(self.ui.operation_box.width(), 0, self.width(), self.height()) + self.graphic_view.setDragMode(QGraphicsView.ScrollHandDrag) + + def resizeEvent(self, event): + self.ui.operation_box.setGeometry(10, 10, self.ui.operation_box.width(), self.height()) + self.graphic_view.setGeometry(self.ui.operation_box.width() + 20, 0, self.width() - self.ui.operation_box.width() - 20, self.height()) + super(MainWindow, self).resizeEvent(event) + + def wheelEvent(self, event): + if event.modifiers() == Qt.ControlModifier: + old_zoom = self.zoom + self.zoom += event.angleDelta().y()/2500 + self.graphic_view.scale(self.zoom, self.zoom) + self.zoom = old_zoom + + def view_operation_names(self, event): + if self.check_show_names.isChecked(): + self.is_show_names = True + else: + self.is_show_names = False + for operation in self.operationList: + operation.label.setOpacity(self.is_show_names) + operation.is_show_name = self.is_show_names def exit_app(self, checked): QApplication.quit() - def clicked(self): - print("Drag button clicked") - + def _determine_port_distance(self, length, ports): + """Determine the distance between each port on the side of an operation. + The method returns the distance that each port should have from 0. + """ + return [length / 2] if ports == 1 else linspace(0, length, ports) + + def _create_port(self, operation, output_port=True): + text = ">" if output_port else "<" + button = PortButton(text, operation, self) + button.setStyleSheet("background-color: white") + button.connectionRequested.connect(self.connectButton) + return button + + def add_ports(self, operation): + _output_ports_dist = self._determine_port_distance(55 - 17, operation.operation.output_count) + _input_ports_dist = self._determine_port_distance(55 - 17, operation.operation.input_count) + + for dist in _input_ports_dist: + port = self._create_port(operation) + port.move(0, dist) + port.show() + + for dist in _output_ports_dist: + port = self._create_port(operation) + port.move(55 - 12, dist) + port.show() + + def get_operations_from_namespace(self, namespace): + return [comp for comp in dir(namespace) if hasattr(getattr(namespace, comp), "type_name")] + + def add_operations_from_namespace(self, namespace, _list): + for attr_name in self.get_operations_from_namespace(namespace): + attr = getattr(namespace, attr_name) + try: + attr.type_name() + item = QListWidgetItem(attr_name) + _list.addItem(item) + self._operations_from_name[attr_name] = attr + except NotImplementedError: + pass + + def _create_operation(self, item): + try: + attr_oper = self._operations_from_name[item.text()]() + attr_button = DragButton(attr_oper.graph_id, attr_oper, attr_oper.type_name().lower(), True, self) + attr_button.move(250, 100) + attr_button.setFixedSize(55, 55) + attr_button.setStyleSheet("background-color: white; border-style: solid;\ + border-color: black; border-width: 2px") + self.add_ports(attr_button) + + icon_path = path.join("operation_icons", f"{attr_oper.type_name().lower()}.png") + if not path.exists(icon_path): + icon_path = path.join("operation_icons", f"custom_operation.png") + attr_button.setIcon(QIcon(icon_path)) + attr_button.setIconSize(QSize(55, 55)) + attr_button.setParent(None) + attr_button_scene = self.scene.addWidget(attr_button) + attr_button_scene.moveBy(self.move_button_index * 100, 0) + self.move_button_index += 1 + operation_label = QGraphicsTextItem(attr_oper.type_name(), attr_button_scene) + if not self.is_show_names: + operation_label.setOpacity(0) + operation_label.setTransformOriginPoint(operation_label.boundingRect().center()) + operation_label.moveBy(10, -20) + attr_button.add_label(operation_label) + self.operationList.append(attr_button) + except Exception as e: + print("Unexpected error occured: ", e) + + def _refresh_operations_list_from_namespace(self): + self.ui.core_operations_list.clear() + self.ui.special_operations_list.clear() + + self.add_operations_from_namespace(c_oper, self.ui.core_operations_list) + self.add_operations_from_namespace(s_oper, self.ui.special_operations_list) + + def on_list_widget_item_clicked(self, item): + self._create_operation(item) + + def keyPressEvent(self, event): + pressed_buttons = [] + for op in self.operationList: + if op.pressed: + pressed_buttons.append(op) + if event.key() == Qt.Key_Delete: + for pressed_op in pressed_buttons: + self.operationList.remove(pressed_op) + pressed_op.remove() + self.move_button_index -= 1 + super().keyPressEvent(event) + + def connectButton(self, button): + if len(self.pressed_ports) < 2: + return + for i in range(len(self.pressed_ports) - 1): + line = Arrow(self.pressed_ports[i], self.pressed_ports[i + 1], self) + self.scene.addItem(line) + self.signalList.append(line) + + self.update() + + def paintEvent(self, event): + for signal in self.signalList: + signal.moveLine() if __name__ == "__main__": app = QApplication(sys.argv) window = MainWindow() - window.resize(960, 720) window.show() - app.exec_() + sys.exit(app.exec_()) diff --git a/b_asic/GUI/port_button.py b/b_asic/GUI/port_button.py index d9fd2b135d1b11c9677be8ee814d95fe01d8007a..af2e7ef6de2df343334305f0fce63cc313c574d2 100644 --- a/b_asic/GUI/port_button.py +++ b/b_asic/GUI/port_button.py @@ -20,7 +20,6 @@ class PortButton(QPushButton): menu.exec_(self.cursor().pos()) def mousePressEvent(self, event): - if event.button() == Qt.LeftButton: self.clicked += 1 if self.clicked == 1: @@ -29,11 +28,11 @@ class PortButton(QPushButton): self.window.pressed_ports.append(self) elif self.clicked == 2: self.setStyleSheet("background-color: white") - self.pressed = False + self.pressed = False self.clicked = 0 self.window.pressed_ports.remove(self) super(PortButton, self).mousePressEvent(event) - + def mouseReleaseEvent(self, event): super(PortButton, self).mouseReleaseEvent(event) diff --git a/b_asic/GUI/properties_window.py b/b_asic/GUI/properties_window.py new file mode 100644 index 0000000000000000000000000000000000000000..5aaee05c820f5d58a9b41d5fee506b176b6826f6 --- /dev/null +++ b/b_asic/GUI/properties_window.py @@ -0,0 +1,62 @@ +from PyQt5.QtWidgets import QDialog, QLineEdit, QPushButton, QVBoxLayout, QHBoxLayout,\ +QLabel, QCheckBox +from PyQt5.QtCore import Qt +from PyQt5.QtGui import QIntValidator + +class PropertiesWindow(QDialog): + def __init__(self, operation, main_window): + super(PropertiesWindow, self).__init__() + self.operation = operation + self.main_window = main_window + self.setWindowFlags(Qt.WindowTitleHint | Qt.WindowCloseButtonHint) + self.setWindowTitle("Properties") + + self.name_layout = QHBoxLayout() + self.name_layout.setSpacing(50) + self.name_label = QLabel("Name:") + self.edit_name = QLineEdit(self.operation.operation_path_name) + self.name_layout.addWidget(self.name_label) + self.name_layout.addWidget(self.edit_name) + + self.vertical_layout = QVBoxLayout() + self.vertical_layout.addLayout(self.name_layout) + + if self.operation.operation_path_name == "c": + self.constant_layout = QHBoxLayout() + self.constant_layout.setSpacing(50) + self.constant_value = QLabel("Constant:") + self.edit_constant = QLineEdit(str(self.operation.operation.value)) + self.only_accept_int = QIntValidator() + self.edit_constant.setValidator(self.only_accept_int) + self.constant_layout.addWidget(self.constant_value) + self.constant_layout.addWidget(self.edit_constant) + self.vertical_layout.addLayout(self.constant_layout) + + self.show_name_layout = QHBoxLayout() + self.check_show_name = QCheckBox("Show name?") + if self.operation.is_show_name: + self.check_show_name.setChecked(1) + else: + self.check_show_name.setChecked(0) + self.check_show_name.setLayoutDirection(Qt.RightToLeft) + self.check_show_name.setStyleSheet("spacing: 170px") + self.show_name_layout.addWidget(self.check_show_name) + self.vertical_layout.addLayout(self.show_name_layout) + + self.ok = QPushButton("OK") + self.ok.clicked.connect(self.save_properties) + self.vertical_layout.addWidget(self.ok) + self.setLayout(self.vertical_layout) + + def save_properties(self): + self.operation.name = self.edit_name.text() + self.operation.label.setPlainText(self.operation.name) + if self.operation.operation_path_name == "c": + self.operation.operation.value = self.edit_constant.text() + if self.check_show_name.isChecked(): + self.operation.label.setOpacity(1) + self.operation.is_show_name = True + else: + self.operation.label.setOpacity(0) + self.operation.is_show_name = False + self.reject() \ No newline at end of file diff --git a/b_asic/GUI/utils.py b/b_asic/GUI/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..721496c7db7d0259d6335681f6a4c95c6c5930b6 --- /dev/null +++ b/b_asic/GUI/utils.py @@ -0,0 +1,19 @@ +from PyQt5.QtWidgets import QErrorMessage +from traceback import format_exc + +def handle_error(fn): + def wrapper(self, *args, **kwargs): + try: + return fn(self, *args, **kwargs) + except Exception as e: + QErrorMessage(self._window).showMessage(f"Unexpected error: {format_exc()}") + + return wrapper + +def decorate_class(decorator): + def decorate(cls): + for attr in cls.__dict__: + if callable(getattr(cls, attr)): + setattr(cls, attr, decorator(getattr(cls, attr))) + return cls + return decorate \ No newline at end of file diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py index 3e6cd78755f727d034723e2f02e707225cbe9611..ec7306c6f4c97b5c0377794e48524d09c7ed159b 100644 --- a/b_asic/core_operations.py +++ b/b_asic/core_operations.py @@ -240,3 +240,18 @@ class Butterfly(AbstractOperation): def evaluate(self, a, b): return a + b, a - b + +class MAD(AbstractOperation): + """Multiply-and-add operation. + TODO: More info. + """ + + def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, src2: Optional[SignalSourceProvider] = None, name: Name = ""): + super().__init__(input_count = 3, output_count = 1, name = name, input_sources = [src0, src1, src2]) + + @property + def type_name(self) -> TypeName: + return "mad" + + def evaluate(self, a, b, c): + return a * b + c diff --git a/b_asic/graph_component.py b/b_asic/graph_component.py index 5efa8038e6c914d0e8f2023ebf6f05eb58664e5b..e08422a842a84d08dcab58ab03d7f581cb1bc664 100644 --- a/b_asic/graph_component.py +++ b/b_asic/graph_component.py @@ -105,6 +105,10 @@ class AbstractGraphComponent(GraphComponent): self._graph_id = "" self._parameters = {} + def __str__(self): + return f"id: {self.graph_id if self.graph_id else 'no_id'}, \tname: {self.name if self.name else 'no_name'}" + \ + "".join((f", \t{key}: {str(param)}" for key, param in self._parameters.items())) + @property def name(self) -> Name: return self._name diff --git a/b_asic/operation.py b/b_asic/operation.py index f8ac22e2a1d26e13365d0d742775de6f1f020057..21e7012eaf7a333c5db8a7f8a6c741b3220030b8 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -186,6 +186,12 @@ class Operation(GraphComponent, SignalSourceProvider): """Get the input indices of all inputs in this operation whose values are required in order to evalueate the output at the given output index.""" raise NotImplementedError + @abstractmethod + def to_sfg(self) -> "SFG": + """Convert the operation into its corresponding SFG. + If the operation is composed by multiple operations, the operation will be split. + """ + raise NotImplementedError class AbstractOperation(Operation, AbstractGraphComponent): """Generic abstract operation class which most implementations will derive from. @@ -256,6 +262,47 @@ class AbstractOperation(Operation, AbstractGraphComponent): from b_asic.core_operations import Constant, Division return Division(Constant(src) if isinstance(src, Number) else src, self) + def __str__(self): + inputs_dict = dict() + for i, port in enumerate(self.inputs): + if port.signal_count == 0: + inputs_dict[i] = '-' + break + dict_ele = [] + for signal in port.signals: + if signal.source: + if signal.source.operation.graph_id: + dict_ele.append(signal.source.operation.graph_id) + else: + dict_ele.append("no_id") + else: + if signal.graph_id: + dict_ele.append(signal.graph_id) + else: + dict_ele.append("no_id") + inputs_dict[i] = dict_ele + + outputs_dict = dict() + for i, port in enumerate(self.outputs): + if port.signal_count == 0: + outputs_dict[i] = '-' + break + dict_ele = [] + for signal in port.signals: + if signal.destination: + if signal.destination.operation.graph_id: + dict_ele.append(signal.destination.operation.graph_id) + else: + dict_ele.append("no_id") + else: + if signal.graph_id: + dict_ele.append(signal.graph_id) + else: + dict_ele.append("no_id") + outputs_dict[i] = dict_ele + + return super().__str__() + f", \tinputs: {str(inputs_dict)}, \toutputs: {str(outputs_dict)}" + @property def input_count(self) -> int: return len(self._input_ports) @@ -361,6 +408,30 @@ class AbstractOperation(Operation, AbstractGraphComponent): pass return [self] + def to_sfg(self) -> "SFG": + # Import here to avoid circular imports. + from b_asic.special_operations import Input, Output + from b_asic.signal_flow_graph import SFG + + inputs = [Input() for i in range(self.input_count)] + + try: + last_operations = self.evaluate(*inputs) + if isinstance(last_operations, Operation): + last_operations = [last_operations] + outputs = [Output(o) for o in last_operations] + except TypeError: + operation_copy = self.copy_component() + inputs = [] + for i in range(self.input_count): + _input = Input() + operation_copy.input(i).connect(_input) + inputs.append(_input) + + outputs = [Output(operation_copy)] + + return SFG(inputs=inputs, outputs=outputs) + def inputs_required_for_output(self, output_index: int) -> Iterable[int]: if output_index < 0 or output_index >= self.output_count: raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {output_index})") @@ -370,6 +441,16 @@ class AbstractOperation(Operation, AbstractGraphComponent): def neighbors(self) -> Iterable[GraphComponent]: return list(self.input_signals) + list(self.output_signals) + @property + def preceding_operations(self) -> Iterable[Operation]: + """Returns an Iterable of all Operations that are connected to this Operations input ports.""" + return [signal.source.operation for signal in self.input_signals if signal.source] + + @property + def subsequent_operations(self) -> Iterable[Operation]: + """Returns an Iterable of all Operations that are connected to this Operations output ports.""" + return [signal.destination.operation for signal in self.output_signals if signal.destination] + @property def source(self) -> OutputPort: if self.output_count != 1: diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 6306a78c6590f3f73e8b6c0b533f869a37a51275..22667aa8b6c754e1119d80d9753e0afc99400c5d 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -6,14 +6,16 @@ TODO: More info. from typing import List, Iterable, Sequence, Dict, Optional, DefaultDict, MutableSet from numbers import Number from collections import defaultdict, deque +from io import StringIO +from queue import PriorityQueue +import itertools from graphviz import Digraph -from b_asic.port import SignalSourceProvider, OutputPort, InputPort +from b_asic.port import SignalSourceProvider, OutputPort from b_asic.operation import Operation, AbstractOperation, MutableOutputMap, MutableRegisterMap from b_asic.signal import Signal from b_asic.graph_component import GraphID, GraphIDNumber, GraphComponent, Name, TypeName from b_asic.special_operations import Input, Output, Register -from b_asic.core_operations import Constant class GraphIDGenerator: @@ -42,8 +44,9 @@ class SFG(AbstractOperation): _components_by_id: Dict[GraphID, GraphComponent] _components_by_name: DefaultDict[Name, List[GraphComponent]] - _components_ordered: List[GraphComponent] - _operations_ordered: List[Operation] + _components_dfs_order: List[GraphComponent] + _operations_dfs_order: List[Operation] + _operations_topological_order: List[Operation] _graph_id_generator: GraphIDGenerator _input_operations: List[Input] _output_operations: List[Output] @@ -68,8 +71,9 @@ class SFG(AbstractOperation): self._components_by_id = dict() self._components_by_name = defaultdict(list) - self._components_ordered = [] - self._operations_ordered = [] + self._components_dfs_order = [] + self._operations_dfs_order = [] + self._operations_topological_order = [] self._graph_id_generator = GraphIDGenerator(id_number_offset) self._input_operations = [] self._output_operations = [] @@ -152,9 +156,9 @@ class SFG(AbstractOperation): signal.destination.operation) elif new_signal.destination.operation in output_operations_set: # Add directly connected input to output to ordered list. - self._components_ordered.extend( + self._components_dfs_order.extend( [new_signal.source.operation, new_signal, new_signal.destination.operation]) - self._operations_ordered.extend( + self._operations_dfs_order.extend( [new_signal.source.operation, new_signal.destination.operation]) # Search the graph inwards from each output signal. @@ -171,47 +175,18 @@ class SFG(AbstractOperation): def __str__(self) -> str: """Get a string representation of this SFG.""" - output_string = "" - for component in self._components_ordered: - if isinstance(component, Operation): - for key, value in self._components_by_id.items(): - if value is component: - output_string += "id: " + key + ", name: " - - if component.name != None: - output_string += component.name + ", " - else: - output_string += "-, " + string_io = StringIO() + string_io.write(super().__str__() + "\n") + string_io.write("Internal Operations:\n") + line = "-" * 100 + "\n" + string_io.write(line) - if isinstance(component, Constant): - output_string += "value: " + \ - str(component.value) + ", input: [" - else: - output_string += "input: [" - - counter_input = 0 - for input in component.inputs: - counter_input += 1 - for signal in input.signals: - for key, value in self._components_by_id.items(): - if value is signal: - output_string += key + ", " - - if counter_input > 0: - output_string = output_string[:-2] - output_string += "], output: [" - counter_output = 0 - for output in component.outputs: - counter_output += 1 - for signal in output.signals: - for key, value in self._components_by_id.items(): - if value is signal: - output_string += key + ", " - if counter_output > 0: - output_string = output_string[:-2] - output_string += "]\n" - - return output_string + for operation in self.get_operations_topological_order(): + string_io.write(str(operation) + "\n") + + string_io.write(line) + + return string_io.getvalue() def __call__(self, *src: Optional[SignalSourceProvider], name: Name = "") -> "SFG": """Get a new independent SFG instance that is identical to this SFG except without any of its external connections.""" @@ -249,7 +224,7 @@ class SFG(AbstractOperation): return value def connect_external_signals_to_components(self) -> bool: - """ Connects any external signals to this SFG's internal operations. This SFG becomes unconnected to the SFG + """ Connects any external signals to this SFG's internal operations. This SFG becomes unconnected to the SFG it is a component off, causing it to become invalid afterwards. Returns True if succesful, False otherwise. """ if len(self.inputs) != len(self.input_operations): raise IndexError(f"Number of inputs does not match the number of input_operations in SFG.") @@ -265,7 +240,7 @@ class SFG(AbstractOperation): dest = input_operation.output(0).signals[0].destination dest.clear() port.signals[0].set_destination(dest) - # For each output_signal, connect it to the corresponding operation + # For each output_signal, connect it to the corresponding operation for port, output_operation in zip(self.outputs, self.output_operations): src = output_operation.input(0).signals[0].source src.clear() @@ -284,6 +259,9 @@ class SFG(AbstractOperation): def split(self) -> Iterable[Operation]: return self.operations + + def to_sfg(self) -> 'SFG': + return self def inputs_required_for_output(self, output_index: int) -> Iterable[int]: if output_index < 0 or output_index >= self.output_count: @@ -326,12 +304,12 @@ class SFG(AbstractOperation): @property def components(self) -> Iterable[GraphComponent]: """Get all components of this graph in depth-first order.""" - return self._components_ordered + return self._components_dfs_order @property def operations(self) -> Iterable[Operation]: """Get all operations of this graph in depth-first order.""" - return self._operations_ordered + return self._operations_dfs_order def get_components_with_type_name(self, type_name: TypeName) -> List[GraphComponent]: """Get a list with all components in this graph with the specified type_name. @@ -385,8 +363,8 @@ class SFG(AbstractOperation): new_op = None if original_op not in self._original_components_to_new: new_op = self._add_component_unconnected_copy(original_op) - self._components_ordered.append(new_op) - self._operations_ordered.append(new_op) + self._components_dfs_order.append(new_op) + self._operations_dfs_order.append(new_op) else: new_op = self._original_components_to_new[original_op] @@ -400,24 +378,20 @@ class SFG(AbstractOperation): if original_signal in self._original_input_signals_to_indices: # New signal already created during first step of constructor. new_signal = self._original_components_to_new[original_signal] - new_signal.set_destination( - new_op.input(original_input_port.index)) - self._components_ordered.extend( - [new_signal, new_signal.source.operation]) - self._operations_ordered.append( - new_signal.source.operation) + new_signal.set_destination(new_op.input(original_input_port.index)) + + self._components_dfs_order.extend([new_signal, new_signal.source.operation]) + self._operations_dfs_order.append(new_signal.source.operation) # Check if the signal has not been added before. elif original_signal not in self._original_components_to_new: if original_signal.source is None: - raise ValueError( - "Dangling signal without source in SFG") + raise ValueError("Dangling signal without source in SFG") - new_signal = self._add_component_unconnected_copy( - original_signal) - new_signal.set_destination( - new_op.input(original_input_port.index)) - self._components_ordered.append(new_signal) + new_signal = self._add_component_unconnected_copy(original_signal) + new_signal.set_destination(new_op.input(original_input_port.index)) + + self._components_dfs_order.append(new_signal) original_connected_op = original_signal.source.operation # Check if connected Operation has been added before. @@ -427,12 +401,11 @@ class SFG(AbstractOperation): original_signal.source.index)) else: # Create new operation, set signal source to it. - new_connected_op = self._add_component_unconnected_copy( - original_connected_op) - new_signal.set_source(new_connected_op.output( - original_signal.source.index)) - self._components_ordered.append(new_connected_op) - self._operations_ordered.append(new_connected_op) + new_connected_op = self._add_component_unconnected_copy(original_connected_op) + new_signal.set_source(new_connected_op.output(original_signal.source.index)) + + self._components_dfs_order.append(new_connected_op) + self._operations_dfs_order.append(new_connected_op) # Add connected operation to queue of operations to visit. op_stack.append(original_connected_op) @@ -444,24 +417,20 @@ class SFG(AbstractOperation): if original_signal in self._original_output_signals_to_indices: # New signal already created during first step of constructor. new_signal = self._original_components_to_new[original_signal] - new_signal.set_source( - new_op.output(original_output_port.index)) - self._components_ordered.extend( - [new_signal, new_signal.destination.operation]) - self._operations_ordered.append( - new_signal.destination.operation) + new_signal.set_source(new_op.output(original_output_port.index)) + + self._components_dfs_order.extend([new_signal, new_signal.destination.operation]) + self._operations_dfs_order.append(new_signal.destination.operation) # Check if signal has not been added before. elif original_signal not in self._original_components_to_new: if original_signal.source is None: - raise ValueError( - "Dangling signal without source in SFG") + raise ValueError("Dangling signal without source in SFG") - new_signal = self._add_component_unconnected_copy( - original_signal) - new_signal.set_source( - new_op.output(original_output_port.index)) - self._components_ordered.append(new_signal) + new_signal = self._add_component_unconnected_copy(original_signal) + new_signal.set_source(new_op.output(original_output_port.index)) + + self._components_dfs_order.append(new_signal) original_connected_op = original_signal.destination.operation # Check if connected operation has been added. @@ -471,12 +440,11 @@ class SFG(AbstractOperation): original_signal.destination.index)) else: # Create new operation, set destination to it. - new_connected_op = self._add_component_unconnected_copy( - original_connected_op) - new_signal.set_destination(new_connected_op.input( - original_signal.destination.index)) - self._components_ordered.append(new_connected_op) - self._operations_ordered.append(new_connected_op) + new_connected_op = self._add_component_unconnected_copy(original_connected_op) + new_signal.set_destination(new_connected_op.input(original_signal.destination.index)) + + self._components_dfs_order.append(new_connected_op) + self._operations_dfs_order.append(new_connected_op) # Add connected operation to the queue of operations to visit. op_stack.append(original_connected_op) @@ -554,16 +522,13 @@ class SFG(AbstractOperation): if key in results: value = results[key] if value is None: - raise RuntimeError( - f"Direct feedback loop detected when evaluating operation.") + raise RuntimeError(f"Direct feedback loop detected when evaluating operation.") return value - results[key] = src.operation.current_output( - src.index, registers, src_prefix) + results[key] = src.operation.current_output(src.index, registers, src_prefix) input_values = [self._evaluate_source( input_port.signals[0].source, results, registers, prefix) for input_port in src.operation.inputs] - value = src.operation.evaluate_output( - src.index, input_values, results, registers, src_prefix) + value = src.operation.evaluate_output(src.index, input_values, results, registers, src_prefix) results[key] = value return value @@ -571,7 +536,7 @@ class SFG(AbstractOperation): """Returns a Precedence list of the SFG where each element in n:th the list consists of elements that are executed in the n:th step. If the precedence list already has been calculated for the current SFG then returns the cached version.""" - if self._precedence_list is not None: + if self._precedence_list: return self._precedence_list # Find all operations with only outputs and no inputs. @@ -585,17 +550,9 @@ class SFG(AbstractOperation): return self._precedence_list - def _traverse_for_precedence_list(self, first_iter_ports): + def _traverse_for_precedence_list(self, first_iter_ports: List[OutputPort]) -> List[List[OutputPort]]: # Find dependencies of output ports and input ports. - outports_per_inport = defaultdict(list) - remaining_inports_per_outport = dict() - for op in self.operations: - op_inputs = op.inputs - for out_i, outport in enumerate(op.outputs): - dependendent_indexes = op.inputs_required_for_output(out_i) - remaining_inports_per_outport[outport] = len(dependendent_indexes) - for in_i in dependendent_indexes: - outports_per_inport[op_inputs[in_i]].append(outport) + remaining_inports_per_operation = {op: op.input_count for op in self.operations} # Traverse output ports for precedence curr_iter_ports = first_iter_ports @@ -612,10 +569,10 @@ class SFG(AbstractOperation): new_inport = signal.destination # Don't traverse over Registers if new_inport is not None and not isinstance(new_inport.operation, Register): - for new_outport in outports_per_inport[new_inport]: - remaining_inports_per_outport[new_outport] -= 1 - if remaining_inports_per_outport[new_outport] == 0: - next_iter_ports.append(new_outport) + new_op = new_inport.operation + remaining_inports_per_operation[new_op] -= 1 + if remaining_inports_per_operation[new_op] == 0: + next_iter_ports.extend(new_op.outputs) curr_iter_ports = next_iter_ports @@ -641,3 +598,105 @@ class SFG(AbstractOperation): pg.edge(port.operation.graph_id + '.' + str(port.index), signal.destination.operation.graph_id) pg.edge(port.operation.graph_id, port.operation.graph_id + '.' + str(port.index)) pg.view() + + def print_precedence_graph(self) -> None: + """Prints a representation of the SFG's precedence list to the standard out. + If the precedence list already has been calculated then it uses the cached version, + otherwise it calculates the precedence list and then prints it.""" + precedence_list = self.get_precedence_list() + + line = "-" * 120 + out_str = StringIO() + out_str.write(line) + + printed_ops = set() + + for iter_num, iter in enumerate(precedence_list, start=1): + for outport_num, outport in enumerate(iter, start=1): + if outport not in printed_ops: + # Only print once per operation, even if it has multiple outports + out_str.write("\n") + out_str.write(str(iter_num)) + out_str.write(".") + out_str.write(str(outport_num)) + out_str.write(" \t") + out_str.write(str(outport.operation)) + printed_ops.add(outport) + + out_str.write("\n") + out_str.write(line) + + print(out_str.getvalue()) + + def get_operations_topological_order(self) -> Iterable[Operation]: + """Returns an Iterable of the Operations in the SFG in Topological Order. + Feedback loops makes an absolutely correct Topological order impossible, so an + approximative Topological Order is returned in such cases in this implementation.""" + if self._operations_topological_order: + return self._operations_topological_order + + no_inputs_queue = deque(list(filter(lambda op: op.input_count == 0, self.operations))) + remaining_inports_per_operation = {op: op.input_count for op in self.operations} + + # Maps number of input counts to a queue of seen objects with such a size. + seen_with_inputs_dict = defaultdict(deque) + seen = set() + top_order = [] + + assert len(no_inputs_queue) > 0, "Illegal SFG state, dangling signals in SFG." + + first_op = no_inputs_queue.popleft() + visited = set([first_op]) + p_queue = PriorityQueue() + p_queue.put((-first_op.output_count, first_op)) # Negative priority as max-heap popping is wanted + operations_left = len(self.operations) - 1 + + seen_but_not_visited_count = 0 + + while operations_left > 0: + while not p_queue.empty(): + op = p_queue.get()[1] + + operations_left -= 1 + top_order.append(op) + visited.add(op) + + for neighbor_op in op.subsequent_operations: + if neighbor_op not in visited: + remaining_inports_per_operation[neighbor_op] -= 1 + remaining_inports = remaining_inports_per_operation[neighbor_op] + + if remaining_inports == 0: + p_queue.put((-neighbor_op.output_count, neighbor_op)) + + elif remaining_inports > 0: + if neighbor_op in seen: + seen_with_inputs_dict[remaining_inports + 1].remove(neighbor_op) + else: + seen.add(neighbor_op) + seen_but_not_visited_count += 1 + + seen_with_inputs_dict[remaining_inports].append(neighbor_op) + + # Check if have to fetch Operations from somewhere else since p_queue is empty + if operations_left > 0: + # First check if can fetch from Operations with no input ports + if no_inputs_queue: + new_op = no_inputs_queue.popleft() + p_queue.put((new_op.output_count, new_op)) + + # Else fetch operation with lowest input count that is not zero + elif seen_but_not_visited_count > 0: + for i in itertools.count(start=1): + seen_inputs_queue = seen_with_inputs_dict[i] + if seen_inputs_queue: + new_op = seen_inputs_queue.popleft() + p_queue.put((-new_op.output_count, new_op)) + seen_but_not_visited_count -= 1 + break + else: + raise RuntimeError("Unallowed structure in SFG detected") + + self._operations_topological_order = top_order + + return self._operations_topological_order diff --git a/test/fixtures/signal_flow_graph.py b/test/fixtures/signal_flow_graph.py index 5a0ef25b94cec8e3fad9275cccf97882703de330..e2145b0a2a5974222c8c3d740cb3f53d76c7e445 100644 --- a/test/fixtures/signal_flow_graph.py +++ b/test/fixtures/signal_flow_graph.py @@ -1,12 +1,12 @@ import pytest -from b_asic import SFG, Input, Output, Constant, Register, ConstantMultiplication +from b_asic import SFG, Input, Output, Constant, Register, ConstantMultiplication, Addition, Butterfly @pytest.fixture def sfg_two_inputs_two_outputs(): """Valid SFG with two inputs and two outputs. - . . + . . in1-------+ +--------->out1 . | | . . v | . @@ -17,9 +17,9 @@ def sfg_two_inputs_two_outputs(): | . ^ . | . | . +------------+ . - . . + . . out1 = in1 + in2 - out2 = in1 + 2 * in2 + out2 = in1 + 2 * in2 """ in1 = Input() in2 = Input() @@ -27,13 +27,14 @@ def sfg_two_inputs_two_outputs(): add2 = add1 + in2 out1 = Output(add1) out2 = Output(add2) - return SFG(inputs = [in1, in2], outputs = [out1, out2]) + return SFG(inputs=[in1, in2], outputs=[out1, out2]) + @pytest.fixture def sfg_two_inputs_two_outputs_independent(): """Valid SFG with two inputs and two outputs, where the first output only depends on the first input and the second output only depends on the second input. - . . + . . in1-------------------->out1 . . . . @@ -44,17 +45,18 @@ def sfg_two_inputs_two_outputs_independent(): . | ^ . . | | . . +------+ . - . . + . . out1 = in1 - out2 = in2 + 3 + out2 = in2 + 3 """ - in1 = Input() - in2 = Input() - c1 = Constant(3) - add1 = in2 + c1 - out1 = Output(in1) - out2 = Output(add1) - return SFG(inputs = [in1, in2], outputs = [out1, out2]) + in1 = Input("IN1") + in2 = Input("IN2") + c1 = Constant(3, "C1") + add1 = Addition(in2, c1, "ADD1") + out1 = Output(in1, "OUT1") + out2 = Output(add1, "OUT2") + return SFG(inputs=[in1, in2], outputs=[out1, out2]) + @pytest.fixture def sfg_nested(): @@ -65,7 +67,7 @@ def sfg_nested(): mac_in2 = Input() mac_in3 = Input() mac_out1 = Output(mac_in1 + mac_in2 * mac_in3) - MAC = SFG(inputs = [mac_in1, mac_in2, mac_in3], outputs = [mac_out1]) + MAC = SFG(inputs=[mac_in1, mac_in2, mac_in3], outputs=[mac_out1]) in1 = Input() in2 = Input() @@ -73,7 +75,8 @@ def sfg_nested(): mac2 = MAC(in1, in2, mac1) mac3 = MAC(in1, mac1, mac2) out1 = Output(mac3) - return SFG(inputs = [in1, in2], outputs = [out1]) + return SFG(inputs=[in1, in2], outputs=[out1]) + @pytest.fixture def sfg_delay(): @@ -83,7 +86,8 @@ def sfg_delay(): in1 = Input() reg1 = Register(in1) out1 = Output(reg1) - return SFG(inputs = [in1], outputs = [out1]) + return SFG(inputs=[in1], outputs=[out1]) + @pytest.fixture def sfg_accumulator(): @@ -95,7 +99,8 @@ def sfg_accumulator(): reg = Register() reg.input(0).connect((reg + data_in) * (1 - reset)) data_out = Output(reg) - return SFG(inputs = [data_in, reset], outputs = [data_out]) + return SFG(inputs=[data_in, reset], outputs=[data_out]) + @pytest.fixture def simple_filter(): @@ -105,11 +110,70 @@ def simple_filter(): | | in1>------add1>------reg>------+------out1> """ - in1 = Input() - reg = Register() - constmul1 = ConstantMultiplication(0.5) - add1 = in1 + constmul1 - reg.input(0).connect(add1) + in1 = Input("IN1") + constmul1 = ConstantMultiplication(0.5, name="CMUL1") + add1 = Addition(in1, constmul1, "ADD1") + reg = Register(add1, name="REG1") constmul1.input(0).connect(reg) - out1 = Output(reg) - return SFG(inputs=[in1], outputs=[out1]) + out1 = Output(reg, "OUT1") + return SFG(inputs=[in1], outputs=[out1], name="simple_filter") + + +@pytest.fixture +def precedence_sfg_registers(): + """A sfg with registers and interesting layout for precednce list generation. + + IN1>--->C0>--->ADD1>--->Q1>---+--->A0>--->ADD4>--->OUT1 + ^ | ^ + | T1 | + | | | + ADD2<---<B1<---+--->A1>--->ADD3 + ^ | ^ + | T2 | + | | | + +-----<B2<---+--->A2>-----+ + """ + in1 = Input("IN1") + c0 = ConstantMultiplication(5, in1, "C0") + add1 = Addition(c0, None, "ADD1") + # Not sure what operation "Q" is supposed to be in the example + Q1 = ConstantMultiplication(1, add1, "Q1") + T1 = Register(Q1, 0, "T1") + T2 = Register(T1, 0, "T2") + b2 = ConstantMultiplication(2, T2, "B2") + b1 = ConstantMultiplication(3, T1, "B1") + add2 = Addition(b1, b2, "ADD2") + add1.input(1).connect(add2) + a1 = ConstantMultiplication(4, T1, "A1") + a2 = ConstantMultiplication(6, T2, "A2") + add3 = Addition(a1, a2, "ADD3") + a0 = ConstantMultiplication(7, Q1, "A0") + add4 = Addition(a0, add3, "ADD4") + out1 = Output(add4, "OUT1") + + return SFG(inputs=[in1], outputs=[out1], name="SFG") + + +@pytest.fixture +def precedence_sfg_registers_and_constants(): + in1 = Input("IN1") + c0 = ConstantMultiplication(5, in1, "C0") + add1 = Addition(c0, None, "ADD1") + # Not sure what operation "Q" is supposed to be in the example + Q1 = ConstantMultiplication(1, add1, "Q1") + T1 = Register(Q1, 0, "T1") + const1 = Constant(10, "CONST1") # Replace T2 register with a constant + b2 = ConstantMultiplication(2, const1, "B2") + b1 = ConstantMultiplication(3, T1, "B1") + add2 = Addition(b1, b2, "ADD2") + add1.input(1).connect(add2) + a1 = ConstantMultiplication(4, T1, "A1") + a2 = ConstantMultiplication(10, const1, "A2") + add3 = Addition(a1, a2, "ADD3") + a0 = ConstantMultiplication(7, Q1, "A0") + # Replace ADD4 with a butterfly to test multiple output ports + bfly1 = Butterfly(a0, add3, "BFLY1") + out1 = Output(bfly1.output(0), "OUT1") + out2 = Output(bfly1.output(1), "OUT2") + + return SFG(inputs=[in1], outputs=[out1], name="SFG") diff --git a/test/test_abstract_operation.py b/test/test_abstract_operation.py index 5423ecdf08c420df5dccc6393c3ad6637961172b..9163fce2a955c7fbc68d5d24de86896d251934da 100644 --- a/test/test_abstract_operation.py +++ b/test/test_abstract_operation.py @@ -89,4 +89,3 @@ def test_division_overload(): assert isinstance(div3, Division) assert div3.input(0).signals[0].source.operation.value == 5 assert div3.input(1).signals == div2.output(0).signals - diff --git a/test/test_core_operations.py b/test/test_core_operations.py index 2eb341da88a851ac0fd26939da64377ea27963a1..6a0493c60965579bd843e0b514bd7f9b9a0e4707 100644 --- a/test/test_core_operations.py +++ b/test/test_core_operations.py @@ -6,7 +6,6 @@ from b_asic import \ Constant, Addition, Subtraction, Multiplication, ConstantMultiplication, Division, \ SquareRoot, ComplexConjugate, Max, Min, Absolute, Butterfly - class TestConstant: def test_constant_positive(self): test_operation = Constant(3) diff --git a/test/test_operation.py b/test/test_operation.py index b76ba16d11425c0ce868e4fa0b4c88d9f862e23f..77e9ba3cbd0eaa75886b5a7e5d11f00f6cfeb479 100644 --- a/test/test_operation.py +++ b/test/test_operation.py @@ -1,6 +1,6 @@ import pytest -from b_asic import Constant, Addition +from b_asic import Constant, Addition, MAD, Butterfly, SquareRoot class TestTraverse: def test_traverse_single_tree(self, operation): @@ -22,4 +22,32 @@ class TestTraverse: assert len(list(filter(lambda type_: isinstance(type_, Constant), result))) == 4 def test_traverse_loop(self, operation_graph_with_cycle): - assert len(list(operation_graph_with_cycle.traverse())) == 8 \ No newline at end of file + assert len(list(operation_graph_with_cycle.traverse())) == 8 + +class TestToSfg: + def test_convert_mad_to_sfg(self): + mad1 = MAD() + mad1_sfg = mad1.to_sfg() + + assert mad1.evaluate(1,1,1) == mad1_sfg.evaluate(1,1,1) + assert len(mad1_sfg.operations) == 6 + + def test_butterfly_to_sfg(self): + but1 = Butterfly() + but1_sfg = but1.to_sfg() + + assert but1.evaluate(1,1)[0] == but1_sfg.evaluate(1,1)[0] + assert but1.evaluate(1,1)[1] == but1_sfg.evaluate(1,1)[1] + assert len(but1_sfg.operations) == 8 + + def test_add_to_sfg(self): + add1 = Addition() + add1_sfg = add1.to_sfg() + + assert len(add1_sfg.operations) == 4 + + def test_sqrt_to_sfg(self): + sqrt1 = SquareRoot() + sqrt1_sfg = sqrt1.to_sfg() + + assert len(sqrt1_sfg.operations) == 3 diff --git a/test/test_sfg.py b/test/test_sfg.py index 618dce7f3f2bd9b7efd9ec2f3018d7e48e8c025b..e9ee82b872d5ecf725f617a20c1af7c3ea7a9f57 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -1,4 +1,7 @@ import pytest +import io +import sys + from b_asic import SFG, Signal, Input, Output, Constant, ConstantMultiplication, Addition, Multiplication, Register, \ Butterfly, Subtraction, SquareRoot @@ -54,13 +57,17 @@ class TestPrintSfg: inp2 = Input("INP2") add1 = Addition(inp1, inp2, "ADD1") out1 = Output(add1, "OUT1") - sfg = SFG(inputs=[inp1, inp2], outputs=[out1], name="sf1") + sfg = SFG(inputs=[inp1, inp2], outputs=[out1], name="SFG1") assert sfg.__str__() == \ - "id: add1, name: ADD1, input: [s1, s2], output: [s3]\n" + \ - "id: in1, name: INP1, input: [], output: [s1]\n" + \ - "id: in2, name: INP2, input: [], output: [s2]\n" + \ - "id: out1, name: OUT1, input: [s3], output: []\n" + "id: no_id, \tname: SFG1, \tinputs: {0: '-'}, \toutputs: {0: '-'}\n" + \ + "Internal Operations:\n" + \ + "----------------------------------------------------------------------------------------------------\n" + \ + str(sfg.find_by_name("INP1")[0]) + "\n" + \ + str(sfg.find_by_name("INP2")[0]) + "\n" + \ + str(sfg.find_by_name("ADD1")[0]) + "\n" + \ + str(sfg.find_by_name("OUT1")[0]) + "\n" + \ + "----------------------------------------------------------------------------------------------------\n" def test_add_mul(self): inp1 = Input("INP1") @@ -72,12 +79,16 @@ class TestPrintSfg: sfg = SFG(inputs=[inp1, inp2, inp3], outputs=[out1], name="mac_sfg") assert sfg.__str__() == \ - "id: add1, name: ADD1, input: [s1, s2], output: [s5]\n" + \ - "id: in1, name: INP1, input: [], output: [s1]\n" + \ - "id: in2, name: INP2, input: [], output: [s2]\n" + \ - "id: mul1, name: MUL1, input: [s5, s3], output: [s4]\n" + \ - "id: in3, name: INP3, input: [], output: [s3]\n" + \ - "id: out1, name: OUT1, input: [s4], output: []\n" + "id: no_id, \tname: mac_sfg, \tinputs: {0: '-'}, \toutputs: {0: '-'}\n" + \ + "Internal Operations:\n" + \ + "----------------------------------------------------------------------------------------------------\n" + \ + str(sfg.find_by_name("INP1")[0]) + "\n" + \ + str(sfg.find_by_name("INP2")[0]) + "\n" + \ + str(sfg.find_by_name("ADD1")[0]) + "\n" + \ + str(sfg.find_by_name("INP3")[0]) + "\n" + \ + str(sfg.find_by_name("MUL1")[0]) + "\n" + \ + str(sfg.find_by_name("OUT1")[0]) + "\n" + \ + "----------------------------------------------------------------------------------------------------\n" def test_constant(self): inp1 = Input("INP1") @@ -88,18 +99,27 @@ class TestPrintSfg: sfg = SFG(inputs=[inp1], outputs=[out1], name="sfg") assert sfg.__str__() == \ - "id: add1, name: ADD1, input: [s3, s1], output: [s2]\n" + \ - "id: c1, name: CONST, value: 3, input: [], output: [s3]\n" + \ - "id: in1, name: INP1, input: [], output: [s1]\n" + \ - "id: out1, name: OUT1, input: [s2], output: []\n" + "id: no_id, \tname: sfg, \tinputs: {0: '-'}, \toutputs: {0: '-'}\n" + \ + "Internal Operations:\n" + \ + "----------------------------------------------------------------------------------------------------\n" + \ + str(sfg.find_by_name("CONST")[0]) + "\n" + \ + str(sfg.find_by_name("INP1")[0]) + "\n" + \ + str(sfg.find_by_name("ADD1")[0]) + "\n" + \ + str(sfg.find_by_name("OUT1")[0]) + "\n" + \ + "----------------------------------------------------------------------------------------------------\n" def test_simple_filter(self, simple_filter): + assert simple_filter.__str__() == \ - 'id: add1, name: , input: [s1, s3], output: [s4]\n' + \ - 'id: in1, name: , input: [], output: [s1]\n' + \ - 'id: cmul1, name: , input: [s5], output: [s3]\n' + \ - 'id: reg1, name: , input: [s4], output: [s5, s2]\n' + \ - 'id: out1, name: , input: [s2], output: []\n' + "id: no_id, \tname: simple_filter, \tinputs: {0: '-'}, \toutputs: {0: '-'}\n" + \ + "Internal Operations:\n" + \ + "----------------------------------------------------------------------------------------------------\n" + \ + str(simple_filter.find_by_name("IN1")[0]) + "\n" + \ + str(simple_filter.find_by_name("ADD1")[0]) + "\n" + \ + str(simple_filter.find_by_name("REG1")[0]) + "\n" + \ + str(simple_filter.find_by_name("CMUL1")[0]) + "\n" + \ + str(simple_filter.find_by_name("OUT1")[0]) + "\n" + \ + "----------------------------------------------------------------------------------------------------\n" class TestDeepCopy: @@ -267,7 +287,7 @@ class TestInsertComponent: _sfg = sfg.insert_operation(sqrt, sfg.find_by_name("constant4")[0].graph_id) assert _sfg.evaluate() != sfg.evaluate() - + assert any([isinstance(comp, SquareRoot) for comp in _sfg.operations]) assert not any([isinstance(comp, SquareRoot) for comp in sfg.operations]) @@ -275,7 +295,8 @@ class TestInsertComponent: assert isinstance(_sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation, SquareRoot) assert sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation is sfg.find_by_id("add3") - assert _sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation is not _sfg.find_by_id("add3") + assert _sfg.find_by_name("constant4")[0].output( + 0).signals[0].destination.operation is not _sfg.find_by_id("add3") assert _sfg.find_by_id("sqrt1").output(0).signals[0].destination.operation is _sfg.find_by_id("add3") def test_insert_invalid_component_in_sfg(self, large_operation_tree): @@ -304,22 +325,26 @@ class TestInsertComponent: assert len(_sfg.find_by_name("n_bfly")) == 1 # Correctly connected old output -> new input - assert _sfg.find_by_name("bfly3")[0].output(0).signals[0].destination.operation is _sfg.find_by_name("n_bfly")[0] - assert _sfg.find_by_name("bfly3")[0].output(1).signals[0].destination.operation is _sfg.find_by_name("n_bfly")[0] + assert _sfg.find_by_name("bfly3")[0].output( + 0).signals[0].destination.operation is _sfg.find_by_name("n_bfly")[0] + assert _sfg.find_by_name("bfly3")[0].output( + 1).signals[0].destination.operation is _sfg.find_by_name("n_bfly")[0] # Correctly connected new input -> old output assert _sfg.find_by_name("n_bfly")[0].input(0).signals[0].source.operation is _sfg.find_by_name("bfly3")[0] assert _sfg.find_by_name("n_bfly")[0].input(1).signals[0].source.operation is _sfg.find_by_name("bfly3")[0] # Correctly connected new output -> next input - assert _sfg.find_by_name("n_bfly")[0].output(0).signals[0].destination.operation is _sfg.find_by_name("bfly2")[0] - assert _sfg.find_by_name("n_bfly")[0].output(1).signals[0].destination.operation is _sfg.find_by_name("bfly2")[0] + assert _sfg.find_by_name("n_bfly")[0].output( + 0).signals[0].destination.operation is _sfg.find_by_name("bfly2")[0] + assert _sfg.find_by_name("n_bfly")[0].output( + 1).signals[0].destination.operation is _sfg.find_by_name("bfly2")[0] # Correctly connected next input -> new output assert _sfg.find_by_name("bfly2")[0].input(0).signals[0].source.operation is _sfg.find_by_name("n_bfly")[0] assert _sfg.find_by_name("bfly2")[0].input(1).signals[0].source.operation is _sfg.find_by_name("n_bfly")[0] - + class TestFindComponentsWithTypeName: def test_mac_components(self): inp1 = Input("INP1") @@ -358,28 +383,9 @@ class TestFindComponentsWithTypeName: class TestGetPrecedenceList: - def test_inputs_registers(self): - in1 = Input("IN1") - c0 = ConstantMultiplication(5, in1, "C0") - add1 = Addition(c0, None, "ADD1") - # Not sure what operation "Q" is supposed to be in the example - Q1 = ConstantMultiplication(1, add1, "Q1") - T1 = Register(Q1, 0, "T1") - T2 = Register(T1, 0, "T2") - b2 = ConstantMultiplication(2, T2, "B2") - b1 = ConstantMultiplication(3, T1, "B1") - add2 = Addition(b1, b2, "ADD2") - add1.input(1).connect(add2) - a1 = ConstantMultiplication(4, T1, "A1") - a2 = ConstantMultiplication(6, T2, "A2") - add3 = Addition(a1, a2, "ADD3") - a0 = ConstantMultiplication(7, Q1, "A0") - add4 = Addition(a0, add3, "ADD4") - out1 = Output(add4, "OUT1") - - sfg = SFG(inputs=[in1], outputs=[out1], name="SFG") + def test_inputs_registers(self, precedence_sfg_registers): - precedence_list = sfg.get_precedence_list() + precedence_list = precedence_sfg_registers.get_precedence_list() assert len(precedence_list) == 7 @@ -404,30 +410,9 @@ class TestGetPrecedenceList: assert set([port.operation.key(port.index, port.operation.name) for port in precedence_list[6]]) == {"ADD4"} - def test_inputs_constants_registers_multiple_outputs(self): - in1 = Input("IN1") - c0 = ConstantMultiplication(5, in1, "C0") - add1 = Addition(c0, None, "ADD1") - # Not sure what operation "Q" is supposed to be in the example - Q1 = ConstantMultiplication(1, add1, "Q1") - T1 = Register(Q1, 0, "T1") - const1 = Constant(10, "CONST1") # Replace T2 register with a constant - b2 = ConstantMultiplication(2, const1, "B2") - b1 = ConstantMultiplication(3, T1, "B1") - add2 = Addition(b1, b2, "ADD2") - add1.input(1).connect(add2) - a1 = ConstantMultiplication(4, T1, "A1") - a2 = ConstantMultiplication(10, const1, "A2") - add3 = Addition(a1, a2, "ADD3") - a0 = ConstantMultiplication(7, Q1, "A0") - # Replace ADD4 with a butterfly to test multiple output ports - bfly1 = Butterfly(a0, add3, "BFLY1") - out1 = Output(bfly1.output(0), "OUT1") - out2 = Output(bfly1.output(1), "OUT2") - - sfg = SFG(inputs=[in1], outputs=[out1], name="SFG") + def test_inputs_constants_registers_multiple_outputs(self, precedence_sfg_registers_and_constants): - precedence_list = sfg.get_precedence_list() + precedence_list = precedence_sfg_registers_and_constants.get_precedence_list() assert len(precedence_list) == 7 @@ -502,10 +487,48 @@ class TestGetPrecedenceList: for port in precedence_list[0]]) == {"IN1", "IN2"} assert set([port.operation.key(port.index, port.operation.name) - for port in precedence_list[1]]) == {"NESTED_SFG.0", "CMUL1"} + for port in precedence_list[1]]) == {"CMUL1"} assert set([port.operation.key(port.index, port.operation.name) - for port in precedence_list[2]]) == {"NESTED_SFG.1"} + for port in precedence_list[2]]) == {"NESTED_SFG.0", "NESTED_SFG.1"} + + +class TestPrintPrecedence: + def test_registers(self, precedence_sfg_registers): + sfg = precedence_sfg_registers + + captured_output = io.StringIO() + sys.stdout = captured_output + + sfg.print_precedence_graph() + + sys.stdout = sys.__stdout__ + + captured_output = captured_output.getvalue() + + assert captured_output == \ + "-" * 120 + "\n" + \ + "1.1 \t" + str(sfg.find_by_name("IN1")[0]) + "\n" + \ + "1.2 \t" + str(sfg.find_by_name("T1")[0]) + "\n" + \ + "1.3 \t" + str(sfg.find_by_name("T2")[0]) + "\n" + \ + "-" * 120 + "\n" + \ + "2.1 \t" + str(sfg.find_by_name("C0")[0]) + "\n" + \ + "2.2 \t" + str(sfg.find_by_name("A1")[0]) + "\n" + \ + "2.3 \t" + str(sfg.find_by_name("B1")[0]) + "\n" + \ + "2.4 \t" + str(sfg.find_by_name("A2")[0]) + "\n" + \ + "2.5 \t" + str(sfg.find_by_name("B2")[0]) + "\n" + \ + "-" * 120 + "\n" + \ + "3.1 \t" + str(sfg.find_by_name("ADD3")[0]) + "\n" + \ + "3.2 \t" + str(sfg.find_by_name("ADD2")[0]) + "\n" + \ + "-" * 120 + "\n" + \ + "4.1 \t" + str(sfg.find_by_name("ADD1")[0]) + "\n" + \ + "-" * 120 + "\n" + \ + "5.1 \t" + str(sfg.find_by_name("Q1")[0]) + "\n" + \ + "-" * 120 + "\n" + \ + "6.1 \t" + str(sfg.find_by_name("A0")[0]) + "\n" + \ + "-" * 120 + "\n" + \ + "7.1 \t" + str(sfg.find_by_name("ADD4")[0]) + "\n" + \ + "-" * 120 + "\n" class TestDepends: @@ -674,6 +697,17 @@ class TestConnectExternalSignalsToComponentsMultipleComp: assert test_sfg.evaluate(1, 2, 3, 4) == 16 assert not test_sfg.connect_external_signals_to_components() +class TestTopologicalOrderOperations: + def test_feedback_sfg(self, simple_filter): + topological_order = simple_filter.get_operations_topological_order() + + assert [comp.name for comp in topological_order] == ["IN1", "ADD1", "REG1", "CMUL1", "OUT1"] + + def test_multiple_independent_inputs(self, sfg_two_inputs_two_outputs_independent): + topological_order = sfg_two_inputs_two_outputs_independent.get_operations_topological_order() + + assert [comp.name for comp in topological_order] == ["IN1", "OUT1", "IN2", "C1", "ADD1", "OUT2"] + class TestShowPrecedenceGraph: def create_sfg(self, op_tree):