Skip to content
Snippets Groups Projects
Commit a033e3bc authored by Simon Bjurek's avatar Simon Bjurek Committed by Oscar Gustafsson
Browse files

Added retiming step for RecursiveListScheduler

parent 8efb6349
No related branches found
No related tags found
1 merge request!502Added retiming step for RecursiveListScheduler
Pipeline #159507 passed
...@@ -81,7 +81,7 @@ class Scheduler(ABC): ...@@ -81,7 +81,7 @@ class Scheduler(ABC):
def _place_outputs_asap( def _place_outputs_asap(
self, schedule: "Schedule", non_schedulable_ops: list["GraphID"] | None = [] self, schedule: "Schedule", non_schedulable_ops: list["GraphID"] | None = []
) -> None: ) -> None:
for output in schedule.sfg.find_by_type(Output): for output in schedule._sfg.find_by_type(Output):
output = cast(Output, output) output = cast(Output, output)
source_port = cast(OutputPort, output.inputs[0].signals[0].source) source_port = cast(OutputPort, output.inputs[0].signals[0].source)
if source_port.operation.graph_id in non_schedulable_ops: if source_port.operation.graph_id in non_schedulable_ops:
...@@ -811,16 +811,12 @@ class RecursiveListScheduler(ListScheduler): ...@@ -811,16 +811,12 @@ class RecursiveListScheduler(ListScheduler):
self, self,
sort_order: tuple[tuple[int, bool], ...], sort_order: tuple[tuple[int, bool], ...],
max_resources: dict[TypeName, int] | None = None, max_resources: dict[TypeName, int] | None = None,
max_concurrent_reads: int | None = None,
max_concurrent_writes: int | None = None,
input_times: dict["GraphID", int] | None = None, input_times: dict["GraphID", int] | None = None,
output_delta_times: dict["GraphID", int] | None = None, output_delta_times: dict["GraphID", int] | None = None,
) -> None: ) -> None:
super().__init__( super().__init__(
sort_order=sort_order, sort_order=sort_order,
max_resources=max_resources, max_resources=max_resources,
max_concurrent_reads=max_concurrent_reads,
max_concurrent_writes=max_concurrent_writes,
input_times=input_times, input_times=input_times,
output_delta_times=output_delta_times, output_delta_times=output_delta_times,
) )
...@@ -846,7 +842,10 @@ class RecursiveListScheduler(ListScheduler): ...@@ -846,7 +842,10 @@ class RecursiveListScheduler(ListScheduler):
if self._schedule._schedule_time is None: if self._schedule._schedule_time is None:
self._schedule.set_schedule_time(self._schedule.get_max_end_time()) self._schedule.set_schedule_time(self._schedule.get_max_end_time())
period_bound = self._schedule._sfg.iteration_period_bound()
self._schedule.remove_delays() self._schedule.remove_delays()
if loops:
self._retime_ops(period_bound)
self._handle_dont_cares() self._handle_dont_cares()
self._schedule.sort_y_locations_on_start_times() self._schedule.sort_y_locations_on_start_times()
self._logger.debug("--- Scheduling completed ---") self._logger.debug("--- Scheduling completed ---")
...@@ -884,6 +883,73 @@ class RecursiveListScheduler(ListScheduler): ...@@ -884,6 +883,73 @@ class RecursiveListScheduler(ListScheduler):
] ]
return [(op_id, self._deadlines[op_id]) for op_id in ready_ops] return [(op_id, self._deadlines[op_id]) for op_id in ready_ops]
def _retime_ops(self, period_bound: int) -> None:
# calculate the time goal
time_goal = period_bound
for type_name, amount_of_resource in self._max_resources.items():
if type_name in ("out", "in"):
continue
time_required = self._schedule._sfg.resource_lower_bound(
type_name, time_goal
)
while (
self._schedule._sfg.resource_lower_bound(type_name, time_required)
> amount_of_resource
):
time_required += 1
time_goal = max(time_goal, time_required)
# retiming loop
time_out_counter = 100
while self._schedule._schedule_time > time_goal and time_out_counter > 0:
sorted_op_ids = sorted(
self._schedule._start_times,
key=self._schedule._start_times.get,
reverse=True,
)
# move all operations forward to the next valid step and check if period can be reduced
for op_id in sorted_op_ids:
op = self._schedule._sfg.find_by_id(op_id)
if self._schedule.forward_slack(op_id):
delta = 1
new_time = (
self._schedule._start_times[op_id] + delta
) % self._schedule.schedule_time
exec_count = self._execution_times_in_time(op, new_time)
while exec_count >= self._remaining_resources[op.type_name()]:
delta += 1
new_time = (
self._schedule._start_times[op_id] + delta
) % self._schedule.schedule_time
exec_count = self._execution_times_in_time(op, new_time)
if delta > self._schedule.forward_slack(op_id):
continue
self._schedule.move_operation(op_id, delta)
# adjust time if a gap exists on the right side of the schedule
self._schedule._schedule_time = min(
self._schedule._schedule_time, self._schedule.get_max_end_time()
)
# adjust time if a gap exists on the left side of the schedule
slack = min(self._schedule._start_times.values())
for other_op_id in sorted_op_ids:
op = self._schedule._sfg.find_by_id(op_id)
max_end_time = 0
op_start_time = self._schedule._start_times[other_op_id]
for outport in op.outputs:
max_end_time = max(
max_end_time,
op_start_time + cast(int, outport.latency_offset),
)
if max_end_time > self._schedule._schedule_time:
slack = min(slack, self._schedule.forward_slack(other_op_id))
for op_id in self._schedule._start_times:
self._schedule._start_times[op_id] -= slack
self._schedule._schedule_time = self._schedule._schedule_time - slack
time_out_counter -= 1
def _schedule_recursive_ops(self, loops: list[list["GraphID"]]) -> None: def _schedule_recursive_ops(self, loops: list[list["GraphID"]]) -> None:
saved_sched_time = self._schedule._schedule_time saved_sched_time = self._schedule._schedule_time
self._schedule._schedule_time = None self._schedule._schedule_time = None
......
...@@ -1853,6 +1853,10 @@ class TestRecursiveListScheduler: ...@@ -1853,6 +1853,10 @@ class TestRecursiveListScheduler:
), ),
) )
_validate_recreated_sfg_filter(sfg, schedule) _validate_recreated_sfg_filter(sfg, schedule)
assert schedule.schedule_time == sfg.iteration_period_bound()
for op_id in schedule.start_times:
assert schedule.backward_slack(op_id) >= 0
assert schedule.forward_slack(op_id) >= 0
def test_direct_form_2_iir(self): def test_direct_form_2_iir(self):
N = 3 N = 3
...@@ -1878,6 +1882,10 @@ class TestRecursiveListScheduler: ...@@ -1878,6 +1882,10 @@ class TestRecursiveListScheduler:
), ),
) )
_validate_recreated_sfg_filter(sfg, schedule) _validate_recreated_sfg_filter(sfg, schedule)
assert schedule.schedule_time == sfg.iteration_period_bound()
for op_id in schedule.start_times:
assert schedule.backward_slack(op_id) >= 0
assert schedule.forward_slack(op_id) >= 0
def test_large_direct_form_2_iir(self): def test_large_direct_form_2_iir(self):
N = 8 N = 8
...@@ -1903,6 +1911,9 @@ class TestRecursiveListScheduler: ...@@ -1903,6 +1911,9 @@ class TestRecursiveListScheduler:
), ),
) )
_validate_recreated_sfg_filter(sfg, schedule) _validate_recreated_sfg_filter(sfg, schedule)
for op_id in schedule.start_times:
assert schedule.backward_slack(op_id) >= 0
assert schedule.forward_slack(op_id) >= 0
def test_custom_recursive_filter(self): def test_custom_recursive_filter(self):
# Create the SFG for a digital filter (seen in an exam question from TSTE87). # Create the SFG for a digital filter (seen in an exam question from TSTE87).
...@@ -1939,6 +1950,10 @@ class TestRecursiveListScheduler: ...@@ -1939,6 +1950,10 @@ class TestRecursiveListScheduler:
), ),
) )
_validate_recreated_sfg_filter(sfg, schedule) _validate_recreated_sfg_filter(sfg, schedule)
assert schedule.schedule_time == 4 # all slots filled with cmul executions
for op_id in schedule.start_times:
assert schedule.backward_slack(op_id) >= 0
assert schedule.forward_slack(op_id) >= 0
def _validate_recreated_sfg_filter(sfg: SFG, schedule: Schedule) -> None: def _validate_recreated_sfg_filter(sfg: SFG, schedule: Schedule) -> None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment