From 2ce86e81176d1719e654c03b5c518465a35aa2c0 Mon Sep 17 00:00:00 2001
From: Oscar Gustafsson <oscar.gustafsson@gmail.com>
Date: Thu, 18 May 2023 00:18:52 +0200
Subject: [PATCH] Rewrite slack computation

---
 b_asic/resources.py              | 27 +++++++++++----------
 b_asic/schedule.py               | 40 ++++++++++++++++++--------------
 examples/fivepointwinograddft.py |  3 +++
 test/test_schedule.py            |  4 ++--
 4 files changed, 43 insertions(+), 31 deletions(-)

diff --git a/b_asic/resources.py b/b_asic/resources.py
index 15a3b0da..03ed8f22 100644
--- a/b_asic/resources.py
+++ b/b_asic/resources.py
@@ -1409,11 +1409,13 @@ class ProcessCollection:
         return max(self.read_port_accesses().values())
 
     def read_port_accesses(self) -> Dict[int, int]:
-        reads = []
-        for process in self._collection:
-            reads.extend(
-                set(read_time % self.schedule_time for read_time in process.read_times)
-            )
+        reads = sum(
+            (
+                [read_time % self.schedule_time for read_time in process.read_times]
+                for process in self._collection
+            ),
+            [],
+        )
         return dict(sorted(Counter(reads).items()))
 
     def write_ports_bound(self) -> int:
@@ -1444,13 +1446,14 @@ class ProcessCollection:
         return max(self.total_port_accesses().values())
 
     def total_port_accesses(self) -> Dict[int, int]:
-        accesses = [
-            process.start_time % self.schedule_time for process in self._collection
-        ]
-        for process in self._collection:
-            accesses.extend(
-                set(read_time % self.schedule_time for read_time in process.read_times)
-            )
+        accesses = sum(
+            (
+                list(read_time % self.schedule_time for read_time in process.read_times)
+                for process in self._collection
+            ),
+            [process.start_time % self.schedule_time for process in self._collection],
+        )
+
         return dict(sorted(Counter(accesses).items()))
 
     def from_name(self, name: str):
diff --git a/b_asic/schedule.py b/b_asic/schedule.py
index c80466a2..dc9976cb 100644
--- a/b_asic/schedule.py
+++ b/b_asic/schedule.py
@@ -72,7 +72,7 @@ class Schedule:
         algorithm.
     cyclic : bool, default: False
         If the schedule is cyclic.
-    algorithm : {'ASAP', 'ALAP', 'provided'}, optional
+    algorithm : {'ASAP', 'ALAP', 'provided'}, default: 'ASAP'
         The scheduling algorithm to use. The following algorithm are available:
            * ``'ASAP'``: As-soon-as-possible scheduling.
            * ``'ALAP'``: As-late-as-possible scheduling.
@@ -84,10 +84,10 @@ class Schedule:
         Dictionary with GraphIDs as keys and laps as values.
         Used when *algorithm* is 'provided'.
     max_resources : dict, optional
-        Dictionary like ``{'cmul': 2}`` denoting the maximum number of resources
-        for a given operation type if the scheduling algorithm considers that.
-        If not provided, or an operation type is not provided, at most one resource is
-        used.
+        Dictionary like ``{Addition.type_name(): 2}`` denoting the maximum number of
+        resources for a given operation type if the scheduling algorithm considers
+        that. If not provided, or an operation type is not provided, at most one
+        resource is used.
     """
 
     _sfg: SFG
@@ -181,13 +181,16 @@ class Schedule:
         """
         if graph_id not in self._start_times:
             raise ValueError(f"No operation with graph_id {graph_id} in schedule")
-        slack = sys.maxsize
         output_slacks = self._forward_slacks(graph_id)
-        # Make more pythonic
-        for signal_slacks in output_slacks.values():
-            for signal_slack in signal_slacks.values():
-                slack = min(slack, signal_slack)
-        return slack
+        return min(
+            sum(
+                (
+                    list(signal_slacks.values())
+                    for signal_slacks in output_slacks.values()
+                ),
+                [sys.maxsize],
+            )
+        )
 
     def _forward_slacks(
         self, graph_id: GraphID
@@ -241,13 +244,16 @@ class Schedule:
         """
         if graph_id not in self._start_times:
             raise ValueError(f"No operation with graph_id {graph_id} in schedule")
-        slack = sys.maxsize
         input_slacks = self._backward_slacks(graph_id)
-        # Make more pythonic
-        for signal_slacks in input_slacks.values():
-            for signal_slack in signal_slacks.values():
-                slack = min(slack, signal_slack)
-        return slack
+        return min(
+            sum(
+                (
+                    list(signal_slacks.values())
+                    for signal_slacks in input_slacks.values()
+                ),
+                [sys.maxsize],
+            )
+        )
 
     def _backward_slacks(self, graph_id: GraphID) -> Dict[InputPort, Dict[Signal, int]]:
         ret = {}
diff --git a/examples/fivepointwinograddft.py b/examples/fivepointwinograddft.py
index 798d3bc3..dea0cd55 100644
--- a/examples/fivepointwinograddft.py
+++ b/examples/fivepointwinograddft.py
@@ -130,6 +130,7 @@ schedule.move_operation('bfly3', -2)
 schedule.move_operation('bfly4', -1)
 schedule.show()
 
+# %%
 # Extract memory variables and operation executions
 operations = schedule.get_operations()
 adders = operations.get_by_type_name(AddSub.type_name())
@@ -168,6 +169,8 @@ fig, ax = plt.subplots()
 fig.suptitle('Exclusion graph based on ports')
 nx.draw(mem_vars.create_exclusion_graph_from_ports(1, 1, 2), ax=ax)
 
+# %%
+# Create architecture
 arch = Architecture(
     [addsub, butterfly, multiplier, pe_in, pe_out],
     memories,
diff --git a/test/test_schedule.py b/test/test_schedule.py
index 14ba916a..cfa66db0 100644
--- a/test/test_schedule.py
+++ b/test/test_schedule.py
@@ -70,9 +70,9 @@ class TestInit:
             print(op.latency_offsets)
 
         start_times_names = {}
-        for op_id, start_time in schedule._start_times.items():
+        for op_id in schedule.start_times:
             op_name = precedence_sfg_delays.find_by_id(op_id).name
-            start_times_names[op_name] = start_time
+            start_times_names[op_name] = schedule.start_time_of_operation(op_id)
 
         assert start_times_names == {
             "IN1": 4,
-- 
GitLab