From fb1f6a0cd6ab68b4d78c28d8ceb6122ff5cbd93e Mon Sep 17 00:00:00 2001 From: Oscar Gustafsson <oscar.gustafsson@gmail.com> Date: Mon, 14 Apr 2025 13:39:11 +0200 Subject: [PATCH] Improve scheduling performance --- b_asic/scheduler.py | 102 +++++++++++++++++++++++--------------------- 1 file changed, 53 insertions(+), 49 deletions(-) diff --git a/b_asic/scheduler.py b/b_asic/scheduler.py index 13193a4a..9afa14d8 100644 --- a/b_asic/scheduler.py +++ b/b_asic/scheduler.py @@ -397,6 +397,7 @@ class ListScheduler(Scheduler): self._remaining_ops = [ op_id for op_id in self._remaining_ops if op_id not in self._input_times ] + self._remaining_ops_set = set(self._remaining_ops) self._schedule_nonrecursive_ops() @@ -423,11 +424,18 @@ class ListScheduler(Scheduler): sorted_table = sorted(priority_table, key=sort_key) return sorted_table[0][0] - def _get_priority_table(self) -> list[tuple["GraphID", int, int, int]]: + def _get_priority_table( + self, candidate_ids + ) -> list[tuple["GraphID", int, int, int]]: + schedule_time = ( + self._schedule._schedule_time + if self._schedule._schedule_time is not None + else 0 + ) ready_ops = [ op_id - for op_id in self._remaining_ops - if self._op_is_schedulable(self._sfg.find_by_id(op_id)) + for op_id in candidate_ids + if self._op_is_schedulable(self._sfg.find_by_id(op_id), schedule_time) ] memory_reads = self._calculate_memory_reads(ready_ops) @@ -504,15 +512,15 @@ class ListScheduler(Scheduler): time_slot = self._current_time % self._schedule._schedule_time else: time_slot = self._current_time - count = self._cached_execution_times_in_time[op.type_name()][time_slot] - return count < self._remaining_resources[op.type_name()] + op_type_name = op.type_name() + count = self._cached_execution_times_in_time[op_type_name][time_slot] + return count < self._remaining_resources[op_type_name] def _op_satisfies_concurrent_writes(self, op: "Operation") -> bool: if self._max_concurrent_writes: tmp_used_writes = {} if not isinstance(op, Output): for output_port in op.outputs: - output_ready_time = self._current_time + output_port.latency_offset if self._schedule._schedule_time: output_ready_time %= self._schedule._schedule_time @@ -568,7 +576,9 @@ class ListScheduler(Scheduler): return False return True - def _op_satisfies_data_dependencies(self, op: "Operation") -> bool: + def _op_satisfies_data_dependencies( + self, op: "Operation", schedule_time: int + ) -> bool: for op_input in op.inputs: source_port = op_input.signals[0].source source_op = source_port.operation @@ -576,30 +586,24 @@ class ListScheduler(Scheduler): if isinstance(source_op, (Delay, DontCare)): continue - if source_op.graph_id in self._remaining_ops: + if source_op.graph_id in self._remaining_ops_set: return False - if self._schedule._schedule_time is not None: - available_time = ( - self._schedule.start_times[source_op.graph_id] - + self._op_laps[source_op.graph_id] * self._schedule._schedule_time - + source_port.latency_offset - ) - else: - available_time = ( - self._schedule.start_times[source_op.graph_id] - + source_port.latency_offset - ) + available_time = ( + self._schedule.start_times[source_op.graph_id] + + self._op_laps[source_op.graph_id] * schedule_time + + source_port.latency_offset + ) required_time = self._current_time + op_input.latency_offset if available_time > required_time: return False return True - def _op_is_schedulable(self, op: "Operation") -> bool: + def _op_is_schedulable(self, op: "Operation", schedule_time: int) -> bool: return ( - self._op_satisfies_data_dependencies(op) - and self._op_satisfies_resource_constraints(op) + self._op_satisfies_resource_constraints(op) + and self._op_satisfies_data_dependencies(op, schedule_time) and self._op_satisfies_concurrent_writes(op) and self._op_satisfies_concurrent_reads(op) ) @@ -692,12 +696,7 @@ class ListScheduler(Scheduler): self._remaining_ops = [ op_id for op_id in self._remaining_ops - if not isinstance(self._sfg.find_by_id(op_id), DontCare) - ] - self._remaining_ops = [ - op_id - for op_id in self._remaining_ops - if not isinstance(self._sfg.find_by_id(op_id), Delay) + if not isinstance(self._sfg.find_by_id(op_id), (Delay, DontCare)) ] self._remaining_ops = [ op_id @@ -715,6 +714,7 @@ class ListScheduler(Scheduler): f"Missing operation: {op_id}." ) + self._remaining_ops_set = set(self._remaining_ops) self._deadlines = self._calculate_deadlines() self._output_slacks = self._calculate_alap_output_slacks() self._fan_outs = self._calculate_fan_outs() @@ -727,20 +727,22 @@ class ListScheduler(Scheduler): def _schedule_nonrecursive_ops(self) -> None: self._logger.debug("--- Non-Recursive Operation scheduling starting ---") while self._remaining_ops: - prio_table = self._get_priority_table() + prio_table = self._get_priority_table(self._remaining_ops) while prio_table: - next_op = self._sfg.find_by_id(self._get_next_op_id(prio_table)) + next_op_id = self._get_next_op_id(prio_table) + next_op = self._sfg.find_by_id(next_op_id) self._update_port_reads(next_op) self._remaining_ops = [ - op_id for op_id in self._remaining_ops if op_id != next_op.graph_id + op_id for op_id in self._remaining_ops if op_id != next_op_id ] + self._remaining_ops_set.remove(next_op_id) self._schedule.place_operation( next_op, self._current_time, self._op_laps ) - self._op_laps[next_op.graph_id] = ( + self._op_laps[next_op_id] = ( (self._current_time) // self._schedule._schedule_time if self._schedule._schedule_time else 0 @@ -758,7 +760,10 @@ class ListScheduler(Scheduler): self._log_scheduled_op(next_op) - prio_table = self._get_priority_table() + prio_table = self._get_priority_table( + self._remaining_ops + # [r[0] for r in prio_table if r[0] != next_op_id] + ) self._current_time += 1 self._current_time -= 1 @@ -821,6 +826,7 @@ class RecursiveListScheduler(ListScheduler): self._remaining_ops = [ op_id for op_id in self._remaining_ops if op_id not in self._input_times ] + self._remaining_ops_set = set(self._remaining_ops) loops = self._sfg.loops if loops: @@ -948,6 +954,7 @@ class RecursiveListScheduler(ListScheduler): self._logger.debug("--- Scheduling of recursive loops starting ---") self._recursive_ops = self._get_recursive_ops(loops) + self._recursive_ops_set = set(self._recursive_ops) self._remaining_recursive_ops = self._recursive_ops.copy() prio_table = self._get_recursive_priority_table() while prio_table: @@ -975,6 +982,7 @@ class RecursiveListScheduler(ListScheduler): self._logger.debug(f" Op: {op.graph_id} time: {op_sched_time}") self._remaining_recursive_ops.remove(op.graph_id) self._remaining_ops.remove(op.graph_id) + self._remaining_ops_set.remove(op.graph_id) for i in range(max(1, op.execution_time)): time_slot = ( @@ -1006,16 +1014,18 @@ class RecursiveListScheduler(ListScheduler): source_op = signal.source.operation if ( not isinstance(source_op, Delay) - and source_op.graph_id not in self._recursive_ops + and source_op.graph_id not in self._recursive_ops_set ): # non-recursive to recursive edge found -> pipeline self._schedule.laps[signal.graph_id] += 1 - def _op_satisfies_data_dependencies(self, op: "Operation") -> bool: + def _op_satisfies_data_dependencies( + self, op: "Operation", schedule_time: int + ) -> bool: for output_port in op.outputs: destination_port = output_port.signals[0].destination destination_op = destination_port.operation - if destination_op.graph_id not in self._remaining_ops: + if destination_op.graph_id not in self._remaining_ops_set: if isinstance(destination_op, Delay): continue # spotted a recursive operation -> check if ok @@ -1024,7 +1034,7 @@ class RecursiveListScheduler(ListScheduler): ) usage_time = ( self._schedule.start_times[destination_op.graph_id] - + self._schedule._schedule_time + + schedule_time * self._schedule.laps[output_port.signals[0].graph_id] ) if op_available_time > usage_time: @@ -1037,19 +1047,13 @@ class RecursiveListScheduler(ListScheduler): source_op = source_port.operation if isinstance(source_op, (Delay, DontCare)): continue - if source_op.graph_id in self._remaining_ops: + if source_op.graph_id in self._remaining_ops_set: return False - if self._schedule._schedule_time is not None: - available_time = ( - self._schedule.start_times.get(source_op.graph_id) - + self._op_laps[source_op.graph_id] * self._schedule._schedule_time - + source_port.latency_offset - ) - else: - available_time = ( - self._schedule.start_times.get(source_op.graph_id) - + source_port.latency_offset - ) + available_time = ( + self._schedule.start_times.get(source_op.graph_id) + + self._op_laps[source_op.graph_id] * schedule_time + + source_port.latency_offset + ) required_time = self._current_time + op_input.latency_offset if available_time > required_time: return False -- GitLab