Skip to content
Snippets Groups Projects
Commit 977d3f5a authored by Simon Bjurek's avatar Simon Bjurek
Browse files

started on port constrained scheduling, first version functional

parent 8b1d1222
Branches
No related tags found
No related merge requests found
Pipeline #156536 failed
...@@ -907,7 +907,7 @@ class ProcessCollection: ...@@ -907,7 +907,7 @@ class ProcessCollection:
def split_on_ports( def split_on_ports(
self, self,
heuristic: str = "left_edge", heuristic: str = "graph_color",
read_ports: Optional[int] = None, read_ports: Optional[int] = None,
write_ports: Optional[int] = None, write_ports: Optional[int] = None,
total_ports: Optional[int] = None, total_ports: Optional[int] = None,
......
import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional, cast from typing import TYPE_CHECKING, Optional, cast
...@@ -9,6 +10,7 @@ from b_asic.types import TypeName ...@@ -9,6 +10,7 @@ from b_asic.types import TypeName
if TYPE_CHECKING: if TYPE_CHECKING:
from b_asic.operation import Operation from b_asic.operation import Operation
from b_asic.schedule import Schedule from b_asic.schedule import Schedule
from b_asic.signal_flow_graph import SFG
from b_asic.types import GraphID from b_asic.types import GraphID
...@@ -44,10 +46,15 @@ class Scheduler(ABC): ...@@ -44,10 +46,15 @@ class Scheduler(ABC):
] + cast(int, source_port.latency_offset) ] + cast(int, source_port.latency_offset)
# TODO: Rename max_concurrent_reads/writes to max_concurrent_read_ports or something to signify difference
class ListScheduler(Scheduler, ABC): class ListScheduler(Scheduler, ABC):
def __init__( def __init__(
self, self,
max_resources: Optional[dict[TypeName, int]] = None, max_resources: Optional[dict[TypeName, int]] = None,
max_concurrent_reads: Optional[int] = None,
max_concurrent_writes: Optional[int] = None,
input_times: Optional[dict["GraphID", int]] = None, input_times: Optional[dict["GraphID", int]] = None,
output_delta_times: Optional[dict["GraphID", int]] = None, output_delta_times: Optional[dict["GraphID", int]] = None,
cyclic: Optional[bool] = False, cyclic: Optional[bool] = False,
...@@ -65,6 +72,13 @@ class ListScheduler(Scheduler, ABC): ...@@ -65,6 +72,13 @@ class ListScheduler(Scheduler, ABC):
else: else:
self._max_resources = {} self._max_resources = {}
self._max_concurrent_reads = (
max_concurrent_reads if max_concurrent_reads else sys.maxsize
)
self._max_concurrent_writes = (
max_concurrent_writes if max_concurrent_writes else sys.maxsize
)
self._input_times = input_times if input_times else {} self._input_times = input_times if input_times else {}
self._output_delta_times = output_delta_times if output_delta_times else {} self._output_delta_times = output_delta_times if output_delta_times else {}
...@@ -77,45 +91,63 @@ class ListScheduler(Scheduler, ABC): ...@@ -77,45 +91,63 @@ class ListScheduler(Scheduler, ABC):
Schedule to apply the scheduling algorithm on. Schedule to apply the scheduling algorithm on.
""" """
sfg = schedule.sfg sfg = schedule.sfg
start_times = schedule.start_times
used_resources_ready_times = {} used_resources_ready_times = {}
remaining_resources = self._max_resources.copy() remaining_resources = self._max_resources.copy()
sorted_operations = self._get_sorted_operations(schedule) sorted_operations = self._get_sorted_operations(schedule)
schedule.start_times = {}
remaining_reads = self._max_concurrent_reads
# initial input placement # initial input placement
if self._input_times: if self._input_times:
for input_id in self._input_times: for input_id in self._input_times:
start_times[input_id] = self._input_times[input_id] schedule.start_times[input_id] = self._input_times[input_id]
for input_op in sfg.find_by_type_name(Input.type_name()): for input_op in sfg.find_by_type_name(Input.type_name()):
if input_op.graph_id not in self._input_times: if input_op.graph_id not in self._input_times:
start_times[input_op.graph_id] = 0 schedule.start_times[input_op.graph_id] = 0
current_time = 0 current_time = 0
timeout_counter = 0
while sorted_operations: while sorted_operations:
# generate the best schedulable candidate # generate the best schedulable candidate
candidate = sfg.find_by_id(sorted_operations[0]) candidate = sfg.find_by_id(sorted_operations[0])
counter = 0 counter = 0
while not self._candidate_is_schedulable( while not self._candidate_is_schedulable(
start_times, schedule.start_times,
sfg,
candidate, candidate,
current_time, current_time,
remaining_resources, remaining_resources,
remaining_reads,
self._max_concurrent_writes,
sorted_operations, sorted_operations,
): ):
if counter == len(sorted_operations): if counter == len(sorted_operations):
counter = 0 counter = 0
current_time += 1 current_time += 1
timeout_counter += 1
if timeout_counter > 10:
msg = "Algorithm did not schedule any operation for 10 time steps, try relaxing constraints."
raise TimeoutError(msg)
remaining_reads = self._max_concurrent_reads
# update available operators # update available operators
for operation, ready_time in used_resources_ready_times.items(): for operation, ready_time in used_resources_ready_times.items():
if ready_time == current_time: if ready_time == current_time:
remaining_resources[operation.type_name()] += 1 remaining_resources[operation.type_name()] += 1
else: else:
candidate = sfg.find_by_id(sorted_operations[counter]) candidate = sfg.find_by_id(sorted_operations[counter])
counter += 1 counter += 1
timeout_counter = 0
# if the resource is constrained, update remaining resources # if the resource is constrained, update remaining resources
if candidate.type_name() in remaining_resources: if candidate.type_name() in remaining_resources:
remaining_resources[candidate.type_name()] -= 1 remaining_resources[candidate.type_name()] -= 1
...@@ -128,9 +160,11 @@ class ListScheduler(Scheduler, ABC): ...@@ -128,9 +160,11 @@ class ListScheduler(Scheduler, ABC):
current_time + candidate.latency current_time + candidate.latency
) )
remaining_reads -= candidate.input_count
# schedule the best candidate to the current time # schedule the best candidate to the current time
sorted_operations.remove(candidate.graph_id) sorted_operations.remove(candidate.graph_id)
start_times[candidate.graph_id] = current_time schedule.start_times[candidate.graph_id] = current_time
self._handle_outputs(schedule) self._handle_outputs(schedule)
...@@ -138,6 +172,7 @@ class ListScheduler(Scheduler, ABC): ...@@ -138,6 +172,7 @@ class ListScheduler(Scheduler, ABC):
max_start_time = max(schedule.start_times.values()) max_start_time = max(schedule.start_times.values())
if current_time < max_start_time: if current_time < max_start_time:
current_time = max_start_time current_time = max_start_time
current_time = max(current_time, schedule.get_max_end_time())
schedule.set_schedule_time(current_time) schedule.set_schedule_time(current_time)
schedule.remove_delays() schedule.remove_delays()
...@@ -152,9 +187,12 @@ class ListScheduler(Scheduler, ABC): ...@@ -152,9 +187,12 @@ class ListScheduler(Scheduler, ABC):
@staticmethod @staticmethod
def _candidate_is_schedulable( def _candidate_is_schedulable(
start_times: dict["GraphID"], start_times: dict["GraphID"],
sfg: "SFG",
operation: "Operation", operation: "Operation",
current_time: int, current_time: int,
remaining_resources: dict["GraphID", int], remaining_resources: dict["GraphID", int],
remaining_reads: int,
max_concurrent_writes: int,
remaining_ops: list["GraphID"], remaining_ops: list["GraphID"],
) -> bool: ) -> bool:
if ( if (
...@@ -163,20 +201,51 @@ class ListScheduler(Scheduler, ABC): ...@@ -163,20 +201,51 @@ class ListScheduler(Scheduler, ABC):
): ):
return False return False
op_finish_time = current_time + operation.latency
future_ops = [
sfg.find_by_id(item[0])
for item in start_times.items()
if item[1] + sfg.find_by_id(item[0]).latency == op_finish_time
]
future_ops_writes = sum([op.input_count for op in future_ops])
if (
not operation.graph_id.startswith("out")
and future_ops_writes >= max_concurrent_writes
):
return False
read_counter = 0
earliest_start_time = 0 earliest_start_time = 0
for op_input in operation.inputs: for op_input in operation.inputs:
source_op = op_input.signals[0].source.operation source_op = op_input.signals[0].source.operation
if isinstance(source_op, Delay):
continue
source_op_graph_id = source_op.graph_id source_op_graph_id = source_op.graph_id
if source_op_graph_id in remaining_ops: if source_op_graph_id in remaining_ops:
return False return False
if start_times[source_op_graph_id] != current_time - 1:
# not a direct connection -> memory read required
read_counter += 1
if read_counter > remaining_reads:
return False
proceeding_op_start_time = start_times.get(source_op_graph_id) proceeding_op_start_time = start_times.get(source_op_graph_id)
proceeding_op_finish_time = proceeding_op_start_time + source_op.latency
# if not proceeding_op_finish_time == current_time:
# # not direct connection -> memory required, check if okay
# satisfying_remaining_reads = remaining_reads >= operation.input_count
# satisfying_remaining_writes = remaining_writes >= operation.output_count
# if not (satisfying_remaining_reads and satisfying_remaining_writes):
# return False
if not isinstance(source_op, Delay): earliest_start_time = max(earliest_start_time, proceeding_op_finish_time)
earliest_start_time = max(
earliest_start_time, proceeding_op_start_time + source_op.latency
)
return earliest_start_time <= current_time return earliest_start_time <= current_time
......
...@@ -52,7 +52,12 @@ output_delta_times = { ...@@ -52,7 +52,12 @@ output_delta_times = {
"out7": 5, "out7": 5,
} }
schedule = Schedule( schedule = Schedule(
sfg, scheduler=HybridScheduler(resources, input_times, output_delta_times) sfg,
scheduler=HybridScheduler(
resources,
input_times=input_times,
output_delta_times=output_delta_times,
),
) )
schedule.show() schedule.show()
...@@ -70,7 +75,11 @@ output_delta_times = { ...@@ -70,7 +75,11 @@ output_delta_times = {
} }
schedule = Schedule( schedule = Schedule(
sfg, sfg,
scheduler=HybridScheduler(resources, input_times, output_delta_times), scheduler=HybridScheduler(
resources,
input_times=input_times,
output_delta_times=output_delta_times,
),
cyclic=True, cyclic=True,
) )
schedule.show() schedule.show()
...@@ -84,7 +84,9 @@ output_delta_times = { ...@@ -84,7 +84,9 @@ output_delta_times = {
} }
schedule = Schedule( schedule = Schedule(
sfg, sfg,
scheduler=HybridScheduler(resources, input_times, output_delta_times), scheduler=HybridScheduler(
resources, input_times=input_times, output_delta_times=output_delta_times
),
cyclic=True, cyclic=True,
) )
print("Scheduling time:", schedule.schedule_time) print("Scheduling time:", schedule.schedule_time)
...@@ -137,4 +139,3 @@ arch = Architecture( ...@@ -137,4 +139,3 @@ arch = Architecture(
# %% # %%
arch arch
# schedule.edit()
"""
=========================================
Memory Constrained Scheduling
=========================================
"""
from b_asic.architecture import Architecture, Memory, ProcessingElement
from b_asic.core_operations import Butterfly, ConstantMultiplication
from b_asic.core_schedulers import ASAPScheduler, HybridScheduler
from b_asic.schedule import Schedule
from b_asic.sfg_generators import radix_2_dif_fft
from b_asic.special_operations import Input, Output
sfg = radix_2_dif_fft(points=16)
# %%
# The SFG is
sfg
# %%
# Set latencies and execution times.
sfg.set_latency_of_type(Butterfly.type_name(), 3)
sfg.set_latency_of_type(ConstantMultiplication.type_name(), 2)
sfg.set_execution_time_of_type(Butterfly.type_name(), 1)
sfg.set_execution_time_of_type(ConstantMultiplication.type_name(), 1)
# # %%
# Generate an ASAP schedule for reference
schedule = Schedule(sfg, scheduler=ASAPScheduler())
schedule.show()
# %%
# Generate a PE constrained HybridSchedule
resources = {Butterfly.type_name(): 1, ConstantMultiplication.type_name(): 1}
schedule = Schedule(sfg, scheduler=HybridScheduler(resources))
schedule.show()
# %%
operations = schedule.get_operations()
bfs = operations.get_by_type_name(Butterfly.type_name())
bfs.show(title="Butterfly executions")
const_muls = operations.get_by_type_name(ConstantMultiplication.type_name())
const_muls.show(title="ConstMul executions")
inputs = operations.get_by_type_name(Input.type_name())
inputs.show(title="Input executions")
outputs = operations.get_by_type_name(Output.type_name())
outputs.show(title="Output executions")
bf_pe = ProcessingElement(bfs, entity_name="bf")
mul_pe = ProcessingElement(const_muls, entity_name="mul")
pe_in = ProcessingElement(inputs, entity_name='input')
pe_out = ProcessingElement(outputs, entity_name='output')
mem_vars = schedule.get_memory_variables()
mem_vars.show(title="All memory variables")
direct, mem_vars = mem_vars.split_on_length()
mem_vars.show(title="Non-zero time memory variables")
mem_vars_set = mem_vars.split_on_ports(read_ports=1, write_ports=1, total_ports=2)
# %%
memories = []
for i, mem in enumerate(mem_vars_set):
memory = Memory(mem, memory_type="RAM", entity_name=f"memory{i}")
memories.append(memory)
mem.show(title=f"{memory.entity_name}")
memory.assign("left_edge")
memory.show_content(title=f"Assigned {memory.entity_name}")
direct.show(title="Direct interconnects")
# %%
arch = Architecture(
{bf_pe, mul_pe, pe_in, pe_out},
memories,
direct_interconnects=direct,
)
arch
# %%
# Generate another HybridSchedule but this time constrain the amount of reads and writes to reduce the amount of memories
resources = {Butterfly.type_name(): 1, ConstantMultiplication.type_name(): 1}
schedule = Schedule(
sfg,
scheduler=HybridScheduler(
resources, max_concurrent_reads=2, max_concurrent_writes=2
),
)
schedule.show()
# %% Print the max number of read and write port accesses to non-direct memories
direct, mem_vars = schedule.get_memory_variables().split_on_length()
print("Max read ports:", mem_vars.read_ports_bound())
print("Max write ports:", mem_vars.write_ports_bound())
# %% Proceed to construct PEs and plot executions and non-direct memory variables
operations = schedule.get_operations()
bfs = operations.get_by_type_name(Butterfly.type_name())
bfs.show(title="Butterfly executions")
const_muls = operations.get_by_type_name(ConstantMultiplication.type_name())
const_muls.show(title="ConstMul executions")
inputs = operations.get_by_type_name(Input.type_name())
inputs.show(title="Input executions")
outputs = operations.get_by_type_name(Output.type_name())
outputs.show(title="Output executions")
bf_pe = ProcessingElement(bfs, entity_name="bf")
mul_pe = ProcessingElement(const_muls, entity_name="mul")
pe_in = ProcessingElement(inputs, entity_name='input')
pe_out = ProcessingElement(outputs, entity_name='output')
mem_vars.show(title="Non-zero time memory variables")
mem_vars_set = mem_vars.split_on_ports(read_ports=1, write_ports=1, total_ports=2)
# %% Allocate memories by graph-coloring
memories = []
for i, mem in enumerate(mem_vars_set):
memory = Memory(mem, memory_type="RAM", entity_name=f"memory{i}")
memories.append(memory)
mem.show(title=f"{memory.entity_name}")
memory.assign("left_edge")
memory.show_content(title=f"Assigned {memory.entity_name}")
direct.show(title="Direct interconnects")
# %% Synthesize the new architecture, now only using two memories but with data rate
arch = Architecture(
{bf_pe, mul_pe, pe_in, pe_out},
memories,
direct_interconnects=direct,
)
arch
...@@ -861,7 +861,9 @@ class TestHybridScheduler: ...@@ -861,7 +861,9 @@ class TestHybridScheduler:
} }
schedule = Schedule( schedule = Schedule(
sfg, sfg,
scheduler=HybridScheduler(resources, input_times, output_times), scheduler=HybridScheduler(
resources, input_times=input_times, output_delta_times=output_times
),
cyclic=True, cyclic=True,
) )
...@@ -933,7 +935,9 @@ class TestHybridScheduler: ...@@ -933,7 +935,9 @@ class TestHybridScheduler:
} }
schedule = Schedule( schedule = Schedule(
sfg, sfg,
scheduler=HybridScheduler(resources, input_times, output_times), scheduler=HybridScheduler(
resources, input_times=input_times, output_delta_times=output_times
),
cyclic=False, cyclic=False,
) )
...@@ -1027,7 +1031,9 @@ class TestHybridScheduler: ...@@ -1027,7 +1031,9 @@ class TestHybridScheduler:
} }
schedule = Schedule( schedule = Schedule(
sfg, sfg,
scheduler=HybridScheduler(resources, input_times, output_times), scheduler=HybridScheduler(
resources, input_times=input_times, output_delta_times=output_times
),
cyclic=True, cyclic=True,
) )
...@@ -1078,3 +1084,41 @@ class TestHybridScheduler: ...@@ -1078,3 +1084,41 @@ class TestHybridScheduler:
resources = {MADS.type_name(): "test"} resources = {MADS.type_name(): "test"}
with pytest.raises(ValueError, match="max_resources value must be an integer."): with pytest.raises(ValueError, match="max_resources value must be an integer."):
Schedule(sfg, scheduler=HybridScheduler(resources)) Schedule(sfg, scheduler=HybridScheduler(resources))
# def test_ldlt_inverse_2x2_read_constrained(self):
# sfg = ldlt_matrix_inverse(N=2)
# sfg.set_latency_of_type(MADS.type_name(), 3)
# sfg.set_latency_of_type(Reciprocal.type_name(), 2)
# sfg.set_execution_time_of_type(MADS.type_name(), 1)
# sfg.set_execution_time_of_type(Reciprocal.type_name(), 1)
# resources = {MADS.type_name(): 1, Reciprocal.type_name(): 1}
# schedule = Schedule(
# sfg,
# scheduler=HybridScheduler(
# max_resources = resources,
# max_concurrent_reads = 3,
# ),
# )
def test_ldlt_inverse_2x2_read_constrained_too_low(self):
sfg = ldlt_matrix_inverse(N=2)
sfg.set_latency_of_type(MADS.type_name(), 3)
sfg.set_latency_of_type(Reciprocal.type_name(), 2)
sfg.set_execution_time_of_type(MADS.type_name(), 1)
sfg.set_execution_time_of_type(Reciprocal.type_name(), 1)
resources = {MADS.type_name(): 1, Reciprocal.type_name(): 1}
with pytest.raises(
TimeoutError,
match="Algorithm did not schedule any operation for 10 time steps, try relaxing constraints.",
):
Schedule(
sfg,
scheduler=HybridScheduler(
max_resources=resources,
max_concurrent_reads=2,
),
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment