From 2b07f32fa2d219b2993b10c0e1daf5295f5e9b56 Mon Sep 17 00:00:00 2001
From: Simon Bjurek <simbj106@student.liu.se>
Date: Fri, 11 Apr 2025 15:44:04 +0200
Subject: [PATCH] Avoid unneccecary sort_y_locations in Schedulers

---
 b_asic/scheduler.py | 23 +++++++++++++++++------
 1 file changed, 17 insertions(+), 6 deletions(-)

diff --git a/b_asic/scheduler.py b/b_asic/scheduler.py
index 300ff625..65f86e25 100644
--- a/b_asic/scheduler.py
+++ b/b_asic/scheduler.py
@@ -19,6 +19,7 @@ class Scheduler(ABC):
         self,
         input_times: dict["GraphID", int] | None = None,
         output_delta_times: dict["GraphID", int] | None = None,
+        sort_y_direction: bool = True,
     ):
         self._logger = logger.getLogger("scheduler")
         self._op_laps = {}
@@ -57,6 +58,8 @@ class Scheduler(ABC):
         else:
             self._output_delta_times = {}
 
+        self._sort_y_direction = sort_y_direction
+
     @abstractmethod
     def apply_scheduling(self, schedule: "Schedule") -> None:
         """Applies the scheduling algorithm on the given Schedule.
@@ -246,7 +249,8 @@ class ASAPScheduler(Scheduler):
         elif schedule._schedule_time < max_end_time:
             raise ValueError(f"Too short schedule time. Minimum is {max_end_time}.")
 
-        schedule.sort_y_locations_on_start_times()
+        if self._sort_y_direction:
+            schedule.sort_y_locations_on_start_times()
 
 
 class ALAPScheduler(Scheduler):
@@ -265,6 +269,7 @@ class ALAPScheduler(Scheduler):
         ASAPScheduler(
             self._input_times,
             self._output_delta_times,
+            False,
         ).apply_scheduling(schedule)
         self._op_laps = {}
 
@@ -302,7 +307,8 @@ class ALAPScheduler(Scheduler):
             schedule.move_operation(op_id, -slack)
         schedule.set_schedule_time(schedule._schedule_time - slack)
 
-        schedule.sort_y_locations_on_start_times()
+        if self._sort_y_direction:
+            schedule.sort_y_locations_on_start_times()
 
 
 class ListScheduler(Scheduler):
@@ -335,8 +341,9 @@ class ListScheduler(Scheduler):
         max_concurrent_writes: int | None = None,
         input_times: dict["GraphID", int] | None = None,
         output_delta_times: dict["GraphID", int] | None = None,
+        sort_y_locations: bool = True,
     ) -> None:
-        super().__init__(input_times, output_delta_times)
+        super().__init__(input_times, output_delta_times, sort_y_locations)
         self._sort_order = sort_order
 
         if max_resources is not None:
@@ -399,7 +406,8 @@ class ListScheduler(Scheduler):
             self._schedule.set_schedule_time(self._schedule.get_max_end_time())
         self._schedule.remove_delays()
         self._handle_dont_cares()
-        self._schedule.sort_y_locations_on_start_times()
+        if self._sort_y_direction:
+            schedule.sort_y_locations_on_start_times()
         self._logger.debug("--- Scheduling completed ---")
 
     def _get_next_op_id(
@@ -675,7 +683,9 @@ class ListScheduler(Scheduler):
 
         alap_schedule = copy.copy(self._schedule)
         alap_schedule._schedule_time = None
-        alap_scheduler = ALAPScheduler(self._input_times, self._output_delta_times)
+        alap_scheduler = ALAPScheduler(
+            self._input_times, self._output_delta_times, False
+        )
         alap_scheduler.apply_scheduling(alap_schedule)
         self._alap_start_times = alap_schedule.start_times
         self._alap_op_laps = alap_scheduler._op_laps
@@ -847,7 +857,8 @@ class RecursiveListScheduler(ListScheduler):
         if loops:
             self._retime_ops(period_bound)
         self._handle_dont_cares()
-        self._schedule.sort_y_locations_on_start_times()
+        if self._sort_y_direction:
+            schedule.sort_y_locations_on_start_times()
         self._logger.debug("--- Scheduling completed ---")
 
     def _get_recursive_ops(self, loops: list[list["GraphID"]]) -> list["GraphID"]:
-- 
GitLab