From ee87d701553abe3fa965df0bebe190376c992575 Mon Sep 17 00:00:00 2001
From: Simon Bjurek <simbj106@student.liu.se>
Date: Mon, 24 Mar 2025 10:12:23 +0100
Subject: [PATCH] added retiming step for RecursiveListScheduler

---
 b_asic/scheduler.py               | 76 +++++++++++++++++++++++++++++--
 test/unit/test_list_schedulers.py | 15 ++++++
 2 files changed, 86 insertions(+), 5 deletions(-)

diff --git a/b_asic/scheduler.py b/b_asic/scheduler.py
index b81879df..306cdccd 100644
--- a/b_asic/scheduler.py
+++ b/b_asic/scheduler.py
@@ -81,7 +81,7 @@ class Scheduler(ABC):
     def _place_outputs_asap(
         self, schedule: "Schedule", non_schedulable_ops: list["GraphID"] | None = []
     ) -> None:
-        for output in schedule.sfg.find_by_type(Output):
+        for output in schedule._sfg.find_by_type(Output):
             output = cast(Output, output)
             source_port = cast(OutputPort, output.inputs[0].signals[0].source)
             if source_port.operation.graph_id in non_schedulable_ops:
@@ -811,16 +811,12 @@ class RecursiveListScheduler(ListScheduler):
         self,
         sort_order: tuple[tuple[int, bool], ...],
         max_resources: dict[TypeName, int] | None = None,
-        max_concurrent_reads: int | None = None,
-        max_concurrent_writes: int | None = None,
         input_times: dict["GraphID", int] | None = None,
         output_delta_times: dict["GraphID", int] | None = None,
     ) -> None:
         super().__init__(
             sort_order=sort_order,
             max_resources=max_resources,
-            max_concurrent_reads=max_concurrent_reads,
-            max_concurrent_writes=max_concurrent_writes,
             input_times=input_times,
             output_delta_times=output_delta_times,
         )
@@ -846,7 +842,10 @@ class RecursiveListScheduler(ListScheduler):
 
         if self._schedule._schedule_time is None:
             self._schedule.set_schedule_time(self._schedule.get_max_end_time())
+        period_bound = self._schedule._sfg.iteration_period_bound()
         self._schedule.remove_delays()
+        if loops:
+            self._retime_ops(period_bound)
         self._handle_dont_cares()
         self._schedule.sort_y_locations_on_start_times()
         self._logger.debug("--- Scheduling completed ---")
@@ -884,6 +883,73 @@ class RecursiveListScheduler(ListScheduler):
         ]
         return [(op_id, self._deadlines[op_id]) for op_id in ready_ops]
 
+    def _retime_ops(self, period_bound: int) -> None:
+        # calculate the time goal
+        time_goal = period_bound
+        for type_name, amount_of_resource in self._max_resources.items():
+            if type_name in ("out", "in"):
+                continue
+            time_required = self._schedule._sfg.resource_lower_bound(
+                type_name, time_goal
+            )
+            while (
+                self._schedule._sfg.resource_lower_bound(type_name, time_required)
+                > amount_of_resource
+            ):
+                time_required += 1
+            time_goal = max(time_goal, time_required)
+
+        # retiming loop
+        time_out_counter = 100
+        while self._schedule._schedule_time > time_goal and time_out_counter > 0:
+            sorted_op_ids = sorted(
+                self._schedule._start_times,
+                key=self._schedule._start_times.get,
+                reverse=True,
+            )
+            # move all operations forward to the next valid step and check if period can be reduced
+            for op_id in sorted_op_ids:
+                op = self._schedule._sfg.find_by_id(op_id)
+                if self._schedule.forward_slack(op_id):
+                    delta = 1
+                    new_time = (
+                        self._schedule._start_times[op_id] + delta
+                    ) % self._schedule.schedule_time
+                    exec_count = self._execution_times_in_time(op, new_time)
+                    while exec_count >= self._remaining_resources[op.type_name()]:
+                        delta += 1
+                        new_time = (
+                            self._schedule._start_times[op_id] + delta
+                        ) % self._schedule.schedule_time
+                        exec_count = self._execution_times_in_time(op, new_time)
+                    if delta > self._schedule.forward_slack(op_id):
+                        continue
+                    self._schedule.move_operation(op_id, delta)
+
+                # adjust time if a gap exists on the right side of the schedule
+                self._schedule._schedule_time = min(
+                    self._schedule._schedule_time, self._schedule.get_max_end_time()
+                )
+
+                # adjust time if a gap exists on the left side of the schedule
+                slack = min(self._schedule._start_times.values())
+                for other_op_id in sorted_op_ids:
+                    op = self._schedule._sfg.find_by_id(op_id)
+                    max_end_time = 0
+                    op_start_time = self._schedule._start_times[other_op_id]
+                    for outport in op.outputs:
+                        max_end_time = max(
+                            max_end_time,
+                            op_start_time + cast(int, outport.latency_offset),
+                        )
+                    if max_end_time > self._schedule._schedule_time:
+                        slack = min(slack, self._schedule.forward_slack(other_op_id))
+                for op_id in self._schedule._start_times:
+                    self._schedule._start_times[op_id] -= slack
+                self._schedule._schedule_time = self._schedule._schedule_time - slack
+
+            time_out_counter -= 1
+
     def _schedule_recursive_ops(self, loops: list[list["GraphID"]]) -> None:
         saved_sched_time = self._schedule._schedule_time
         self._schedule._schedule_time = None
diff --git a/test/unit/test_list_schedulers.py b/test/unit/test_list_schedulers.py
index edff8a5d..9dbe448c 100644
--- a/test/unit/test_list_schedulers.py
+++ b/test/unit/test_list_schedulers.py
@@ -1853,6 +1853,10 @@ class TestRecursiveListScheduler:
             ),
         )
         _validate_recreated_sfg_filter(sfg, schedule)
+        assert schedule.schedule_time == sfg.iteration_period_bound()
+        for op_id in schedule.start_times:
+            assert schedule.backward_slack(op_id) >= 0
+            assert schedule.forward_slack(op_id) >= 0
 
     def test_direct_form_2_iir(self):
         N = 3
@@ -1878,6 +1882,10 @@ class TestRecursiveListScheduler:
             ),
         )
         _validate_recreated_sfg_filter(sfg, schedule)
+        assert schedule.schedule_time == sfg.iteration_period_bound()
+        for op_id in schedule.start_times:
+            assert schedule.backward_slack(op_id) >= 0
+            assert schedule.forward_slack(op_id) >= 0
 
     def test_large_direct_form_2_iir(self):
         N = 8
@@ -1903,6 +1911,9 @@ class TestRecursiveListScheduler:
             ),
         )
         _validate_recreated_sfg_filter(sfg, schedule)
+        for op_id in schedule.start_times:
+            assert schedule.backward_slack(op_id) >= 0
+            assert schedule.forward_slack(op_id) >= 0
 
     def test_custom_recursive_filter(self):
         # Create the SFG for a digital filter (seen in an exam question from TSTE87).
@@ -1939,6 +1950,10 @@ class TestRecursiveListScheduler:
             ),
         )
         _validate_recreated_sfg_filter(sfg, schedule)
+        assert schedule.schedule_time == 4  # all slots filled with cmul executions
+        for op_id in schedule.start_times:
+            assert schedule.backward_slack(op_id) >= 0
+            assert schedule.forward_slack(op_id) >= 0
 
 
 def _validate_recreated_sfg_filter(sfg: SFG, schedule: Schedule) -> None:
-- 
GitLab