diff --git a/b_asic/scheduler.py b/b_asic/scheduler.py index 9afa14d836bb5c6c6e1f6810d161474018fbe9a0..cbd5160b62d038ff5deedcf8ec73ecae2e24f686 100644 --- a/b_asic/scheduler.py +++ b/b_asic/scheduler.py @@ -432,10 +432,17 @@ class ListScheduler(Scheduler): if self._schedule._schedule_time is not None else 0 ) + time_slot = ( + self._current_time + if self._schedule._schedule_time is None + else self._current_time % self._schedule._schedule_time + ) ready_ops = [ op_id for op_id in candidate_ids - if self._op_is_schedulable(self._sfg.find_by_id(op_id), schedule_time) + if self._op_is_schedulable( + self._sfg.find_by_id(op_id), schedule_time, time_slot + ) ] memory_reads = self._calculate_memory_reads(ready_ops) @@ -507,14 +514,14 @@ class ListScheduler(Scheduler): count += 1 return count - def _op_satisfies_resource_constraints(self, op: "Operation") -> bool: - if self._schedule._schedule_time is not None: - time_slot = self._current_time % self._schedule._schedule_time - else: - time_slot = self._current_time - op_type_name = op.type_name() - count = self._cached_execution_times_in_time[op_type_name][time_slot] - return count < self._remaining_resources[op_type_name] + def _op_satisfies_resource_constraints( + self, op: "Operation", time_slot: int + ) -> bool: + op_type = type(op) + return ( + self._cached_execution_times_in_time[op_type][time_slot] + < self._remaining_resources[op_type] + ) def _op_satisfies_concurrent_writes(self, op: "Operation") -> bool: if self._max_concurrent_writes: @@ -600,9 +607,11 @@ class ListScheduler(Scheduler): return False return True - def _op_is_schedulable(self, op: "Operation", schedule_time: int) -> bool: + def _op_is_schedulable( + self, op: "Operation", schedule_time: int, time_slot: int + ) -> bool: return ( - self._op_satisfies_resource_constraints(op) + self._op_satisfies_resource_constraints(op, time_slot) and self._op_satisfies_data_dependencies(op, schedule_time) and self._op_satisfies_concurrent_writes(op) and self._op_satisfies_concurrent_reads(op) @@ -683,7 +692,17 @@ class ListScheduler(Scheduler): f"{alap_schedule.schedule_time}." ) - self._remaining_resources = self._max_resources.copy() + used_op_types = self._sfg.get_used_operation_types() + + def find_type_from_type_name(type_name): + for op_type in used_op_types: + if op_type.type_name() == type_name: + return op_type + + self._remaining_resources = { + find_type_from_type_name(type_name): cnt + for type_name, cnt in self._max_resources.items() + } self._remaining_ops = [op.graph_id for op in self._sfg.operations] self._cached_execution_times = { @@ -691,7 +710,7 @@ class ListScheduler(Scheduler): for op_id in self._remaining_ops } self._cached_execution_times_in_time = { - op_type: defaultdict(int) for op_type in self._sfg.get_used_type_names() + op_type: defaultdict(int) for op_type in used_op_types } self._remaining_ops = [ op_id @@ -754,9 +773,7 @@ class ListScheduler(Scheduler): if self._schedule._schedule_time else self._current_time ) - self._cached_execution_times_in_time[next_op.type_name()][ - time_slot - ] += 1 + self._cached_execution_times_in_time[type(next_op)][time_slot] += 1 self._log_scheduled_op(next_op) @@ -913,8 +930,9 @@ class RecursiveListScheduler(ListScheduler): new_time = ( self._schedule._start_times[op_id] + delta ) % self._schedule.schedule_time - exec_count = self._execution_times_in_time(type(op), new_time) - while exec_count >= self._remaining_resources[op.type_name()]: + op_type = type(op) + exec_count = self._execution_times_in_time(op_type, new_time) + while exec_count >= self._remaining_resources[op_type]: delta += 1 new_time = ( self._schedule._start_times[op_id] + delta @@ -956,6 +974,7 @@ class RecursiveListScheduler(ListScheduler): self._recursive_ops = self._get_recursive_ops(loops) self._recursive_ops_set = set(self._recursive_ops) self._remaining_recursive_ops = self._recursive_ops.copy() + self._logger.debug("--- Generating initial recursive priority table ---") prio_table = self._get_recursive_priority_table() while prio_table: op = self._get_next_recursive_op(prio_table) @@ -972,10 +991,11 @@ class RecursiveListScheduler(ListScheduler): op_sched_time, source_start_time + source_port.latency_offset ) - exec_count = self._execution_times_in_time(type(op), op_sched_time) - while exec_count >= self._remaining_resources[op.type_name()]: + op_type = type(op) + exec_count = self._execution_times_in_time(op_type, op_sched_time) + while exec_count >= self._remaining_resources[op_type]: op_sched_time += 1 - exec_count = self._execution_times_in_time(type(op), op_sched_time) + exec_count = self._execution_times_in_time(op_type, op_sched_time) self._schedule.place_operation(op, op_sched_time, self._op_laps) self._op_laps[op.graph_id] = 0 @@ -990,7 +1010,7 @@ class RecursiveListScheduler(ListScheduler): if self._schedule._schedule_time else self._current_time ) - self._cached_execution_times_in_time[op.type_name()][time_slot] += 1 + self._cached_execution_times_in_time[op_type][time_slot] += 1 prio_table = self._get_recursive_priority_table() diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 616d544b57fef9482b800775c0758f31c9b0820e..6073467b15776d2c9229d259bc5a5110ee154864 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -2277,6 +2277,12 @@ class SFG(AbstractOperation): ret.sort() return ret + def get_used_operation_types(self) -> list[Operation]: + """Get a list of all Operations used in the SFG.""" + ret = list({type(op) for op in self.operations}) + ret.sort(key=lambda op: op.type_name()) + return ret + def get_used_graph_ids(self) -> set[GraphID]: """Get a list of all GraphID:s used in the SFG.""" ret = set({op.graph_id for op in self.operations})