From 5a905f37e6846d566aa3e623c32ddeb288445625 Mon Sep 17 00:00:00 2001 From: Oscar Gustafsson <oscar.gustafsson@gmail.com> Date: Thu, 23 Feb 2023 14:49:28 +0100 Subject: [PATCH] Fix some typing --- b_asic/resources.py | 2 +- b_asic/schedule.py | 18 +++++++++++------- test/test_resources.py | 7 ++++--- test/test_schedule.py | 2 ++ 4 files changed, 18 insertions(+), 11 deletions(-) diff --git a/b_asic/resources.py b/b_asic/resources.py index 362e09dd..a9d5fbf2 100644 --- a/b_asic/resources.py +++ b/b_asic/resources.py @@ -120,7 +120,7 @@ class ProcessCollection: return self._collection def __len__(self): - return len(self.__collection__) + return len(self._collection) def add_process(self, process: Process): """ diff --git a/b_asic/schedule.py b/b_asic/schedule.py index 595c65c5..5bbfc2a6 100644 --- a/b_asic/schedule.py +++ b/b_asic/schedule.py @@ -79,8 +79,8 @@ class Schedule: schedule_time: Optional[int] = None, cyclic: bool = False, scheduling_algorithm: str = "ASAP", - start_times: Dict[GraphID, int] = None, - laps: Dict[GraphID, int] = None, + start_times: Optional[Dict[GraphID, int]] = None, + laps: Optional[Dict[GraphID, int]] = None, ): """Construct a Schedule from an SFG.""" self._original_sfg = sfg() # Make a copy @@ -92,6 +92,10 @@ class Schedule: if scheduling_algorithm == "ASAP": self._schedule_asap() elif scheduling_algorithm == "provided": + if start_times is None: + raise ValueError("Must provide start_times when using 'provided'") + if laps is None: + raise ValueError("Must provide laps when using 'provided'") self._start_times = start_times self._laps.update(laps) self._remove_delays_no_laps() @@ -403,10 +407,10 @@ class Schedule: """ if insert: - for gid, y_location in self._y_locations.items(): - if y_location >= new_y: - self._y_locations[gid] += 1 - self._y_locations[graph_id] = new_y + for gid in self._y_locations: + if self.get_y_location(gid) >= new_y: + self.set_y_location(gid, self.get_y_location(gid) + 1) + self.set_y_location(graph_id, new_y) used_locations = {*self._y_locations.values()} possible_locations = set(range(max(used_locations) + 1)) if not possible_locations - used_locations: @@ -889,7 +893,7 @@ class Schedule: def _reset_y_locations(self) -> None: """Reset all the y-locations in the schedule to None""" - self._y_locations = self._y_locations = defaultdict(lambda: None) + self._y_locations = defaultdict(lambda: None) def plot_in_axes(self, ax: Axes, operation_gap: Optional[float] = None) -> None: """ diff --git a/test/test_resources.py b/test/test_resources.py index 020fadfc..ea0b0d00 100644 --- a/test/test_resources.py +++ b/test/test_resources.py @@ -1,15 +1,13 @@ import pickle import matplotlib.pyplot as plt -import networkx as nx import pytest -from b_asic.process import Process from b_asic.research.interleaver import ( generate_matrix_transposer, generate_random_interleaver, ) -from b_asic.resources import ProcessCollection, draw_exclusion_graph_coloring +from b_asic.resources import ProcessCollection class TestProcessCollectionPlainMemoryVariable: @@ -44,3 +42,6 @@ class TestProcessCollectionPlainMemoryVariable: assert len(collection.split_ports(read_ports=1, write_ports=1)) == 1 if any(var.execution_time for var in collection.collection): assert len(collection.split_ports(total_ports=1)) == 2 + + def test_len_process_collection(self, simple_collection: ProcessCollection): + assert len(simple_collection) == 7 diff --git a/test/test_schedule.py b/test/test_schedule.py index 0f2c412c..5d22981a 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -495,6 +495,8 @@ class TestProcesses: def test__get_memory_variables_list(self, secondorder_iir_schedule): mvl = secondorder_iir_schedule._get_memory_variables_list() assert len(mvl) == 12 + pc = secondorder_iir_schedule.get_memory_variables() + assert len(pc) == 12 class TestFigureGeneration: -- GitLab