diff --git a/b_asic/GUI/arrow.py b/b_asic/GUI/arrow.py index 02a72bde916458e39f9477921f99e6edf173208b..3e5c83eec5f39e4266bbce93440bca42e93db9f1 100644 --- a/b_asic/GUI/arrow.py +++ b/b_asic/GUI/arrow.py @@ -246,7 +246,7 @@ class Arrow(QGraphicsPathItem): p.lineTo(QPointF(x0 - offset, ymid)) p.lineTo(QPointF(x0 - offset, y1)) else: - offset = -OFFSET if source_flipped else -OFFSET + offset = -OFFSET p.lineTo(QPointF(x0 + offset, y0)) p.lineTo(QPointF(x0 + offset, y1)) else: diff --git a/b_asic/GUI/main_window.py b/b_asic/GUI/main_window.py index 40dd53101a3133b0ad774956db722645179dabeb..72a09f14e7c29c0e95db4d414aa37458f683d842 100644 --- a/b_asic/GUI/main_window.py +++ b/b_asic/GUI/main_window.py @@ -756,7 +756,7 @@ class SFGMainWindow(QMainWindow): ) def _create_operation_item(self, item) -> None: - self._logger.info(f"Creating operation of type: {str(item.text())}") + self._logger.info(f"Creating operation of type: {item.text()!s}") try: attr_operation = self._operations_from_name[item.text()]() self.add_operation(attr_operation) @@ -903,7 +903,7 @@ class SFGMainWindow(QMainWindow): self._thread = {} self._sim_worker = {} for sfg, properties in self._simulation_dialog._properties.items(): - self._logger.info(f"Simulating SFG with name: {str(sfg.name)}") + self._logger.info(f"Simulating SFG with name: {sfg.name!s}") self._sim_worker[sfg] = SimulationWorker(sfg, properties) self._thread[sfg] = QThread() self._sim_worker[sfg].moveToThread(self._thread[sfg]) diff --git a/b_asic/GUI/precedence_graph_window.py b/b_asic/GUI/precedence_graph_window.py index 57365bf9cef7d3d224bf71ca1fec6f5d0269b512..022f5a0ab970d34c1a55f9c565e6b249c6cd6323 100644 --- a/b_asic/GUI/precedence_graph_window.py +++ b/b_asic/GUI/precedence_graph_window.py @@ -53,7 +53,7 @@ class PrecedenceGraphWindow(QDialog): self._dialog_layout.addLayout(self._sfg_layout) if len(self._check_box_dict) == 1: - check_box = list(self._check_box_dict.keys())[0] + check_box = next(iter(self._check_box_dict.keys())) check_box.setChecked(True) def show_precedence_graph(self): diff --git a/b_asic/codegen/vhdl/common.py b/b_asic/codegen/vhdl/common.py index d0e7eff02d49ee7652a144657d02a8f6346f2bfd..ec0849ad9d2e546e26bfe91c55cfadb5fc1d9fa3 100644 --- a/b_asic/codegen/vhdl/common.py +++ b/b_asic/codegen/vhdl/common.py @@ -168,7 +168,7 @@ def constant_declaration( An optional left padding value applied to the name. """ name_pad = 0 if name_pad is None else name_pad - write(f, 1, f"constant {name:<{name_pad}} : {signal_type} := {str(value)};") + write(f, 1, f"constant {name:<{name_pad}} : {signal_type} := {value!s};") def type_declaration( diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py index 30375a2ff24a308bd20a20d77b21b308561a0c6f..e5d9dc02f142b0d14d854d3deebd7f469de4aa65 100644 --- a/b_asic/core_operations.py +++ b/b_asic/core_operations.py @@ -610,7 +610,7 @@ class Min(AbstractOperation): def evaluate(self, a, b): if isinstance(a, complex) or isinstance(b, complex): raise ValueError("core_operations.Min does not support complex numbers.") - return a if a < b else b + return min(a, b) class Max(AbstractOperation): @@ -691,7 +691,7 @@ class Max(AbstractOperation): def evaluate(self, a, b): if isinstance(a, complex) or isinstance(b, complex): raise ValueError("core_operations.Max does not support complex numbers.") - return a if a > b else b + return max(a, b) class SquareRoot(AbstractOperation): diff --git a/b_asic/graph_component.py b/b_asic/graph_component.py index 2864904b38d1a99d9478c3eee07c0ce4e6e2e59d..449adce06331ff846374b3fead32d5c614d750e1 100644 --- a/b_asic/graph_component.py +++ b/b_asic/graph_component.py @@ -136,7 +136,7 @@ class AbstractGraphComponent(GraphComponent): f"id: {self.graph_id if self.graph_id else 'no_id'}, \tname:" f" {self.name if self.name else 'no_name'}" + "".join( - (f", \t{key}: {str(param)}" for key, param in self._parameters.items()) + (f", \t{key}: {param!s}" for key, param in self._parameters.items()) ) ) diff --git a/b_asic/list_schedulers.py b/b_asic/list_schedulers.py index 9b2351e5724c403d6efcbce8656e5714afbbdeea..cd36cae7d9858dc10871201b5a5a88a1079e5581 100644 --- a/b_asic/list_schedulers.py +++ b/b_asic/list_schedulers.py @@ -28,11 +28,11 @@ class LeastSlackTimeScheduler(ListScheduler): def __init__( self, - max_resources: dict[TypeName, int] = None, - max_concurrent_reads: int = None, - max_concurrent_writes: int = None, - input_times: dict["GraphID", int] = None, - output_delta_times: dict["GraphID", int] = None, + max_resources: dict[TypeName, int] | None = None, + max_concurrent_reads: int | None = None, + max_concurrent_writes: int | None = None, + input_times: dict["GraphID", int] | None = None, + output_delta_times: dict["GraphID", int] | None = None, ) -> None: super().__init__( sort_order=((2, True),), @@ -49,11 +49,11 @@ class MaxFanOutScheduler(ListScheduler): def __init__( self, - max_resources: dict[TypeName, int] = None, - max_concurrent_reads: int = None, - max_concurrent_writes: int = None, - input_times: dict["GraphID", int] = None, - output_delta_times: dict["GraphID", int] = None, + max_resources: dict[TypeName, int] | None = None, + max_concurrent_reads: int | None = None, + max_concurrent_writes: int | None = None, + input_times: dict["GraphID", int] | None = None, + output_delta_times: dict["GraphID", int] | None = None, ) -> None: super().__init__( sort_order=((3, False),), @@ -70,11 +70,11 @@ class HybridScheduler(ListScheduler): def __init__( self, - max_resources: dict[TypeName, int] = None, - max_concurrent_reads: int = None, - max_concurrent_writes: int = None, - input_times: dict["GraphID", int] = None, - output_delta_times: dict["GraphID", int] = None, + max_resources: dict[TypeName, int] | None = None, + max_concurrent_reads: int | None = None, + max_concurrent_writes: int | None = None, + input_times: dict["GraphID", int] | None = None, + output_delta_times: dict["GraphID", int] | None = None, ) -> None: super().__init__( sort_order=((2, True), (3, False)), diff --git a/b_asic/operation.py b/b_asic/operation.py index d22caccab4d9383a88f86e70fcaff22db395e8b7..0c9dfdb24038e70af7c83d6b99d19f816d6aa26a 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -584,7 +584,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): return ( super().__str__() - + f", \tinputs: {str(inputs_dict)}, \toutputs: {str(outputs_dict)}" + + f", \tinputs: {inputs_dict!s}, \toutputs: {outputs_dict!s}" ) @property diff --git a/b_asic/resources.py b/b_asic/resources.py index 9f4015ae670b4a8cf5faa36b1ae429b546e9b33f..acdcd1a01fbef4da96a1f162de3c3542058bb094 100644 --- a/b_asic/resources.py +++ b/b_asic/resources.py @@ -1,4 +1,5 @@ import io +import itertools import re from collections import Counter, defaultdict from collections.abc import Iterable @@ -1557,12 +1558,11 @@ class ProcessCollection: return max(self.read_port_accesses().values()) def read_port_accesses(self) -> dict[int, int]: - reads = sum( - ( + reads = list( + itertools.chain.from_iterable( [read_time % self.schedule_time for read_time in process.read_times] for process in self._collection - ), - [], + ) ) return dict(sorted(Counter(reads).items())) diff --git a/b_asic/schedule.py b/b_asic/schedule.py index 0ce9e42f634c42c5d935f320730a99745c2caf56..368a20f2efb123c270ac9ae6f38e8752dc930054 100644 --- a/b_asic/schedule.py +++ b/b_asic/schedule.py @@ -160,7 +160,7 @@ class Schedule: string_io.write("-" * (15 * len(header) + len(header) - 1) + "\n") for r in res_str: - row_str = "|".join(f"{str(item):^15}" for i, item in enumerate(r)) + row_str = "|".join(f"{item!s:^15}" for i, item in enumerate(r)) string_io.write(row_str + "\n") return string_io.getvalue() diff --git a/b_asic/scheduler.py b/b_asic/scheduler.py index 41ed7044831049f4edb59296de443d54363e06d7..767c38cfe7e4f2fe638ff101de6fb8655a7cc516 100644 --- a/b_asic/scheduler.py +++ b/b_asic/scheduler.py @@ -934,7 +934,7 @@ class RecursiveListScheduler(ListScheduler): # adjust time if a gap exists on the left side of the schedule slack = min(self._schedule._start_times.values()) for other_op_id in sorted_op_ids: - op = self._schedule._sfg.find_by_id(op_id) + op = self._schedule._sfg.find_by_id(other_op_id) max_end_time = 0 op_start_time = self._schedule._start_times[other_op_id] for outport in op.outputs: @@ -987,13 +987,9 @@ class RecursiveListScheduler(ListScheduler): prio_table = self._get_recursive_priority_table() self._schedule._schedule_time = self._schedule.get_max_end_time() - if ( - saved_sched_time is not None - and saved_sched_time < self._schedule._schedule_time - ): - raise ValueError( - f"Requested schedule time {saved_sched_time} cannot be reached, increase to {self._schedule._schedule_time} or assign more resources." - ) + + if saved_sched_time: + self._schedule._schedule_time = saved_sched_time self._logger.debug("--- Scheduling of recursive loops completed ---") def _get_next_recursive_op( diff --git a/b_asic/scheduler_gui/main_window.py b/b_asic/scheduler_gui/main_window.py index 75bb33f737d8333eb93bf699cbd15b6726dda239..3f616c6628b2bc4e5ba1d8f5c214f8424729d100 100644 --- a/b_asic/scheduler_gui/main_window.py +++ b/b_asic/scheduler_gui/main_window.py @@ -15,7 +15,7 @@ import webbrowser from collections import defaultdict, deque from copy import deepcopy from importlib.machinery import SourceFileLoader -from typing import TYPE_CHECKING, cast, overload +from typing import TYPE_CHECKING, ClassVar, cast, overload # Qt/qtpy import qtpy @@ -115,8 +115,8 @@ class ScheduleMainWindow(QMainWindow, Ui_MainWindow): _splitter_pos: int _splitter_min: int _zoom: float - _color_per_type: dict[str, QColor] = {} - _converted_color_per_type: dict[str, str] = {} + _color_per_type: ClassVar[dict[str, QColor]] = {} + _converted_color_per_type: ClassVar[dict[str, str]] = {} def __init__(self): """Initialize Scheduler-GUI.""" @@ -374,7 +374,7 @@ class ScheduleMainWindow(QMainWindow, Ui_MainWindow): return if len(schedule_obj_list) == 1: - schedule = list(schedule_obj_list.values())[0] + schedule = next(iter(schedule_obj_list.values())) else: ret_tuple = QInputDialog.getItem( self, diff --git a/b_asic/sfg_generators.py b/b_asic/sfg_generators.py index 7c85ad56188be85f527cecf5c25ebc2042c00275..b7e21aaaf12e1078e74553e3d76bda5a58c43330 100644 --- a/b_asic/sfg_generators.py +++ b/b_asic/sfg_generators.py @@ -369,14 +369,22 @@ def direct_form_1_iir( # construct the feed-forward part input_op = Input() - muls = [ConstantMultiplication(b[0], input_op, **mult_properties)] + if b[0] != 1: + muls = [ConstantMultiplication(b[0], input_op, **mult_properties)] + else: + muls = [input_op] delays = [] prev_delay = input_op for i, coeff in enumerate(b[1:]): prev_delay = Delay(prev_delay) delays.append(prev_delay) if i < len(b) - 1: - muls.append(ConstantMultiplication(coeff, prev_delay, **mult_properties)) + if coeff != 1: + muls.append( + ConstantMultiplication(coeff, prev_delay, **mult_properties) + ) + else: + muls.append(prev_delay) op_a = muls[-1] for i in range(len(muls) - 1): @@ -394,7 +402,12 @@ def direct_form_1_iir( prev_delay = Delay(prev_delay) delays.append(prev_delay) if i < len(a) - 1: - muls.append(ConstantMultiplication(-coeff, prev_delay, **mult_properties)) + if -coeff != 1: + muls.append( + ConstantMultiplication(-coeff, prev_delay, **mult_properties) + ) + else: + muls.append(prev_delay) op_a = muls[-1] for i in range(len(muls) - 1): @@ -443,12 +456,21 @@ def direct_form_2_iir( new_delay = Delay() delays[-1] <<= new_delay delays.append(new_delay) - left_muls.append( - ConstantMultiplication(-a_coeff, delays[-1], **mult_properties) - ) - right_muls.append( - ConstantMultiplication(b_coeff, delays[-1], **mult_properties) - ) + + if -a_coeff != 1: + left_muls.append( + ConstantMultiplication(-a_coeff, delays[-1], **mult_properties) + ) + else: + left_muls.append(delays[-1]) + + if b_coeff != 1: + right_muls.append( + ConstantMultiplication(b_coeff, delays[-1], **mult_properties) + ) + else: + right_muls.append(delays[-1]) + if len(left_muls) > 1: # not first iteration left_adds.append(Addition(op_a_left, left_muls[-1], **add_properties)) right_adds.append(Addition(op_a_right, right_muls[-1], **add_properties)) @@ -465,7 +487,12 @@ def direct_form_2_iir( else: left_adds.append(Addition(input_op, left_muls[-1], **add_properties)) delays[-1] <<= left_adds[-1] - mul = ConstantMultiplication(b[0], left_adds[-1], **mult_properties) + + if b[0] == 1: + mul = left_adds[-1] + else: + mul = ConstantMultiplication(b[0], left_adds[-1], **mult_properties) + if right_adds: add = Addition(mul, right_adds[-1], **add_properties) else: diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index d5ccbc045ee87169532c8ed471160a326c4af030..5b8d7c68e2f623a136e692291adf196e1aa0c3cd 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -14,11 +14,7 @@ from io import StringIO from math import ceil from numbers import Number from queue import PriorityQueue -from typing import ( - Optional, - Union, - cast, -) +from typing import ClassVar, Optional, Union, cast import numpy as np from graphviz import Digraph @@ -121,7 +117,7 @@ class SFG(AbstractOperation): _original_input_signals_to_indices: dict[Signal, int] _original_output_signals_to_indices: dict[Signal, int] _precedence_list: list[list[OutputPort]] | None - _used_ids: set[GraphID] = set() + _used_ids: ClassVar[set[GraphID]] = set() def __init__( self, @@ -1800,7 +1796,7 @@ class SFG(AbstractOperation): for next_state in graph[state]: if next_state in path: continue - fringe.append((next_state, path + [next_state])) + fringe.append((next_state, [*path, next_state])) def resource_lower_bound(self, type_name: TypeName, schedule_time: int) -> int: """ @@ -1906,7 +1902,7 @@ class SFG(AbstractOperation): else: raise ValueError("Destination does not exist") cycles = [ - [node] + path + [node, *path] for node in dict_of_sfg for path in self._dfs(dict_of_sfg, node, node) ] @@ -1995,7 +1991,7 @@ class SFG(AbstractOperation): if "".join([i for i in key if not i.isdigit()]) == "c": addition_with_constant[item[0]] = self.find_by_id(key).value cycles = [ - [node] + path + [node, *path] for node in dict_of_sfg if node[0] == "t" for path in self._dfs(dict_of_sfg, node, node) @@ -2129,7 +2125,7 @@ class SFG(AbstractOperation): if path is None: path = [] - path = path + [start] + path = [*path, start] if start == end: return [path] if start not in graph: diff --git a/pyproject.toml b/pyproject.toml index 9d44ead1b48a0697f21d7dd3eaffa048944797d9..88282db0f0b68511b29d006c68f9adbac121e1c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,8 +100,8 @@ precision = 2 exclude = ["examples"] [tool.ruff.lint] -select = ["E4", "E7", "E9", "F", "SIM", "B", "NPY", "C4", "UP"] -ignore = ["F403", "B008", "B021", "B006", "UP038"] +select = ["E4", "E7", "E9", "F", "SIM", "B", "NPY", "C4", "UP", "RUF"] +ignore = ["F403", "B008", "B021", "B006", "UP038", "RUF023"] [tool.typos] default.extend-identifiers = { ba = "ba", addd0 = "addd0", inout = "inout", ArChItEctUrE = "ArChItEctUrE" } diff --git a/test/unit/test_list_schedulers.py b/test/unit/test_list_schedulers.py index 9dbe448cb290c4426fb967988dc9e8efdd11d84c..1d4c1bd5ff177710afe66bd59eceefa9402a8466 100644 --- a/test/unit/test_list_schedulers.py +++ b/test/unit/test_list_schedulers.py @@ -39,7 +39,7 @@ class TestEarliestDeadlineScheduler: Schedule(sfg_empty, scheduler=EarliestDeadlineScheduler()) def test_direct_form_1_iir(self): - sfg = direct_form_1_iir([1, 2, 3], [1, 2, 3]) + sfg = direct_form_1_iir([0.1, 0.2, 0.3], [1, 2, 3]) sfg.set_latency_of_type_name(ConstantMultiplication.type_name(), 2) sfg.set_execution_time_of_type_name(ConstantMultiplication.type_name(), 1) @@ -217,7 +217,7 @@ class TestLeastSlackTimeScheduler: Schedule(sfg_empty, scheduler=LeastSlackTimeScheduler()) def test_direct_form_1_iir(self): - sfg = direct_form_1_iir([1, 2, 3], [1, 2, 3]) + sfg = direct_form_1_iir([0.1, 0.2, 0.3], [1, 2, 3]) sfg.set_latency_of_type_name(ConstantMultiplication.type_name(), 2) sfg.set_execution_time_of_type_name(ConstantMultiplication.type_name(), 1) @@ -395,7 +395,7 @@ class TestMaxFanOutScheduler: Schedule(sfg_empty, scheduler=MaxFanOutScheduler()) def test_direct_form_1_iir(self): - sfg = direct_form_1_iir([1, 2, 3], [1, 2, 3]) + sfg = direct_form_1_iir([0.1, 0.2, 0.3], [1, 2, 3]) sfg.set_latency_of_type_name(ConstantMultiplication.type_name(), 2) sfg.set_execution_time_of_type_name(ConstantMultiplication.type_name(), 1) @@ -482,7 +482,7 @@ class TestHybridScheduler: Schedule(sfg_empty, scheduler=HybridScheduler()) def test_direct_form_1_iir(self): - sfg = direct_form_1_iir([1, 2, 3], [1, 2, 3]) + sfg = direct_form_1_iir([0.1, 0.2, 0.3], [1, 2, 3]) sfg.set_latency_of_type_name(ConstantMultiplication.type_name(), 2) sfg.set_execution_time_of_type_name(ConstantMultiplication.type_name(), 1) @@ -1585,7 +1585,7 @@ class TestHybridScheduler: ) def test_iteration_period_bound(self): - sfg = direct_form_1_iir([1, 2, 3], [1, 2, 3]) + sfg = direct_form_1_iir([0.1, 0.2, 0.3], [1, 2, 3]) sfg.set_latency_of_type_name(ConstantMultiplication.type_name(), 2) sfg.set_execution_time_of_type_name(ConstantMultiplication.type_name(), 1) diff --git a/test/unit/test_process.py b/test/unit/test_process.py index 6b11e5e593348d76589911cd4bd2df949b9d2d05..267dba764a77a9338657135cd19d44cccb9bd05c 100644 --- a/test/unit/test_process.py +++ b/test/unit/test_process.py @@ -30,7 +30,7 @@ def test_MemoryVariables(secondorder_iir_schedule): "MemoryVariable\\(3, <b_asic.port.OutputPort object at 0x[a-fA-F0-9]+>," " {<b_asic.port.InputPort object at 0x[a-fA-F0-9]+>: 4}, 'cmul0.0'\\)" ) - mem_var = [m for m in mem_vars if m.name == "cmul0.0"][0] + mem_var = next(m for m in mem_vars if m.name == "cmul0.0") assert pattern.match(repr(mem_var)) assert mem_var.execution_time == 4 assert mem_var.start_time == 3 diff --git a/test/unit/test_schedule.py b/test/unit/test_schedule.py index 0a94212e7260d3369e4b671a1551579003a23ead..2efd6d849923cda2313a7e77870c6a363a03b675 100644 --- a/test/unit/test_schedule.py +++ b/test/unit/test_schedule.py @@ -258,7 +258,7 @@ class TestInit: assert schedule.schedule_time == 10 def test_provided_schedule(self): - sfg = direct_form_1_iir([1, 2, 3], [1, 2, 3]) + sfg = direct_form_1_iir([2, 2, 3], [1, 2, 3]) sfg.set_latency_of_type_name(Addition.type_name(), 1) sfg.set_latency_of_type_name(ConstantMultiplication.type_name(), 3) diff --git a/test/unit/test_scheduler.py b/test/unit/test_scheduler.py index 374f50212e1d7502623abd13b507495bac3588d7..b42df835516532f9efafb67c8264d0811054db4d 100644 --- a/test/unit/test_scheduler.py +++ b/test/unit/test_scheduler.py @@ -14,7 +14,7 @@ class TestASAPScheduler: Schedule(sfg_empty, scheduler=ASAPScheduler()) def test_direct_form_1_iir(self): - sfg = direct_form_1_iir([1, 2, 3], [1, 2, 3]) + sfg = direct_form_1_iir([0.1, 0.2, 0.3], [1, 2, 3]) sfg.set_latency_of_type_name(ConstantMultiplication.type_name(), 2) sfg.set_execution_time_of_type_name(ConstantMultiplication.type_name(), 1) @@ -144,7 +144,7 @@ class TestALAPScheduler: Schedule(sfg_empty, scheduler=ALAPScheduler()) def test_direct_form_1_iir(self): - sfg = direct_form_1_iir([1, 2, 3], [1, 2, 3]) + sfg = direct_form_1_iir([0.1, 0.2, 0.3], [1, 2, 3]) sfg.set_latency_of_type_name(ConstantMultiplication.type_name(), 2) sfg.set_execution_time_of_type_name(ConstantMultiplication.type_name(), 1) diff --git a/test/unit/test_sfg_generators.py b/test/unit/test_sfg_generators.py index f164ee50866565da9b20d4c8e2e58ab5921641bb..c28e1b9ee8e48ed1a416a7e09cc820a2e72a899f 100644 --- a/test/unit/test_sfg_generators.py +++ b/test/unit/test_sfg_generators.py @@ -350,10 +350,7 @@ class TestDirectFormIIRType1: sfg = direct_form_1_iir(b, a, name="test iir direct form 1") amount_of_muls = len(sfg.find_by_type_name(ConstantMultiplication.type_name())) - assert amount_of_muls == 2 * N + 1 - - amount_of_muls = len(sfg.find_by_type(ConstantMultiplication)) - assert amount_of_muls == 2 * N + 1 + assert amount_of_muls == 2 * N amount_of_adds = len(sfg.find_by_type_name(Addition.type_name())) assert amount_of_adds == 2 * N @@ -362,10 +359,44 @@ class TestDirectFormIIRType1: assert amount_of_delays == 2 * N amount_of_ops = len(sfg.operations) - assert amount_of_ops == 6 * N + 3 + assert amount_of_ops == 6 * N + 2 assert sfg.name == "test iir direct form 1" + b = [1, 0.1, 0.1, 1, 1] + a = [1, 0.1, 0.1, -1, -1] + + sfg = direct_form_1_iir(b, a, name="test iir direct form 1") + + amount_of_muls = len(sfg.find_by_type_name(ConstantMultiplication.type_name())) + assert amount_of_muls == 4 + + amount_of_adds = len(sfg.find_by_type_name(Addition.type_name())) + assert amount_of_adds == 8 + + amount_of_delays = len(sfg.find_by_type_name(Delay.type_name())) + assert amount_of_delays == 8 + + amount_of_ops = len(sfg.operations) + assert amount_of_ops == 22 + + b = [1, 1, 1, 1, 1] + a = [1, -1, -1, -1, -1] + + sfg = direct_form_1_iir(b, a, name="test iir direct form 1") + + amount_of_muls = len(sfg.find_by_type_name(ConstantMultiplication.type_name())) + assert amount_of_muls == 0 + + amount_of_adds = len(sfg.find_by_type_name(Addition.type_name())) + assert amount_of_adds == 8 + + amount_of_delays = len(sfg.find_by_type_name(Delay.type_name())) + assert amount_of_delays == 8 + + amount_of_ops = len(sfg.operations) + assert amount_of_ops == 18 + def test_b_single_coeff(self): with pytest.raises( ValueError, @@ -476,14 +507,27 @@ class TestDirectFormIIRType1: class TestDirectFormIIRType2: def test_correct_number_of_operations_and_name(self): N = 17 - - b = [i + 1 for i in range(N + 1)] + b = list(range(N + 1)) a = [i + 1 for i in range(N + 1)] + sfg = direct_form_2_iir(b, a, name="test iir direct form 2") + amount_of_muls = len(sfg.find_by_type_name(ConstantMultiplication.type_name())) + assert amount_of_muls == 2 * N + + amount_of_adds = len(sfg.find_by_type_name(Addition.type_name())) + assert amount_of_adds == 2 * N + + amount_of_delays = len(sfg.find_by_type_name(Delay.type_name())) + assert amount_of_delays == N + + amount_of_ops = len(sfg.operations) + assert amount_of_ops == 5 * N + 2 + + b = [i + 1 for i in range(N + 1)] sfg = direct_form_2_iir(b, a, name="test iir direct form 2") amount_of_muls = len(sfg.find_by_type_name(ConstantMultiplication.type_name())) - assert amount_of_muls == 2 * N + 1 + assert amount_of_muls == 2 * N amount_of_adds = len(sfg.find_by_type_name(Addition.type_name())) assert amount_of_adds == 2 * N @@ -492,10 +536,43 @@ class TestDirectFormIIRType2: assert amount_of_delays == N amount_of_ops = len(sfg.operations) - assert amount_of_ops == 5 * N + 3 + assert amount_of_ops == 5 * N + 2 assert sfg.name == "test iir direct form 2" + b = [1, 0.1, 1, 0.1] + a = [1, -1, -1, 0.1] + sfg = direct_form_2_iir(b, a) + + amount_of_muls = len(sfg.find_by_type_name(ConstantMultiplication.type_name())) + assert amount_of_muls == 3 + + amount_of_adds = len(sfg.find_by_type_name(Addition.type_name())) + assert amount_of_adds == 6 + + amount_of_delays = len(sfg.find_by_type_name(Delay.type_name())) + assert amount_of_delays == 3 + + amount_of_ops = len(sfg.operations) + assert amount_of_ops == 14 + + b = [1, 1, 1, 1, 1] + a = [1, -1, -1, -1, -1] + + sfg = direct_form_2_iir(b, a, name="test iir direct form 1") + + amount_of_muls = len(sfg.find_by_type_name(ConstantMultiplication.type_name())) + assert amount_of_muls == 0 + + amount_of_adds = len(sfg.find_by_type_name(Addition.type_name())) + assert amount_of_adds == 8 + + amount_of_delays = len(sfg.find_by_type_name(Delay.type_name())) + assert amount_of_delays == 4 + + amount_of_ops = len(sfg.operations) + assert amount_of_ops == 14 + def test_b_single_coeff(self): with pytest.raises( ValueError,