Skip to content
Snippets Groups Projects
Commit 5a905f37 authored by Oscar Gustafsson's avatar Oscar Gustafsson :bicyclist:
Browse files

Fix some typing

parent 722749dd
No related branches found
No related tags found
1 merge request!225Fix some typing
Pipeline #90267 passed
...@@ -120,7 +120,7 @@ class ProcessCollection: ...@@ -120,7 +120,7 @@ class ProcessCollection:
return self._collection return self._collection
def __len__(self): def __len__(self):
return len(self.__collection__) return len(self._collection)
def add_process(self, process: Process): def add_process(self, process: Process):
""" """
......
...@@ -79,8 +79,8 @@ class Schedule: ...@@ -79,8 +79,8 @@ class Schedule:
schedule_time: Optional[int] = None, schedule_time: Optional[int] = None,
cyclic: bool = False, cyclic: bool = False,
scheduling_algorithm: str = "ASAP", scheduling_algorithm: str = "ASAP",
start_times: Dict[GraphID, int] = None, start_times: Optional[Dict[GraphID, int]] = None,
laps: Dict[GraphID, int] = None, laps: Optional[Dict[GraphID, int]] = None,
): ):
"""Construct a Schedule from an SFG.""" """Construct a Schedule from an SFG."""
self._original_sfg = sfg() # Make a copy self._original_sfg = sfg() # Make a copy
...@@ -92,6 +92,10 @@ class Schedule: ...@@ -92,6 +92,10 @@ class Schedule:
if scheduling_algorithm == "ASAP": if scheduling_algorithm == "ASAP":
self._schedule_asap() self._schedule_asap()
elif scheduling_algorithm == "provided": elif scheduling_algorithm == "provided":
if start_times is None:
raise ValueError("Must provide start_times when using 'provided'")
if laps is None:
raise ValueError("Must provide laps when using 'provided'")
self._start_times = start_times self._start_times = start_times
self._laps.update(laps) self._laps.update(laps)
self._remove_delays_no_laps() self._remove_delays_no_laps()
...@@ -403,10 +407,10 @@ class Schedule: ...@@ -403,10 +407,10 @@ class Schedule:
""" """
if insert: if insert:
for gid, y_location in self._y_locations.items(): for gid in self._y_locations:
if y_location >= new_y: if self.get_y_location(gid) >= new_y:
self._y_locations[gid] += 1 self.set_y_location(gid, self.get_y_location(gid) + 1)
self._y_locations[graph_id] = new_y self.set_y_location(graph_id, new_y)
used_locations = {*self._y_locations.values()} used_locations = {*self._y_locations.values()}
possible_locations = set(range(max(used_locations) + 1)) possible_locations = set(range(max(used_locations) + 1))
if not possible_locations - used_locations: if not possible_locations - used_locations:
...@@ -889,7 +893,7 @@ class Schedule: ...@@ -889,7 +893,7 @@ class Schedule:
def _reset_y_locations(self) -> None: def _reset_y_locations(self) -> None:
"""Reset all the y-locations in the schedule to None""" """Reset all the y-locations in the schedule to None"""
self._y_locations = self._y_locations = defaultdict(lambda: None) self._y_locations = defaultdict(lambda: None)
def plot_in_axes(self, ax: Axes, operation_gap: Optional[float] = None) -> None: def plot_in_axes(self, ax: Axes, operation_gap: Optional[float] = None) -> None:
""" """
......
import pickle import pickle
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import networkx as nx
import pytest import pytest
from b_asic.process import Process
from b_asic.research.interleaver import ( from b_asic.research.interleaver import (
generate_matrix_transposer, generate_matrix_transposer,
generate_random_interleaver, generate_random_interleaver,
) )
from b_asic.resources import ProcessCollection, draw_exclusion_graph_coloring from b_asic.resources import ProcessCollection
class TestProcessCollectionPlainMemoryVariable: class TestProcessCollectionPlainMemoryVariable:
...@@ -44,3 +42,6 @@ class TestProcessCollectionPlainMemoryVariable: ...@@ -44,3 +42,6 @@ class TestProcessCollectionPlainMemoryVariable:
assert len(collection.split_ports(read_ports=1, write_ports=1)) == 1 assert len(collection.split_ports(read_ports=1, write_ports=1)) == 1
if any(var.execution_time for var in collection.collection): if any(var.execution_time for var in collection.collection):
assert len(collection.split_ports(total_ports=1)) == 2 assert len(collection.split_ports(total_ports=1)) == 2
def test_len_process_collection(self, simple_collection: ProcessCollection):
assert len(simple_collection) == 7
...@@ -495,6 +495,8 @@ class TestProcesses: ...@@ -495,6 +495,8 @@ class TestProcesses:
def test__get_memory_variables_list(self, secondorder_iir_schedule): def test__get_memory_variables_list(self, secondorder_iir_schedule):
mvl = secondorder_iir_schedule._get_memory_variables_list() mvl = secondorder_iir_schedule._get_memory_variables_list()
assert len(mvl) == 12 assert len(mvl) == 12
pc = secondorder_iir_schedule.get_memory_variables()
assert len(pc) == 12
class TestFigureGeneration: class TestFigureGeneration:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment