diff --git a/b_asic/scheduler.py b/b_asic/scheduler.py index 5bff185cebec4523706d451f59c4fca4b8fe1388..04f14bffb03e2354b28e645d661aea047f448bfa 100644 --- a/b_asic/scheduler.py +++ b/b_asic/scheduler.py @@ -1,5 +1,6 @@ import copy from abc import ABC, abstractmethod +from collections import defaultdict from typing import TYPE_CHECKING, cast import b_asic.logger as logger @@ -484,19 +485,18 @@ class ListScheduler(Scheduler): op_reads[op_id] = reads return op_reads - def _execution_times_in_time(self, op: "Operation", time: int) -> int: + def _execution_times_in_time(self, op_type: "Operation", time: int) -> int: count = 0 for other_op_id, start_time in self._schedule.start_times.items(): - if other_op_id != op._graph_id: - if self._schedule._schedule_time is not None: - start_time = start_time % self._schedule._schedule_time - if ( - time >= start_time - and time - < start_time + max(self._cached_execution_times[other_op_id], 1) - and isinstance(self._sfg.find_by_id(other_op_id), type(op)) - ): - count += 1 + if self._schedule._schedule_time is not None: + start_time = start_time % self._schedule._schedule_time + if ( + time >= start_time + and time + < start_time + max(self._cached_execution_times[other_op_id], 1) + and isinstance(self._sfg.find_by_id(other_op_id), op_type) + ): + count += 1 return count def _op_satisfies_resource_constraints(self, op: "Operation") -> bool: @@ -504,7 +504,7 @@ class ListScheduler(Scheduler): time_slot = self._current_time % self._schedule._schedule_time else: time_slot = self._current_time - count = self._execution_times_in_time(op, time_slot) + count = self._cached_execution_times_in_time[op.type_name()][time_slot] return count < self._remaining_resources[op.type_name()] def _op_satisfies_concurrent_writes(self, op: "Operation") -> bool: @@ -686,6 +686,9 @@ class ListScheduler(Scheduler): op_id: self._sfg.find_by_id(op_id).execution_time for op_id in self._remaining_ops } + self._cached_execution_times_in_time = { + op_type: defaultdict(int) for op_type in self._sfg.get_used_type_names() + } self._remaining_ops = [ op_id for op_id in self._remaining_ops @@ -743,6 +746,16 @@ class ListScheduler(Scheduler): else 0 ) + for i in range(max(1, next_op.execution_time)): + time_slot = ( + (self._current_time + i) % self._schedule._schedule_time + if self._schedule._schedule_time + else self._current_time + ) + self._cached_execution_times_in_time[next_op.type_name()][ + time_slot + ] += 1 + self._log_scheduled_op(next_op) prio_table = self._get_priority_table() @@ -894,13 +907,13 @@ class RecursiveListScheduler(ListScheduler): new_time = ( self._schedule._start_times[op_id] + delta ) % self._schedule.schedule_time - exec_count = self._execution_times_in_time(op, new_time) + exec_count = self._execution_times_in_time(type(op), new_time) while exec_count >= self._remaining_resources[op.type_name()]: delta += 1 new_time = ( self._schedule._start_times[op_id] + delta ) % self._schedule.schedule_time - exec_count = self._execution_times_in_time(op, new_time) + exec_count = self._execution_times_in_time(type(op), new_time) if delta > self._schedule.forward_slack(op_id): continue self._schedule.move_operation(op_id, delta) @@ -952,16 +965,25 @@ class RecursiveListScheduler(ListScheduler): op_sched_time, source_start_time + source_port.latency_offset ) - exec_count = self._execution_times_in_time(op, op_sched_time) + exec_count = self._execution_times_in_time(type(op), op_sched_time) while exec_count >= self._remaining_resources[op.type_name()]: op_sched_time += 1 - exec_count = self._execution_times_in_time(op, op_sched_time) + exec_count = self._execution_times_in_time(type(op), op_sched_time) self._schedule.place_operation(op, op_sched_time, self._op_laps) self._op_laps[op.graph_id] = 0 self._logger.debug(f" Op: {op.graph_id} time: {op_sched_time}") self._remaining_recursive_ops.remove(op.graph_id) self._remaining_ops.remove(op.graph_id) + + for i in range(max(1, op.execution_time)): + time_slot = ( + (self._current_time + i) % self._schedule._schedule_time + if self._schedule._schedule_time + else self._current_time + ) + self._cached_execution_times_in_time[op.type_name()][time_slot] += 1 + prio_table = self._get_recursive_priority_table() self._schedule._schedule_time = self._schedule.get_max_end_time()