From fb1f6a0cd6ab68b4d78c28d8ceb6122ff5cbd93e Mon Sep 17 00:00:00 2001
From: Oscar Gustafsson <oscar.gustafsson@gmail.com>
Date: Mon, 14 Apr 2025 13:39:11 +0200
Subject: [PATCH] Improve scheduling performance

---
 b_asic/scheduler.py | 102 +++++++++++++++++++++++---------------------
 1 file changed, 53 insertions(+), 49 deletions(-)

diff --git a/b_asic/scheduler.py b/b_asic/scheduler.py
index 13193a4a..9afa14d8 100644
--- a/b_asic/scheduler.py
+++ b/b_asic/scheduler.py
@@ -397,6 +397,7 @@ class ListScheduler(Scheduler):
             self._remaining_ops = [
                 op_id for op_id in self._remaining_ops if op_id not in self._input_times
             ]
+            self._remaining_ops_set = set(self._remaining_ops)
 
         self._schedule_nonrecursive_ops()
 
@@ -423,11 +424,18 @@ class ListScheduler(Scheduler):
         sorted_table = sorted(priority_table, key=sort_key)
         return sorted_table[0][0]
 
-    def _get_priority_table(self) -> list[tuple["GraphID", int, int, int]]:
+    def _get_priority_table(
+        self, candidate_ids
+    ) -> list[tuple["GraphID", int, int, int]]:
+        schedule_time = (
+            self._schedule._schedule_time
+            if self._schedule._schedule_time is not None
+            else 0
+        )
         ready_ops = [
             op_id
-            for op_id in self._remaining_ops
-            if self._op_is_schedulable(self._sfg.find_by_id(op_id))
+            for op_id in candidate_ids
+            if self._op_is_schedulable(self._sfg.find_by_id(op_id), schedule_time)
         ]
         memory_reads = self._calculate_memory_reads(ready_ops)
 
@@ -504,15 +512,15 @@ class ListScheduler(Scheduler):
             time_slot = self._current_time % self._schedule._schedule_time
         else:
             time_slot = self._current_time
-        count = self._cached_execution_times_in_time[op.type_name()][time_slot]
-        return count < self._remaining_resources[op.type_name()]
+        op_type_name = op.type_name()
+        count = self._cached_execution_times_in_time[op_type_name][time_slot]
+        return count < self._remaining_resources[op_type_name]
 
     def _op_satisfies_concurrent_writes(self, op: "Operation") -> bool:
         if self._max_concurrent_writes:
             tmp_used_writes = {}
             if not isinstance(op, Output):
                 for output_port in op.outputs:
-
                     output_ready_time = self._current_time + output_port.latency_offset
                     if self._schedule._schedule_time:
                         output_ready_time %= self._schedule._schedule_time
@@ -568,7 +576,9 @@ class ListScheduler(Scheduler):
                         return False
         return True
 
-    def _op_satisfies_data_dependencies(self, op: "Operation") -> bool:
+    def _op_satisfies_data_dependencies(
+        self, op: "Operation", schedule_time: int
+    ) -> bool:
         for op_input in op.inputs:
             source_port = op_input.signals[0].source
             source_op = source_port.operation
@@ -576,30 +586,24 @@ class ListScheduler(Scheduler):
             if isinstance(source_op, (Delay, DontCare)):
                 continue
 
-            if source_op.graph_id in self._remaining_ops:
+            if source_op.graph_id in self._remaining_ops_set:
                 return False
 
-            if self._schedule._schedule_time is not None:
-                available_time = (
-                    self._schedule.start_times[source_op.graph_id]
-                    + self._op_laps[source_op.graph_id] * self._schedule._schedule_time
-                    + source_port.latency_offset
-                )
-            else:
-                available_time = (
-                    self._schedule.start_times[source_op.graph_id]
-                    + source_port.latency_offset
-                )
+            available_time = (
+                self._schedule.start_times[source_op.graph_id]
+                + self._op_laps[source_op.graph_id] * schedule_time
+                + source_port.latency_offset
+            )
 
             required_time = self._current_time + op_input.latency_offset
             if available_time > required_time:
                 return False
         return True
 
-    def _op_is_schedulable(self, op: "Operation") -> bool:
+    def _op_is_schedulable(self, op: "Operation", schedule_time: int) -> bool:
         return (
-            self._op_satisfies_data_dependencies(op)
-            and self._op_satisfies_resource_constraints(op)
+            self._op_satisfies_resource_constraints(op)
+            and self._op_satisfies_data_dependencies(op, schedule_time)
             and self._op_satisfies_concurrent_writes(op)
             and self._op_satisfies_concurrent_reads(op)
         )
@@ -692,12 +696,7 @@ class ListScheduler(Scheduler):
         self._remaining_ops = [
             op_id
             for op_id in self._remaining_ops
-            if not isinstance(self._sfg.find_by_id(op_id), DontCare)
-        ]
-        self._remaining_ops = [
-            op_id
-            for op_id in self._remaining_ops
-            if not isinstance(self._sfg.find_by_id(op_id), Delay)
+            if not isinstance(self._sfg.find_by_id(op_id), (Delay, DontCare))
         ]
         self._remaining_ops = [
             op_id
@@ -715,6 +714,7 @@ class ListScheduler(Scheduler):
                     f"Missing operation: {op_id}."
                 )
 
