Skip to content
Snippets Groups Projects
Commit 75c839bf authored by Simon Bjurek's avatar Simon Bjurek
Browse files

Add IO times for list schedulers

parent 82b6e78a
Branches Elias-branch
No related tags found
No related merge requests found
...@@ -124,7 +124,8 @@ class EarliestDeadlineScheduler(ListScheduler): ...@@ -124,7 +124,8 @@ class EarliestDeadlineScheduler(ListScheduler):
deadlines = {} deadlines = {}
for op_id, start_time in schedule_copy.start_times.items(): for op_id, start_time in schedule_copy.start_times.items():
deadlines[op_id] = start_time + schedule.sfg.find_by_id(op_id).latency if not op_id.startswith("in"):
deadlines[op_id] = start_time + schedule.sfg.find_by_id(op_id).latency
return sorted(deadlines, key=deadlines.get) return sorted(deadlines, key=deadlines.get)
...@@ -137,7 +138,10 @@ class LeastSlackTimeScheduler(ListScheduler): ...@@ -137,7 +138,10 @@ class LeastSlackTimeScheduler(ListScheduler):
schedule_copy = copy.copy(schedule) schedule_copy = copy.copy(schedule)
ALAPScheduler().apply_scheduling(schedule_copy) ALAPScheduler().apply_scheduling(schedule_copy)
return sorted(schedule_copy.start_times, key=schedule_copy.start_times.get) sorted_ops = sorted(
schedule_copy.start_times, key=schedule_copy.start_times.get
)
return [op for op in sorted_ops if not op.startswith("in")]
class MaxFanOutScheduler(ListScheduler): class MaxFanOutScheduler(ListScheduler):
...@@ -152,7 +156,8 @@ class MaxFanOutScheduler(ListScheduler): ...@@ -152,7 +156,8 @@ class MaxFanOutScheduler(ListScheduler):
for op_id, start_time in schedule_copy.start_times.items(): for op_id, start_time in schedule_copy.start_times.items():
fan_outs[op_id] = len(schedule.sfg.find_by_id(op_id).output_signals) fan_outs[op_id] = len(schedule.sfg.find_by_id(op_id).output_signals)
return sorted(fan_outs, key=fan_outs.get, reverse=True) sorted_ops = sorted(fan_outs, key=fan_outs.get, reverse=True)
return [op for op in sorted_ops if not op.startswith("in")]
class HybridScheduler(ListScheduler): class HybridScheduler(ListScheduler):
...@@ -199,4 +204,4 @@ class HybridScheduler(ListScheduler): ...@@ -199,4 +204,4 @@ class HybridScheduler(ListScheduler):
sorted_op_list = [pair[0] for pair in fan_out_sorted_items] sorted_op_list = [pair[0] for pair in fan_out_sorted_items]
return sorted_op_list return [op for op in sorted_op_list if not op.startswith("in")]
...@@ -907,7 +907,7 @@ class ProcessCollection: ...@@ -907,7 +907,7 @@ class ProcessCollection:
def split_on_ports( def split_on_ports(
self, self,
heuristic: str = "left_edge", heuristic: str = "graph_color",
read_ports: Optional[int] = None, read_ports: Optional[int] = None,
write_ports: Optional[int] = None, write_ports: Optional[int] = None,
total_ports: Optional[int] = None, total_ports: Optional[int] = None,
......
...@@ -119,9 +119,9 @@ class Schedule: ...@@ -119,9 +119,9 @@ class Schedule:
self._remove_delays_no_laps() self._remove_delays_no_laps()
max_end_time = self.get_max_end_time() max_end_time = self.get_max_end_time()
if schedule_time is None: if not self._schedule_time:
self._schedule_time = max_end_time self._schedule_time = max_end_time
elif schedule_time < max_end_time: elif self._schedule_time < max_end_time:
raise ValueError(f"Too short schedule time. Minimum is {max_end_time}.") raise ValueError(f"Too short schedule time. Minimum is {max_end_time}.")
def __str__(self) -> str: def __str__(self) -> str:
......
import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional, cast from typing import TYPE_CHECKING, Optional, cast
...@@ -9,6 +10,7 @@ from b_asic.types import TypeName ...@@ -9,6 +10,7 @@ from b_asic.types import TypeName
if TYPE_CHECKING: if TYPE_CHECKING:
from b_asic.operation import Operation from b_asic.operation import Operation
from b_asic.schedule import Schedule from b_asic.schedule import Schedule
from b_asic.signal_flow_graph import SFG
from b_asic.types import GraphID from b_asic.types import GraphID
...@@ -44,9 +46,21 @@ class Scheduler(ABC): ...@@ -44,9 +46,21 @@ class Scheduler(ABC):
] + cast(int, source_port.latency_offset) ] + cast(int, source_port.latency_offset)
# TODO: Rename max_concurrent_reads/writes to max_concurrent_read_ports or something to signify difference
class ListScheduler(Scheduler, ABC): class ListScheduler(Scheduler, ABC):
def __init__(self, max_resources: Optional[dict[TypeName, int]] = None) -> None: def __init__(
if max_resources: self,
max_resources: Optional[dict[TypeName, int]] = None,
max_concurrent_reads: Optional[int] = None,
max_concurrent_writes: Optional[int] = None,
input_times: Optional[dict["GraphID", int]] = None,
output_delta_times: Optional[dict["GraphID", int]] = None,
cyclic: Optional[bool] = False,
) -> None:
super()
if max_resources is not None:
if not isinstance(max_resources, dict): if not isinstance(max_resources, dict):
raise ValueError("max_resources must be a dictionary.") raise ValueError("max_resources must be a dictionary.")
for key, value in max_resources.items(): for key, value in max_resources.items():
...@@ -54,12 +68,20 @@ class ListScheduler(Scheduler, ABC): ...@@ -54,12 +68,20 @@ class ListScheduler(Scheduler, ABC):
raise ValueError("max_resources key must be a valid type_name.") raise ValueError("max_resources key must be a valid type_name.")
if not isinstance(value, int): if not isinstance(value, int):
raise ValueError("max_resources value must be an integer.") raise ValueError("max_resources value must be an integer.")
if max_resources:
self._max_resources = max_resources self._max_resources = max_resources
else: else:
self._max_resources = {} self._max_resources = {}
self._max_concurrent_reads = (
max_concurrent_reads if max_concurrent_reads else sys.maxsize
)
self._max_concurrent_writes = (
max_concurrent_writes if max_concurrent_writes else sys.maxsize
)
self._input_times = input_times if input_times else {}
self._output_delta_times = output_delta_times if output_delta_times else {}
def apply_scheduling(self, schedule: "Schedule") -> None: def apply_scheduling(self, schedule: "Schedule") -> None:
"""Applies the scheduling algorithm on the given Schedule. """Applies the scheduling algorithm on the given Schedule.
...@@ -69,40 +91,63 @@ class ListScheduler(Scheduler, ABC): ...@@ -69,40 +91,63 @@ class ListScheduler(Scheduler, ABC):
Schedule to apply the scheduling algorithm on. Schedule to apply the scheduling algorithm on.
""" """
sfg = schedule.sfg sfg = schedule.sfg
start_times = schedule.start_times
used_resources_ready_times = {} used_resources_ready_times = {}
remaining_resources = self._max_resources.copy() remaining_resources = self._max_resources.copy()
sorted_operations = self._get_sorted_operations(schedule) sorted_operations = self._get_sorted_operations(schedule)
# place all inputs at time 0 schedule.start_times = {}
remaining_reads = self._max_concurrent_reads
# initial input placement
if self._input_times:
for input_id in self._input_times:
schedule.start_times[input_id] = self._input_times[input_id]
for input_op in sfg.find_by_type_name(Input.type_name()): for input_op in sfg.find_by_type_name(Input.type_name()):
start_times[input_op.graph_id] = 0 if input_op.graph_id not in self._input_times:
schedule.start_times[input_op.graph_id] = 0
current_time = 0 current_time = 0
timeout_counter = 0
while sorted_operations: while sorted_operations:
# generate the best schedulable candidate # generate the best schedulable candidate
candidate = sfg.find_by_id(sorted_operations[0]) candidate = sfg.find_by_id(sorted_operations[0])
counter = 0 counter = 0
while not self._candidate_is_schedulable( while not self._candidate_is_schedulable(
start_times, schedule.start_times,
sfg,
candidate, candidate,
current_time, current_time,
remaining_resources, remaining_resources,
remaining_reads,
self._max_concurrent_writes,
sorted_operations, sorted_operations,
): ):
if counter == len(sorted_operations): if counter == len(sorted_operations):
counter = 0 counter = 0
current_time += 1 current_time += 1
timeout_counter += 1
if timeout_counter > 10:
msg = "Algorithm did not schedule any operation for 10 time steps, try relaxing constraints."
raise TimeoutError(msg)
remaining_reads = self._max_concurrent_reads
# update available operators # update available operators
for operation, ready_time in used_resources_ready_times.items(): for operation, ready_time in used_resources_ready_times.items():
if ready_time == current_time: if ready_time == current_time:
remaining_resources[operation.type_name()] += 1 remaining_resources[operation.type_name()] += 1
else: else:
candidate = sfg.find_by_id(sorted_operations[counter]) candidate = sfg.find_by_id(sorted_operations[counter])
counter += 1 counter += 1
timeout_counter = 0
# if the resource is constrained, update remaining resources # if the resource is constrained, update remaining resources
if candidate.type_name() in remaining_resources: if candidate.type_name() in remaining_resources:
remaining_resources[candidate.type_name()] -= 1 remaining_resources[candidate.type_name()] -= 1
...@@ -115,19 +160,24 @@ class ListScheduler(Scheduler, ABC): ...@@ -115,19 +160,24 @@ class ListScheduler(Scheduler, ABC):
current_time + candidate.latency current_time + candidate.latency
) )
remaining_reads -= candidate.input_count
# schedule the best candidate to the current time # schedule the best candidate to the current time
sorted_operations.remove(candidate.graph_id) sorted_operations.remove(candidate.graph_id)
start_times[candidate.graph_id] = current_time schedule.start_times[candidate.graph_id] = current_time
schedule.set_schedule_time(current_time)
self._handle_outputs(schedule) self._handle_outputs(schedule)
if not schedule.cyclic:
max_start_time = max(schedule.start_times.values())
if current_time < max_start_time:
current_time = max_start_time
current_time = max(current_time, schedule.get_max_end_time())
schedule.set_schedule_time(current_time)
schedule.remove_delays() schedule.remove_delays()
# move all inputs ALAP now that operations have moved self._handle_inputs(schedule)
for input_op in schedule.sfg.find_by_type_name(Input.type_name()):
input_op = cast(Input, input_op)
schedule.move_operation_alap(input_op.graph_id)
# move all dont cares ALAP # move all dont cares ALAP
for dc_op in schedule.sfg.find_by_type_name(DontCare.type_name()): for dc_op in schedule.sfg.find_by_type_name(DontCare.type_name()):
...@@ -137,9 +187,12 @@ class ListScheduler(Scheduler, ABC): ...@@ -137,9 +187,12 @@ class ListScheduler(Scheduler, ABC):
@staticmethod @staticmethod
def _candidate_is_schedulable( def _candidate_is_schedulable(
start_times: dict["GraphID"], start_times: dict["GraphID"],
sfg: "SFG",
operation: "Operation", operation: "Operation",
current_time: int, current_time: int,
remaining_resources: dict["GraphID", int], remaining_resources: dict["GraphID", int],
remaining_reads: int,
max_concurrent_writes: int,
remaining_ops: list["GraphID"], remaining_ops: list["GraphID"],
) -> bool: ) -> bool:
if ( if (
...@@ -148,23 +201,79 @@ class ListScheduler(Scheduler, ABC): ...@@ -148,23 +201,79 @@ class ListScheduler(Scheduler, ABC):
): ):
return False return False
op_finish_time = current_time + operation.latency
future_ops = [
sfg.find_by_id(item[0])
for item in start_times.items()
if item[1] + sfg.find_by_id(item[0]).latency == op_finish_time
]
future_ops_writes = sum([op.input_count for op in future_ops])
if (
not operation.graph_id.startswith("out")
and future_ops_writes >= max_concurrent_writes
):
return False
read_counter = 0
earliest_start_time = 0 earliest_start_time = 0
for op_input in operation.inputs: for op_input in operation.inputs:
source_op = op_input.signals[0].source.operation source_op = op_input.signals[0].source.operation
if isinstance(source_op, Delay):
continue
source_op_graph_id = source_op.graph_id source_op_graph_id = source_op.graph_id
if source_op_graph_id in remaining_ops: if source_op_graph_id in remaining_ops:
return False return False
if start_times[source_op_graph_id] != current_time - 1:
# not a direct connection -> memory read required
read_counter += 1
if read_counter > remaining_reads:
return False
proceeding_op_start_time = start_times.get(source_op_graph_id) proceeding_op_start_time = start_times.get(source_op_graph_id)
proceeding_op_finish_time = proceeding_op_start_time + source_op.latency
# if not proceeding_op_finish_time == current_time:
# # not direct connection -> memory required, check if okay
# satisfying_remaining_reads = remaining_reads >= operation.input_count
# satisfying_remaining_writes = remaining_writes >= operation.output_count
# if not (satisfying_remaining_reads and satisfying_remaining_writes):
# return False
if not isinstance(source_op, Delay): earliest_start_time = max(earliest_start_time, proceeding_op_finish_time)
earliest_start_time = max(
earliest_start_time, proceeding_op_start_time + source_op.latency
)
return earliest_start_time <= current_time return earliest_start_time <= current_time
@abstractmethod @abstractmethod
def _get_sorted_operations(schedule: "Schedule") -> list["GraphID"]: def _get_sorted_operations(schedule: "Schedule") -> list["GraphID"]:
raise NotImplementedError raise NotImplementedError
def _handle_inputs(self, schedule: "Schedule") -> None:
for input_op in schedule.sfg.find_by_type_name(Input.type_name()):
input_op = cast(Input, input_op)
if input_op.graph_id not in self._input_times:
schedule.move_operation_alap(input_op.graph_id)
def _handle_outputs(
self, schedule: "Schedule", non_schedulable_ops: Optional[list["GraphID"]] = []
) -> None:
super()._handle_outputs(schedule, non_schedulable_ops)
schedule.set_schedule_time(schedule.get_max_end_time())
for output in schedule.sfg.find_by_type_name(Output.type_name()):
output = cast(Output, output)
if output.graph_id in self._output_delta_times:
delta_time = self._output_delta_times[output.graph_id]
if schedule.cyclic:
schedule.start_times[output.graph_id] = schedule.schedule_time
schedule.move_operation(output.graph_id, delta_time)
else:
schedule.start_times[output.graph_id] = (
schedule.schedule_time + delta_time
)
...@@ -415,7 +415,7 @@ def radix_2_dif_fft(points: int) -> SFG: ...@@ -415,7 +415,7 @@ def radix_2_dif_fft(points: int) -> SFG:
inputs = [] inputs = []
for i in range(points): for i in range(points):
inputs.append(Input(name=f"Input: {i}")) inputs.append(Input())
ports = inputs ports = inputs
number_of_stages = int(np.log2(points)) number_of_stages = int(np.log2(points))
...@@ -430,7 +430,7 @@ def radix_2_dif_fft(points: int) -> SFG: ...@@ -430,7 +430,7 @@ def radix_2_dif_fft(points: int) -> SFG:
ports = _get_bit_reversed_ports(ports) ports = _get_bit_reversed_ports(ports)
outputs = [] outputs = []
for i, port in enumerate(ports): for i, port in enumerate(ports):
outputs.append(Output(port, name=f"Output: {i}")) outputs.append(Output(port))
return SFG(inputs=inputs, outputs=outputs) return SFG(inputs=inputs, outputs=outputs)
......
"""
=========================================
Auto Scheduling With Custom IO times
=========================================
"""
from b_asic.core_operations import Butterfly, ConstantMultiplication
from b_asic.core_schedulers import ASAPScheduler, HybridScheduler
from b_asic.schedule import Schedule
from b_asic.sfg_generators import radix_2_dif_fft
sfg = radix_2_dif_fft(points=8)
# %%
# The SFG is
sfg
# %%
# Set latencies and execution times.
sfg.set_latency_of_type(Butterfly.type_name(), 3)
sfg.set_latency_of_type(ConstantMultiplication.type_name(), 2)
sfg.set_execution_time_of_type(Butterfly.type_name(), 1)
sfg.set_execution_time_of_type(ConstantMultiplication.type_name(), 1)
# %%
# Generate an ASAP schedule for reference
schedule = Schedule(sfg, scheduler=ASAPScheduler())
schedule.show()
# %%
# Generate a non-cyclic Schedule from HybridScheduler with custom IO times.
resources = {Butterfly.type_name(): 1, ConstantMultiplication.type_name(): 1}
input_times = {
"in0": 0,
"in1": 1,
"in2": 2,
"in3": 3,
"in4": 4,
"in5": 5,
"in6": 6,
"in7": 7,
}
output_delta_times = {
"out0": -2,
"out1": -1,
"out2": 0,
"out3": 1,
"out4": 2,
"out5": 3,
"out6": 4,
"out7": 5,
}
schedule = Schedule(
sfg,
scheduler=HybridScheduler(
resources,
input_times=input_times,
output_delta_times=output_delta_times,
),
)
schedule.show()
# %%
# Generate a new Schedule with cyclic scheduling enabled
output_delta_times = {
"out0": 0,
"out1": 1,
"out2": 2,
"out3": 3,
"out4": 4,
"out5": 5,
"out6": 6,
"out7": 7,
}
schedule = Schedule(
sfg,
scheduler=HybridScheduler(
resources,
input_times=input_times,
output_delta_times=output_delta_times,
),
cyclic=True,
)
schedule.show()
...@@ -64,8 +64,31 @@ print("Scheduling time:", schedule.schedule_time) ...@@ -64,8 +64,31 @@ print("Scheduling time:", schedule.schedule_time)
schedule.show() schedule.show()
# %% # %%
# Create a HybridScheduler schedule that satisfies the resource constraints. # Create a HybridScheduler schedule that satisfies the resource constraints with custom IO times.
schedule = Schedule(sfg, scheduler=HybridScheduler(resources)) # This is the schedule we will synthesize an architecture for.
input_times = {
"in0": 0,
"in1": 1,
"in2": 2,
"in3": 3,
"in4": 4,
"in5": 5,
}
output_delta_times = {
"out0": 0,
"out1": 1,
"out2": 2,
"out3": 3,
"out4": 4,
"out5": 5,
}
schedule = Schedule(
sfg,
scheduler=HybridScheduler(
resources, input_times=input_times, output_delta_times=output_delta_times
),
cyclic=True,
)
print("Scheduling time:", schedule.schedule_time) print("Scheduling time:", schedule.schedule_time)
schedule.show() schedule.show()
...@@ -116,4 +139,3 @@ arch = Architecture( ...@@ -116,4 +139,3 @@ arch = Architecture(
# %% # %%
arch arch
# schedule.edit()
"""
=========================================
Memory Constrained Scheduling
=========================================
"""
from b_asic.architecture import Architecture, Memory, ProcessingElement
from b_asic.core_operations import Butterfly, ConstantMultiplication
from b_asic.core_schedulers import ASAPScheduler, HybridScheduler
from b_asic.schedule import Schedule
from b_asic.sfg_generators import radix_2_dif_fft
from b_asic.special_operations import Input, Output
sfg = radix_2_dif_fft(points=16)
# %%
# The SFG is
sfg
# %%
# Set latencies and execution times.
sfg.set_latency_of_type(Butterfly.type_name(), 3)
sfg.set_latency_of_type(ConstantMultiplication.type_name(), 2)
sfg.set_execution_time_of_type(Butterfly.type_name(), 1)
sfg.set_execution_time_of_type(ConstantMultiplication.type_name(), 1)
# # %%
# Generate an ASAP schedule for reference
schedule = Schedule(sfg, scheduler=ASAPScheduler())
schedule.show()
# %%
# Generate a PE constrained HybridSchedule
resources = {Butterfly.type_name(): 1, ConstantMultiplication.type_name(): 1}
schedule = Schedule(sfg, scheduler=HybridScheduler(resources))
schedule.show()
# %%
operations = schedule.get_operations()
bfs = operations.get_by_type_name(Butterfly.type_name())
bfs.show(title="Butterfly executions")
const_muls = operations.get_by_type_name(ConstantMultiplication.type_name())
const_muls.show(title="ConstMul executions")
inputs = operations.get_by_type_name(Input.type_name())
inputs.show(title="Input executions")
outputs = operations.get_by_type_name(Output.type_name())
outputs.show(title="Output executions")
bf_pe = ProcessingElement(bfs, entity_name="bf")
mul_pe = ProcessingElement(const_muls, entity_name="mul")
pe_in = ProcessingElement(inputs, entity_name='input')
pe_out = ProcessingElement(outputs, entity_name='output')
mem_vars = schedule.get_memory_variables()
mem_vars.show(title="All memory variables")
direct, mem_vars = mem_vars.split_on_length()
mem_vars.show(title="Non-zero time memory variables")
mem_vars_set = mem_vars.split_on_ports(read_ports=1, write_ports=1, total_ports=2)
# %%
memories = []
for i, mem in enumerate(mem_vars_set):
memory = Memory(mem, memory_type="RAM", entity_name=f"memory{i}")
memories.append(memory)
mem.show(title=f"{memory.entity_name}")
memory.assign("left_edge")
memory.show_content(title=f"Assigned {memory.entity_name}")
direct.show(title="Direct interconnects")
# %%
arch = Architecture(
{bf_pe, mul_pe, pe_in, pe_out},
memories,
direct_interconnects=direct,
)
arch
# %%
# Generate another HybridSchedule but this time constrain the amount of reads and writes to reduce the amount of memories
resources = {Butterfly.type_name(): 1, ConstantMultiplication.type_name(): 1}
schedule = Schedule(
sfg,
scheduler=HybridScheduler(
resources, max_concurrent_reads=2, max_concurrent_writes=2
),
)
schedule.show()
# %% Print the max number of read and write port accesses to non-direct memories
direct, mem_vars = schedule.get_memory_variables().split_on_length()
print("Max read ports:", mem_vars.read_ports_bound())
print("Max write ports:", mem_vars.write_ports_bound())
# %% Proceed to construct PEs and plot executions and non-direct memory variables
operations = schedule.get_operations()
bfs = operations.get_by_type_name(Butterfly.type_name())
bfs.show(title="Butterfly executions")
const_muls = operations.get_by_type_name(ConstantMultiplication.type_name())
const_muls.show(title="ConstMul executions")
inputs = operations.get_by_type_name(Input.type_name())
inputs.show(title="Input executions")
outputs = operations.get_by_type_name(Output.type_name())
outputs.show(title="Output executions")
bf_pe = ProcessingElement(bfs, entity_name="bf")
mul_pe = ProcessingElement(const_muls, entity_name="mul")
pe_in = ProcessingElement(inputs, entity_name='input')
pe_out = ProcessingElement(outputs, entity_name='output')
mem_vars.show(title="Non-zero time memory variables")
mem_vars_set = mem_vars.split_on_ports(read_ports=1, write_ports=1, total_ports=2)
# %% Allocate memories by graph-coloring
memories = []
for i, mem in enumerate(mem_vars_set):
memory = Memory(mem, memory_type="RAM", entity_name=f"memory{i}")
memories.append(memory)
mem.show(title=f"{memory.entity_name}")
memory.assign("left_edge")
memory.show_content(title=f"Assigned {memory.entity_name}")
direct.show(title="Direct interconnects")
# %% Synthesize the new architecture, now only using two memories but with data rate
arch = Architecture(
{bf_pe, mul_pe, pe_in, pe_out},
memories,
direct_interconnects=direct,
)
arch
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment