diff --git a/b_asic/scheduler.py b/b_asic/scheduler.py index 88ce2c8465be62c5d4f9a1764768f76fe23bbf19..a50e390d484c0a896e8bf0c7ba3cfc666a5ba5fa 100644 --- a/b_asic/scheduler.py +++ b/b_asic/scheduler.py @@ -1,21 +1,24 @@ import copy +import math from abc import ABC, abstractmethod from collections import defaultdict from typing import TYPE_CHECKING, cast import b_asic.logger from b_asic.core_operations import DontCare, Sink -from b_asic.port import OutputPort +from b_asic.operation import Operation +from b_asic.port import InputPort, OutputPort from b_asic.special_operations import Delay, Input, Output from b_asic.types import TypeName if TYPE_CHECKING: - from b_asic.operation import Operation from b_asic.schedule import Schedule from b_asic.types import GraphID log = b_asic.logger.getLogger(__name__) +PriorityTableType = list[tuple["GraphID", int, int, int, int]] + class Scheduler(ABC): """ @@ -31,6 +34,17 @@ class Scheduler(ABC): If the y-position should be sorted based on start time of operations. """ + __slots__ = ( + '_schedule', + '_op_laps', + '_input_times', + '_output_delta_times', + '_sort_y_location', + ) + + _schedule: "Schedule" + _op_laps: dict["GraphID", int] + def __init__( self, input_times: dict["GraphID", int] | None = None, @@ -98,8 +112,10 @@ class Scheduler(ABC): log.debug("Input placement completed") def _place_outputs_asap( - self, schedule: "Schedule", non_schedulable_ops: list["GraphID"] | None = [] + self, schedule: "Schedule", non_schedulable_ops: list["GraphID"] | None ) -> None: + if non_schedulable_ops is None: + non_schedulable_ops = [] for output in schedule._sfg.find_by_type(Output): output = cast(Output, output) source_port = cast(OutputPort, output.inputs[0].signals[0].source) @@ -122,7 +138,7 @@ class Scheduler(ABC): end = self._schedule._schedule_time else: end = self._schedule.get_max_end_time() - for output in self._sfg.find_by_type(Output): + for output in self._schedule._sfg.find_by_type(Output): output = cast(Output, output) if output.graph_id in self._output_delta_times: delta_time = self._output_delta_times[output.graph_id] @@ -141,7 +157,7 @@ class Scheduler(ABC): count = -1 for op_id, time in self._schedule.start_times.items(): if time == new_time and isinstance( - self._sfg.find_by_id(op_id), Output + self._schedule._sfg.find_by_id(op_id), Output ): count += 1 @@ -156,10 +172,10 @@ class Scheduler(ABC): log.debug("Output placement optimization starting") min_slack = min( self._schedule.backward_slack(op.graph_id) - for op in self._sfg.find_by_type(Output) + for op in self._schedule._sfg.find_by_type(Output) ) if min_slack != 0: - for output in self._sfg.find_by_type(Output): + for output in self._schedule._sfg.find_by_type(Output): if self._schedule._cyclic and self._schedule._schedule_time is not None: self._schedule.move_operation(output.graph_id, -min_slack) else: @@ -185,19 +201,23 @@ class Scheduler(ABC): def _handle_dont_cares(self) -> None: # schedule all dont cares ALAP - for dc_op in self._sfg.find_by_type(DontCare): + for dc_op in self._schedule._sfg.find_by_type(DontCare): self._schedule.start_times[dc_op.graph_id] = 0 self._schedule.place_operation( - dc_op, self._schedule.forward_slack(dc_op.graph_id), self._op_laps + cast(Operation, dc_op), + self._schedule.forward_slack(dc_op.graph_id), + self._op_laps, ) self._op_laps[dc_op.graph_id] = 0 def _handle_sinks(self) -> None: # schedule all sinks ASAP - for sink_op in self._sfg.find_by_type(Sink): + for sink_op in self._schedule._sfg.find_by_type(Sink): self._schedule.start_times[sink_op.graph_id] = self._schedule._schedule_time self._schedule.place_operation( - sink_op, self._schedule.backward_slack(sink_op.graph_id), self._op_laps + cast(Operation, sink_op), + self._schedule.backward_slack(sink_op.graph_id), + self._op_laps, ) self._op_laps[sink_op.graph_id] = 0 @@ -208,8 +228,7 @@ class ASAPScheduler(Scheduler): def apply_scheduling(self, schedule: "Schedule") -> None: # Doc-string inherited self._schedule = schedule - self._sfg = schedule._sfg - prec_list = schedule.sfg.get_precedence_list() + prec_list = schedule._sfg.get_precedence_list() if len(prec_list) < 2: raise ValueError("Empty signal flow graph cannot be scheduled.") @@ -234,7 +253,7 @@ class ASAPScheduler(Scheduler): if operation.graph_id not in schedule.start_times: op_start_time = 0 for current_input in operation.inputs: - source_port = current_input.signals[0].source + source_port = cast(OutputPort, current_input.signals[0].source) if source_port.operation.graph_id in non_schedulable_ops: source_end_time = 0 @@ -293,7 +312,6 @@ class ALAPScheduler(Scheduler): def apply_scheduling(self, schedule: "Schedule") -> None: # Doc-string inherited self._schedule = schedule - self._sfg = schedule._sfg ASAPScheduler( self._input_times, self._output_delta_times, @@ -345,13 +363,12 @@ class ALAPScheduler(Scheduler): class ListScheduler(Scheduler): """ - List-based scheduler that schedules the operations while complying to the given - constraints. + List-based scheduler that schedules the operations with constraints. .. admonition:: Important - Will only work on non-recursive SFGs. - For recursive SFGs use RecursiveListScheduler instead. + Only works on non-recursive SFGs. + For recursive SFGs use :class:`RecursiveListScheduler` instead. Parameters ---------- @@ -364,12 +381,43 @@ class ListScheduler(Scheduler): Max number of conccurent reads, by default None max_concurrent_writes : int | None, optional Max number of conccurent writes, by default None - input_times : dict[GraphID, int] | None, optional + input_times : dict["GraphID", int] | None, optional Specified input times, by default None - output_delta_times : dict[GraphID, int] | None, optional + output_delta_times : dict["GraphID", int] | None, optional Specified output delta times, by default None """ + __slots__ = ( + '_remaining_ops', + '_deadlines', + '_output_slacks', + '_fan_outs', + '_current_time', + '_cached_execution_times_in_time', + '_alap_start_times', + '_sort_order', + '_max_resources', + '_max_concurrent_reads', + '_max_concurrent_writes', + '_remaining_ops_set', + '_alap_op_laps', + '_alap_schedule_time', + '_used_reads', + '_remaining_resources', + '_cached_execution_times', + ) + _remaining_ops: list["GraphID"] + _deadlines: dict["GraphID", int] + _output_slacks: dict["GraphID", int] + _fan_outs: dict["GraphID", int] + _current_time: int + _cached_execution_times_in_time: dict[type[Operation], defaultdict[int, int]] + _alap_start_times: dict["GraphID", int] + _alap_op_laps: dict["GraphID", int] + _alap_schedule_time: int + _used_reads: defaultdict[int, int] + _cached_execution_times: dict["GraphID", int] + def __init__( self, sort_order: tuple[tuple[int, bool], ...], @@ -436,9 +484,7 @@ class ListScheduler(Scheduler): schedule.sort_y_locations_on_start_times() log.debug("Scheduling completed") - def _get_next_op_id( - self, priority_table: list[tuple["GraphID", int, ...]] - ) -> "GraphID": + def _get_next_op_id(self, priority_table: PriorityTableType) -> "GraphID": def sort_key(item): return tuple( (item[index] * (-1 if not asc else 1),) @@ -448,9 +494,7 @@ class ListScheduler(Scheduler): sorted_table = sorted(priority_table, key=sort_key) return sorted_table[0][0] - def _get_priority_table( - self, candidate_ids - ) -> list[tuple["GraphID", int, int, int]]: + def _get_priority_table(self, candidate_ids) -> PriorityTableType: schedule_time = ( self._schedule._schedule_time if self._schedule._schedule_time is not None @@ -459,7 +503,9 @@ class ListScheduler(Scheduler): ready_ops = [ op_id for op_id in candidate_ids - if self._op_is_schedulable(self._sfg.find_by_id(op_id), schedule_time) + if self._op_is_schedulable( + cast(Operation, self._schedule._sfg.find_by_id(op_id)), schedule_time + ) ] memory_reads = self._calculate_memory_reads(ready_ops) @@ -477,9 +523,8 @@ class ListScheduler(Scheduler): def _calculate_deadlines(self) -> dict["GraphID", int]: deadlines = {} for op_id, start_time in self._alap_start_times.items(): - output_offsets = [ - output.latency_offset for output in self._sfg.find_by_id(op_id).outputs - ] + op = cast(Operation, self._schedule._sfg.find_by_id(op_id)) + output_offsets = [cast(int, output.latency_offset) for output in op.outputs] start_time += self._alap_op_laps[op_id] * self._alap_schedule_time deadlines[op_id] = start_time + min(output_offsets, default=0) return deadlines @@ -492,7 +537,9 @@ class ListScheduler(Scheduler): def _calculate_fan_outs(self) -> dict["GraphID", int]: return { - op_id: len(self._sfg.find_by_id(op_id).output_signals) + op_id: len( + cast(Operation, self._schedule._sfg.find_by_id(op_id)).output_signals + ) for op_id in self._alap_start_times } @@ -502,24 +549,23 @@ class ListScheduler(Scheduler): op_reads = {} for op_id in ready_ops: reads = 0 - for op_input in self._sfg.find_by_id(op_id).inputs: - source_port = op_input.signals[0].source + op = cast(Operation, self._schedule._sfg.find_by_id(op_id)) + for op_input in op.inputs: + source_port = cast(OutputPort, op_input.signals[0].source) source_op = source_port.operation if isinstance(source_op, DontCare): continue if isinstance(source_op, Delay): reads += 1 continue - if ( - self._schedule.start_times[source_op.graph_id] - + source_port.latency_offset - != self._current_time + op_input.latency_offset - ): + if self._schedule.start_times[source_op.graph_id] + cast( + int, source_port.latency_offset + ) != self._current_time + cast(int, op_input.latency_offset): reads += 1 op_reads[op_id] = reads return op_reads - def _execution_times_in_time(self, op_type: "Operation", time: int) -> int: + def _execution_times_in_time(self, op_type, time: int) -> int: count = 0 for other_op_id, start_time in self._schedule.start_times.items(): if self._schedule._schedule_time is not None: @@ -528,12 +574,12 @@ class ListScheduler(Scheduler): time >= start_time and time < start_time + max(self._cached_execution_times[other_op_id], 1) - and isinstance(self._sfg.find_by_id(other_op_id), op_type) + and isinstance(self._schedule._sfg.find_by_id(other_op_id), op_type) ): count += 1 return count - def _op_satisfies_resource_constraints(self, op: "Operation") -> bool: + def _op_satisfies_resource_constraints(self, op: Operation) -> bool: op_type = type(op) if self._schedule._schedule_time is None: for i in range(max(1, op.execution_time)): @@ -542,35 +588,35 @@ class ListScheduler(Scheduler): if count >= self._remaining_resources[op_type]: return False else: - for i in range(max(1, op.execution_time)): + for i in range(max(1, cast(int, op.execution_time))): time_slot = (self._current_time + i) % self._schedule._schedule_time count = self._cached_execution_times_in_time[op_type][time_slot] if count >= self._remaining_resources[op_type]: return False return True - def _op_satisfies_concurrent_writes(self, op: "Operation") -> bool: + def _op_satisfies_concurrent_writes(self, op: Operation) -> bool: if self._max_concurrent_writes: - tmp_used_writes = {} + tmp_used_writes: defaultdict[int, int] = defaultdict(int) if not isinstance(op, Output): for output_port in op.outputs: - output_ready_time = self._current_time + output_port.latency_offset + output_ready_time = self._current_time + cast( + int, output_port.latency_offset + ) if self._schedule._schedule_time: output_ready_time %= self._schedule._schedule_time writes_in_time = 0 - for item in self._schedule.start_times.items(): + for op_id, start_time in self._schedule.start_times.items(): + tmp_op = cast(Operation, self._schedule._sfg.find_by_id(op_id)) offsets = [ - output.latency_offset - for output in self._sfg.find_by_id(item[0]).outputs + cast(int, output.latency_offset) + for output in tmp_op.outputs ] - write_times = [item[1] + offset for offset in offsets] + write_times = [start_time + offset for offset in offsets] writes_in_time += write_times.count(output_ready_time) - if tmp_used_writes.get(output_ready_time): - tmp_used_writes[output_ready_time] += 1 - else: - tmp_used_writes[output_ready_time] = 1 + tmp_used_writes[output_ready_time] += 1 if ( self._max_concurrent_writes @@ -581,41 +627,40 @@ class ListScheduler(Scheduler): return False return True - def _op_satisfies_concurrent_reads(self, op: "Operation") -> bool: + def _op_satisfies_concurrent_reads(self, op: Operation) -> bool: if self._max_concurrent_reads: - tmp_used_reads = {} + tmp_used_reads: defaultdict[int, int] = defaultdict(int) for op_input in op.inputs: - source_port = op_input.signals[0].source + source_port = cast(OutputPort, op_input.signals[0].source) source_op = source_port.operation if isinstance(source_op, (Delay, DontCare)): continue - input_read_time = self._current_time + op_input.latency_offset + input_read_time = self._current_time + cast( + int, op_input.latency_offset + ) if ( self._schedule.start_times[source_op.graph_id] - + source_port.latency_offset + + cast(int, source_port.latency_offset) != input_read_time ): if self._schedule._schedule_time: input_read_time %= self._schedule._schedule_time - if tmp_used_reads.get(input_read_time): - tmp_used_reads[input_read_time] += 1 - else: - tmp_used_reads[input_read_time] = 1 + tmp_used_reads[input_read_time] += 1 - prev_used = self._used_reads.get(input_read_time) or 0 if ( self._max_concurrent_reads - < prev_used + tmp_used_reads[input_read_time] + < self._used_reads[input_read_time] + + tmp_used_reads[input_read_time] ): return False return True def _op_satisfies_data_dependencies( - self, op: "Operation", schedule_time: int + self, op: Operation, schedule_time: int ) -> bool: for op_input in op.inputs: - source_port = op_input.signals[0].source + source_port = cast(OutputPort, op_input.signals[0].source) source_op = source_port.operation if isinstance(source_op, (Delay, DontCare)): @@ -627,15 +672,15 @@ class ListScheduler(Scheduler): available_time = ( self._schedule.start_times[source_op.graph_id] + self._op_laps[source_op.graph_id] * schedule_time - + source_port.latency_offset + + cast(int, source_port.latency_offset) ) - required_time = self._current_time + op_input.latency_offset + required_time = self._current_time + cast(int, op_input.latency_offset) if available_time > required_time: return False return True - def _op_is_schedulable(self, op: "Operation", schedule_time: int) -> bool: + def _op_is_schedulable(self, op: Operation, schedule_time: int) -> bool: return ( self._op_satisfies_resource_constraints(op) and self._op_satisfies_data_dependencies(op, schedule_time) @@ -645,17 +690,16 @@ class ListScheduler(Scheduler): def _initialize_scheduler(self, schedule: "Schedule") -> None: self._schedule = schedule - self._sfg = schedule._sfg for resource_type in self._max_resources: - if not self._sfg.find_by_type_name(resource_type): + if not self._schedule._sfg.find_by_type_name(resource_type): raise ValueError( f"Provided max resource of type {resource_type} cannot be found in the provided SFG." ) differing_elems = [ resource - for resource in self._sfg.get_used_type_names() + for resource in self._schedule._sfg.get_used_type_names() if resource not in self._max_resources and resource != Delay.type_name() and resource != DontCare.type_name() @@ -665,19 +709,19 @@ class ListScheduler(Scheduler): self._max_resources[type_name] = 1 for key in self._input_times: - if self._sfg.find_by_id(key) is None: + if self._schedule._sfg.find_by_id(key) is None: raise ValueError( f"Provided input time with GraphID {key} cannot be found in the provided SFG." ) for key in self._output_delta_times: - if self._sfg.find_by_id(key) is None: + if self._schedule._sfg.find_by_id(key) is None: raise ValueError( f"Provided output delta time with GraphID {key} cannot be found in the provided SFG." ) if self._schedule._cyclic and self._schedule._schedule_time is not None: - iteration_period_bound = self._sfg.iteration_period_bound() + iteration_period_bound = self._schedule._sfg.iteration_period_bound() if self._schedule._schedule_time < iteration_period_bound: raise ValueError( f"Provided scheduling time {self._schedule._schedule_time} must be larger or equal to the" @@ -686,7 +730,7 @@ class ListScheduler(Scheduler): if self._schedule._schedule_time is not None: for resource_type, resource_amount in self._max_resources.items(): - if resource_amount < self._sfg.resource_lower_bound( + if resource_amount < self._schedule._sfg.resource_lower_bound( resource_type, self._schedule._schedule_time ): raise ValueError( @@ -718,7 +762,7 @@ class ListScheduler(Scheduler): f"{alap_schedule.schedule_time}." ) - used_op_types = self._sfg.get_used_operation_types() + used_op_types = self._schedule._sfg.get_used_operation_types() def find_type_from_type_name(type_name): for op_type in used_op_types: @@ -730,9 +774,9 @@ class ListScheduler(Scheduler): for type_name, cnt in self._max_resources.items() } - self._remaining_ops = [op.graph_id for op in self._sfg.operations] + self._remaining_ops = [op.graph_id for op in self._schedule._sfg.operations] self._cached_execution_times = { - op_id: self._sfg.find_by_id(op_id).execution_time + op_id: cast(Operation, self._schedule._sfg.find_by_id(op_id)).execution_time for op_id in self._remaining_ops } self._cached_execution_times_in_time = { @@ -741,19 +785,20 @@ class ListScheduler(Scheduler): self._remaining_ops = [ op_id for op_id in self._remaining_ops - if not isinstance(self._sfg.find_by_id(op_id), (Delay, DontCare)) + if not isinstance(self._schedule._sfg.find_by_id(op_id), (Delay, DontCare)) ] self._remaining_ops = [ op_id for op_id in self._remaining_ops if not ( - isinstance(self._sfg.find_by_id(op_id), Output) + isinstance(self._schedule._sfg.find_by_id(op_id), Output) and op_id in self._output_delta_times ) ] for op_id in self._remaining_ops: - if self._sfg.find_by_id(op_id).execution_time is None: + op = cast(Operation, self._schedule._sfg.find_by_id(op_id)) + if op.execution_time is None: raise ValueError( "All operations in the SFG must have a specified execution time. " f"Missing operation: {op_id}." @@ -765,17 +810,18 @@ class ListScheduler(Scheduler): self._fan_outs = self._calculate_fan_outs() self._schedule.start_times = {} - self._used_reads = {0: 0} + self._used_reads = defaultdict(int) self._current_time = 0 def _schedule_nonrecursive_ops(self) -> None: - log.debug("Non-Recursive Operation scheduling starting") + log.debug("Non-recursive operation scheduling starting") while self._remaining_ops: prio_table = self._get_priority_table(self._remaining_ops) while prio_table: next_op_id = self._get_next_op_id(prio_table) - next_op = self._sfg.find_by_id(next_op_id) + next_op = self._schedule._sfg.find_by_id(next_op_id) + next_op = cast(Operation, next_op) self._update_port_reads(next_op) @@ -800,7 +846,7 @@ class ListScheduler(Scheduler): time_slot ] += 1 else: - for i in range(max(1, next_op.execution_time)): + for i in range(max(1, cast(int, next_op.execution_time))): time_slot = ( self._current_time + i ) % self._schedule._schedule_time @@ -822,27 +868,26 @@ class ListScheduler(Scheduler): self._current_time -= 1 log.debug("Non-recursive operation scheduling completed") - def _update_port_reads(self, next_op: "Operation") -> None: + def _update_port_reads(self, next_op: Operation) -> None: for input_port in next_op.inputs: - source_port = input_port.signals[0].source + source_port = cast(OutputPort, input_port.signals[0].source) source_op = source_port.operation - time = self._current_time + input_port.latency_offset + time = self._current_time + cast(int, input_port.latency_offset) if ( not isinstance(source_op, (DontCare, Delay)) and self._schedule.start_times[source_op.graph_id] - + source_port.latency_offset + + cast(int, source_port.latency_offset) != time ): if self._schedule._schedule_time: time %= self._schedule._schedule_time - if self._used_reads.get(time): - self._used_reads[time] += 1 - else: - self._used_reads[time] = 1 + self._used_reads[time] += 1 class RecursiveListScheduler(ListScheduler): + __slots__ = ('_recursive_ops', '_recursive_ops_set', '_remaining_recursive_ops') + def __init__( self, sort_order: tuple[tuple[int, bool], ...], @@ -868,7 +913,7 @@ class RecursiveListScheduler(ListScheduler): ] self._remaining_ops_set = set(self._remaining_ops) - loops = self._sfg.loops + loops = self._schedule._sfg.loops if loops: self._schedule_recursive_ops(loops) @@ -882,7 +927,9 @@ class RecursiveListScheduler(ListScheduler): period_bound = self._schedule._sfg.iteration_period_bound() self._schedule.remove_delays() if loops: - self._retime_ops(period_bound) + if int(period_bound) != period_bound: + log.warning("Rational iteration period bound: %d", period_bound) + self._retime_ops(math.ceil(period_bound)) self._handle_dont_cares() if self._sort_y_location: schedule.sort_y_locations_on_start_times() @@ -893,14 +940,14 @@ class RecursiveListScheduler(ListScheduler): for loop in loops: for op_id in loop: if op_id not in recursive_ops and not isinstance( - self._sfg.find_by_id(op_id), Delay + self._schedule._sfg.find_by_id(op_id), Delay ): recursive_ops.append(op_id) return recursive_ops - def _recursive_op_satisfies_data_dependencies(self, op: "Operation") -> bool: + def _recursive_op_satisfies_data_dependencies(self, op: Operation) -> bool: for op_input in op.inputs: - source_port = source_op = op_input.signals[0].source + source_port = cast(OutputPort, op_input.signals[0].source) source_op = source_port.operation if isinstance(source_op, (Delay, DontCare)): continue @@ -916,7 +963,7 @@ class RecursiveListScheduler(ListScheduler): op_id for op_id in self._remaining_recursive_ops if self._recursive_op_satisfies_data_dependencies( - self._sfg.find_by_id(op_id) + self._schedule._sfg.find_by_id(op_id) ) ] return [(op_id, self._deadlines[op_id]) for op_id in ready_ops] @@ -973,7 +1020,7 @@ class RecursiveListScheduler(ListScheduler): # adjust time if a gap exists on the left side of the schedule slack = min(self._schedule._start_times.values()) for other_op_id in sorted_op_ids: - op = self._schedule._sfg.find_by_id(other_op_id) + op = cast(Operation, self._schedule._sfg.find_by_id(other_op_id)) max_end_time = 0 op_start_time = self._schedule._start_times[other_op_id] for outport in op.outputs: @@ -1003,7 +1050,7 @@ class RecursiveListScheduler(ListScheduler): op = self._get_next_recursive_op(prio_table) op_sched_time = 0 for input_port in op.inputs: - source_port = input_port.signals[0].source + source_port = cast(OutputPort, input_port.signals[0].source) source_op = source_port.operation if isinstance(source_op, Delay): continue @@ -1029,7 +1076,7 @@ class RecursiveListScheduler(ListScheduler): self._remaining_ops.remove(op.graph_id) self._remaining_ops_set.remove(op.graph_id) - for i in range(max(1, op.execution_time)): + for i in range(max(1, cast(int, op.execution_time))): time_slot = self._current_time + i self._cached_execution_times_in_time[op_type][time_slot] += 1 @@ -1041,18 +1088,16 @@ class RecursiveListScheduler(ListScheduler): self._schedule._schedule_time = saved_sched_time log.debug("Scheduling of recursive loops completed") - def _get_next_recursive_op( - self, priority_table: list[tuple["GraphID", int, ...]] - ) -> "Operation": + def _get_next_recursive_op(self, priority_table: PriorityTableType) -> Operation: sorted_table = sorted(priority_table, key=lambda row: row[1]) - return self._sfg.find_by_id(sorted_table[0][0]) + return cast(Operation, self._schedule._sfg.find_by_id(sorted_table[0][0])) def _pipeline_input_to_recursive_sections(self) -> None: for op_id in self._recursive_ops: - op = self._sfg.find_by_id(op_id) + op = cast(Operation, self._schedule._sfg.find_by_id(op_id)) for input_port in op.inputs: signal = input_port.signals[0] - source_op = signal.source.operation + source_op = cast(OutputPort, signal.source).operation if ( not isinstance(source_op, Delay) and source_op.graph_id not in self._recursive_ops_set @@ -1061,17 +1106,17 @@ class RecursiveListScheduler(ListScheduler): self._schedule.laps[signal.graph_id] += 1 def _op_satisfies_data_dependencies( - self, op: "Operation", schedule_time: int + self, op: Operation, schedule_time: int ) -> bool: for output_port in op.outputs: - destination_port = output_port.signals[0].destination + destination_port = cast(InputPort, output_port.signals[0].destination) destination_op = destination_port.operation if destination_op.graph_id not in self._remaining_ops_set: if isinstance(destination_op, Delay): continue # spotted a recursive operation -> check if ok - op_available_time = ( - self._current_time + op.latency_offsets[f"out{output_port.index}"] + op_available_time = self._current_time + cast( + int, output_port.latency_offset ) usage_time = ( self._schedule.start_times[destination_op.graph_id] @@ -1084,18 +1129,18 @@ class RecursiveListScheduler(ListScheduler): self._pipeline_input_to_recursive_sections() for op_input in op.inputs: - source_port = op_input.signals[0].source + source_port = cast(OutputPort, op_input.signals[0].source) source_op = source_port.operation if isinstance(source_op, (Delay, DontCare)): continue if source_op.graph_id in self._remaining_ops_set: return False available_time = ( - self._schedule.start_times.get(source_op.graph_id) + self._schedule.start_times[source_op.graph_id] + self._op_laps[source_op.graph_id] * schedule_time - + source_port.latency_offset + + cast(int, source_port.latency_offset) ) - required_time = self._current_time + op_input.latency_offset + required_time = self._current_time + cast(int, op_input.latency_offset) if available_time > required_time: return False return True diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 6073467b15776d2c9229d259bc5a5110ee154864..859e3647a0bd707505d38ab7df03a353bb8476aa 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -12,9 +12,8 @@ from collections.abc import Iterable, MutableSet, Sequence from fractions import Fraction from io import StringIO from math import ceil -from numbers import Number from queue import PriorityQueue -from typing import ClassVar, Literal, Optional, Union, cast +from typing import ClassVar, Literal, Union, cast import numpy as np from graphviz import Digraph @@ -116,8 +115,8 @@ class SFG(AbstractOperation): _original_components_to_new: dict[GraphComponent, GraphComponent] _original_input_signals_to_indices: dict[Signal, int] _original_output_signals_to_indices: dict[Signal, int] - _precedence_list: list[list[OutputPort]] | None - _used_ids: ClassVar[set[GraphID]] = set() + _precedence_list: list[list[OutputPort]] + _used_ids: set[GraphID] def __init__( self, @@ -152,7 +151,7 @@ class SFG(AbstractOperation): self._original_components_to_new = {} self._original_input_signals_to_indices = {} self._original_output_signals_to_indices = {} - self._precedence_list = None + self._precedence_list = [] # Setup input signals. if input_signals is not None: @@ -348,7 +347,7 @@ class SFG(AbstractOperation): prefix: str = "", bits_override: int | None = None, quantize: bool = True, - ) -> Number: + ) -> Num: # doc-string inherited if index < 0 or index >= self.output_count: raise IndexError( @@ -547,11 +546,11 @@ class SFG(AbstractOperation): """Get all operations of this graph in depth-first order.""" return list(self._operations_dfs_order) - def find_by_type_name(self, type_name: TypeName) -> Sequence[GraphComponent]: + def find_by_type_name(self, type_name: TypeName) -> list[GraphComponent]: """ Find all components in this graph with the specified type name. - Returns an empty sequence if no components were found. + Returns an empty list if no components were found. Parameters ---------- @@ -565,15 +564,17 @@ class SFG(AbstractOperation): ] return components - def find_by_type(self, component_type: GraphComponent) -> Sequence[GraphComponent]: + def find_by_type( + self, component_type: type[GraphComponent] + ) -> list[GraphComponent]: """ Find all components in this graph with the specified type. - Returns an empty sequence if no components were found. + Returns an empty list if no components were found. Parameters ---------- - component_type : GraphComponent + component_type : type of GraphComponent The TypeName of the desired components. """ components = [ @@ -666,19 +667,17 @@ class SFG(AbstractOperation): signal.remove_source() signal.set_source(component.output(index_out)) - if component_copy.type_name() == "out": + if isinstance(component_copy, Output): sfg_copy._output_operations.remove(component_copy) warnings.warn( f"Output port {component_copy.graph_id} has been removed", stacklevel=2 ) - if component.type_name() == "out": + if isinstance(component, Output): sfg_copy._output_operations.append(component) return sfg_copy() # Copy again to update IDs. - def insert_operation( - self, component: Operation, output_comp_id: GraphID - ) -> Optional["SFG"]: + def insert_operation(self, component: Operation, output_comp_id: GraphID) -> "SFG": """ Insert an operation in the SFG after a given source operation. @@ -726,7 +725,7 @@ class SFG(AbstractOperation): self, output_comp_id: GraphID, new_operation: Operation, - ) -> Optional["SFG"]: + ) -> "SFG": """ Insert an operation in the SFG after a given source operation. @@ -753,8 +752,8 @@ class SFG(AbstractOperation): "Only operations with one input and one output can be inserted." ) if "." in output_comp_id: - output_comp_id, port_id = output_comp_id.split(".") - port_id = int(port_id) + output_comp_id, port_id_str = output_comp_id.split(".") + port_id = int(port_id_str) else: port_id = None @@ -780,7 +779,7 @@ class SFG(AbstractOperation): input_comp_id: GraphID, new_operation: Operation, port: int | None = None, - ) -> Optional["SFG"]: + ) -> "SFG": """ Insert an operation in the SFG before a given source operation. @@ -1211,13 +1210,15 @@ class SFG(AbstractOperation): for op in self.find_by_type_name(type_name): cast(Operation, op).set_latency(latency) - def set_latency_of_type(self, operation_type: Operation, latency: int) -> None: + def set_latency_of_type( + self, operation_type: type[Operation], latency: int + ) -> None: """ Set the latency of all operations with the given type. Parameters ---------- - operation_type : Operation + operation_type : type of Operation The operation type. For example, ``Addition``. latency : int The latency of the operation. @@ -1243,14 +1244,14 @@ class SFG(AbstractOperation): cast(Operation, op).execution_time = execution_time def set_execution_time_of_type( - self, operation_type: Operation, execution_time: int + self, operation_type: type[Operation], execution_time: int ) -> None: """ Set the latency of all operations with the given type. Parameters ---------- - operation_type : Operation + operation_type : type of Operation The operation type. For example, ``Addition``. execution_time : int The execution time of the operation. @@ -1276,14 +1277,14 @@ class SFG(AbstractOperation): cast(Operation, op).set_latency_offsets(latency_offsets) def set_latency_offsets_of_type( - self, operation_type: Operation, latency_offsets: dict[str, int] + self, operation_type: type[Operation], latency_offsets: dict[str, int] ) -> None: """ Set the latency offsets of all operations with the given type. Parameters ---------- - operation_type : Operation + operation_type : type of Operation The operation type. For example, ``Addition``. latency_offsets : {"in1": int, ...} The latency offsets of the inputs and outputs. @@ -1818,13 +1819,13 @@ class SFG(AbstractOperation): raise ValueError( f"Schedule time must be positive, current schedule time is: {schedule_time}." ) - exec_times = [op.execution_time for op in ops] + exec_times = [cast(Operation, op).execution_time for op in ops] if any(time is None for time in exec_times): raise ValueError( f"Execution times not set for all operations of type {type_name}." ) - total_exec_time = sum([op.execution_time for op in ops]) + total_exec_time = sum(exec_times) return ceil(total_exec_time / schedule_time) def iteration_period_bound(self) -> Fraction: @@ -1839,7 +1840,7 @@ class SFG(AbstractOperation): """ loops = self.loops if not loops: - return -1 + return Fraction(-1) op_and_latency = {} for op in self.operations: @@ -2277,7 +2278,7 @@ class SFG(AbstractOperation): ret.sort() return ret - def get_used_operation_types(self) -> list[Operation]: + def get_used_operation_types(self) -> list[type[Operation]]: """Get a list of all Operations used in the SFG.""" ret = list({type(op) for op in self.operations}) ret.sort(key=lambda op: op.type_name()) diff --git a/test/unit/test_sfg.py b/test/unit/test_sfg.py index c5691aaf9155423c384d09cf59b02e5a7f6f1664..9d0a7678c2b248d0fbf4ec1f13b026e49457be9d 100644 --- a/test/unit/test_sfg.py +++ b/test/unit/test_sfg.py @@ -497,7 +497,7 @@ class TestFindComponentsWithTypeName: class TestGetPrecedenceList: def test_inputs_delays(self, precedence_sfg_delays): # No cached precedence list - assert precedence_sfg_delays._precedence_list is None + assert not precedence_sfg_delays._precedence_list precedence_list = precedence_sfg_delays.get_precedence_list()