+        self._remaining_ops_set = set(self._remaining_ops)
         self._deadlines = self._calculate_deadlines()
         self._output_slacks = self._calculate_alap_output_slacks()
         self._fan_outs = self._calculate_fan_outs()
@@ -727,20 +727,22 @@ class ListScheduler(Scheduler):
     def _schedule_nonrecursive_ops(self) -> None:
         self._logger.debug("--- Non-Recursive Operation scheduling starting ---")
         while self._remaining_ops:
-            prio_table = self._get_priority_table()
+            prio_table = self._get_priority_table(self._remaining_ops)
             while prio_table:
-                next_op = self._sfg.find_by_id(self._get_next_op_id(prio_table))
+                next_op_id = self._get_next_op_id(prio_table)
+                next_op = self._sfg.find_by_id(next_op_id)
 
                 self._update_port_reads(next_op)
 
                 self._remaining_ops = [
-                    op_id for op_id in self._remaining_ops if op_id != next_op.graph_id
+                    op_id for op_id in self._remaining_ops if op_id != next_op_id
                 ]
+                self._remaining_ops_set.remove(next_op_id)
 
                 self._schedule.place_operation(
                     next_op, self._current_time, self._op_laps
                 )
-                self._op_laps[next_op.graph_id] = (
+                self._op_laps[next_op_id] = (
                     (self._current_time) // self._schedule._schedule_time
                     if self._schedule._schedule_time
                     else 0
@@ -758,7 +760,10 @@ class ListScheduler(Scheduler):
 
                 self._log_scheduled_op(next_op)
 
-                prio_table = self._get_priority_table()
+                prio_table = self._get_priority_table(
+                    self._remaining_ops
+                    # [r[0] for r in prio_table if r[0] != next_op_id]
+                )
 
             self._current_time += 1
         self._current_time -= 1
@@ -821,6 +826,7 @@ class RecursiveListScheduler(ListScheduler):
             self._remaining_ops = [
                 op_id for op_id in self._remaining_ops if op_id not in self._input_times
             ]
+            self._remaining_ops_set = set(self._remaining_ops)
 
         loops = self._sfg.loops
         if loops:
@@ -948,6 +954,7 @@ class RecursiveListScheduler(ListScheduler):
 
         self._logger.debug("--- Scheduling of recursive loops starting ---")
         self._recursive_ops = self._get_recursive_ops(loops)
+        self._recursive_ops_set = set(self._recursive_ops)
         self._remaining_recursive_ops = self._recursive_ops.copy()
         prio_table = self._get_recursive_priority_table()
         while prio_table:
@@ -975,6 +982,7 @@ class RecursiveListScheduler(ListScheduler):
             self._logger.debug(f"   Op: {op.graph_id} time: {op_sched_time}")
             self._remaining_recursive_ops.remove(op.graph_id)
             self._remaining_ops.remove(op.graph_id)
+            self._remaining_ops_set.remove(op.graph_id)
 
             for i in range(max(1, op.execution_time)):
                 time_slot = (
@@ -1006,16 +1014,18 @@ class RecursiveListScheduler(ListScheduler):
                 source_op = signal.source.operation
                 if (
                     not isinstance(source_op, Delay)
-                    and source_op.graph_id not in self._recursive_ops
+                    and source_op.graph_id not in self._recursive_ops_set
                 ):
                     # non-recursive to recursive edge found -> pipeline
                     self._schedule.laps[signal.graph_id] += 1
 
-    def _op_satisfies_data_dependencies(self, op: "Operation") -> bool:
+    def _op_satisfies_data_dependencies(
+        self, op: "Operation", schedule_time: int
+    ) -> bool:
         for output_port in op.outputs:
             destination_port = output_port.signals[0].destination
             destination_op = destination_port.operation
-            if destination_op.graph_id not in self._remaining_ops:
+            if destination_op.graph_id not in self._remaining_ops_set:
                 if isinstance(destination_op, Delay):
                     continue
                 # spotted a recursive operation -> check if ok
@@ -1024,7 +1034,7 @@ class RecursiveListScheduler(ListScheduler):
                 )
                 usage_time = (
                     self._schedule.start_times[destination_op.graph_id]
-                    + self._schedule._schedule_time
+                    + schedule_time
                     * self._schedule.laps[output_port.signals[0].graph_id]
                 )
                 if op_available_time > usage_time:
@@ -1037,19 +1047,13 @@ class RecursiveListScheduler(ListScheduler):
             source_op = source_port.operation
             if isinstance(source_op, (Delay, DontCare)):
                 continue
-            if source_op.graph_id in self._remaining_ops:
+            if source_op.graph_id in self._remaining_ops_set:
                 return False
-            if self._schedule._schedule_time is not None:
-                available_time = (
-                    self._schedule.start_times.get(source_op.graph_id)
-                    + self._op_laps[source_op.graph_id] * self._schedule._schedule_time
-                    + source_port.latency_offset
-                )
-            else:
-                available_time = (
-                    self._schedule.start_times.get(source_op.graph_id)
-                    + source_port.latency_offset
-                )
+            available_time = (
+                self._schedule.start_times.get(source_op.graph_id)
+                + self._op_laps[source_op.graph_id] * schedule_time
+                + source_port.latency_offset
+            )
             required_time = self._current_time + op_input.latency_offset
             if available_time > required_time:
                 return False
-- 
GitLab