From 5a905f37e6846d566aa3e623c32ddeb288445625 Mon Sep 17 00:00:00 2001
From: Oscar Gustafsson <oscar.gustafsson@gmail.com>
Date: Thu, 23 Feb 2023 14:49:28 +0100
Subject: [PATCH] Fix some typing

---
 b_asic/resources.py    |  2 +-
 b_asic/schedule.py     | 18 +++++++++++-------
 test/test_resources.py |  7 ++++---
 test/test_schedule.py  |  2 ++
 4 files changed, 18 insertions(+), 11 deletions(-)

diff --git a/b_asic/resources.py b/b_asic/resources.py
index 362e09dd..a9d5fbf2 100644
--- a/b_asic/resources.py
+++ b/b_asic/resources.py
@@ -120,7 +120,7 @@ class ProcessCollection:
         return self._collection
 
     def __len__(self):
-        return len(self.__collection__)
+        return len(self._collection)
 
     def add_process(self, process: Process):
         """
diff --git a/b_asic/schedule.py b/b_asic/schedule.py
index 595c65c5..5bbfc2a6 100644
--- a/b_asic/schedule.py
+++ b/b_asic/schedule.py
@@ -79,8 +79,8 @@ class Schedule:
         schedule_time: Optional[int] = None,
         cyclic: bool = False,
         scheduling_algorithm: str = "ASAP",
-        start_times: Dict[GraphID, int] = None,
-        laps: Dict[GraphID, int] = None,
+        start_times: Optional[Dict[GraphID, int]] = None,
+        laps: Optional[Dict[GraphID, int]] = None,
     ):
         """Construct a Schedule from an SFG."""
         self._original_sfg = sfg()  # Make a copy
@@ -92,6 +92,10 @@ class Schedule:
         if scheduling_algorithm == "ASAP":
             self._schedule_asap()
         elif scheduling_algorithm == "provided":
+            if start_times is None:
+                raise ValueError("Must provide start_times when using 'provided'")
+            if laps is None:
+                raise ValueError("Must provide laps when using 'provided'")
             self._start_times = start_times
             self._laps.update(laps)
             self._remove_delays_no_laps()
@@ -403,10 +407,10 @@ class Schedule:
 
         """
         if insert:
-            for gid, y_location in self._y_locations.items():
-                if y_location >= new_y:
-                    self._y_locations[gid] += 1
-        self._y_locations[graph_id] = new_y
+            for gid in self._y_locations:
+                if self.get_y_location(gid) >= new_y:
+                    self.set_y_location(gid, self.get_y_location(gid) + 1)
+        self.set_y_location(graph_id, new_y)
         used_locations = {*self._y_locations.values()}
         possible_locations = set(range(max(used_locations) + 1))
         if not possible_locations - used_locations:
@@ -889,7 +893,7 @@ class Schedule:
 
     def _reset_y_locations(self) -> None:
         """Reset all the y-locations in the schedule to None"""
-        self._y_locations = self._y_locations = defaultdict(lambda: None)
+        self._y_locations = defaultdict(lambda: None)
 
     def plot_in_axes(self, ax: Axes, operation_gap: Optional[float] = None) -> None:
         """
diff --git a/test/test_resources.py b/test/test_resources.py
index 020fadfc..ea0b0d00 100644
--- a/test/test_resources.py
+++ b/test/test_resources.py
@@ -1,15 +1,13 @@
 import pickle
 
 import matplotlib.pyplot as plt
-import networkx as nx
 import pytest
 
-from b_asic.process import Process
 from b_asic.research.interleaver import (
     generate_matrix_transposer,
     generate_random_interleaver,
 )
-from b_asic.resources import ProcessCollection, draw_exclusion_graph_coloring
+from b_asic.resources import ProcessCollection
 
 
 class TestProcessCollectionPlainMemoryVariable:
@@ -44,3 +42,6 @@ class TestProcessCollectionPlainMemoryVariable:
                 assert len(collection.split_ports(read_ports=1, write_ports=1)) == 1
                 if any(var.execution_time for var in collection.collection):
                     assert len(collection.split_ports(total_ports=1)) == 2
+
+    def test_len_process_collection(self, simple_collection: ProcessCollection):
+        assert len(simple_collection) == 7
diff --git a/test/test_schedule.py b/test/test_schedule.py
index 0f2c412c..5d22981a 100644
--- a/test/test_schedule.py
+++ b/test/test_schedule.py
@@ -495,6 +495,8 @@ class TestProcesses:
     def test__get_memory_variables_list(self, secondorder_iir_schedule):
         mvl = secondorder_iir_schedule._get_memory_variables_list()
         assert len(mvl) == 12
+        pc = secondorder_iir_schedule.get_memory_variables()
+        assert len(pc) == 12
 
 
 class TestFigureGeneration:
-- 
GitLab