diff --git a/b_asic/scheduler.py b/b_asic/scheduler.py index 05fe476311c956dec56d84770db666dc5ab61ae5..88ce2c8465be62c5d4f9a1764768f76fe23bbf19 100644 --- a/b_asic/scheduler.py +++ b/b_asic/scheduler.py @@ -348,6 +348,11 @@ class ListScheduler(Scheduler): List-based scheduler that schedules the operations while complying to the given constraints. + .. admonition:: Important + + Will only work on non-recursive SFGs. + For recursive SFGs use RecursiveListScheduler instead. + Parameters ---------- sort_order : tuple[tuple[int, bool]] @@ -411,12 +416,6 @@ class ListScheduler(Scheduler): log.debug("Scheduler initializing") self._initialize_scheduler(schedule) - if self._sfg.loops and self._schedule.cyclic: - raise ValueError( - "ListScheduler does not support cyclic scheduling of " - "recursive algorithms. Use RecursiveListScheduler instead." - ) - if self._input_times: self._place_inputs_on_given_times() self._remaining_ops = [ @@ -504,7 +503,8 @@ class ListScheduler(Scheduler): for op_id in ready_ops: reads = 0 for op_input in self._sfg.find_by_id(op_id).inputs: - source_op = op_input.signals[0].source.operation + source_port = op_input.signals[0].source + source_op = source_port.operation if isinstance(source_op, DontCare): continue if isinstance(source_op, Delay): @@ -512,7 +512,8 @@ class ListScheduler(Scheduler): continue if ( self._schedule.start_times[source_op.graph_id] - != self._current_time - 1 + + source_port.latency_offset + != self._current_time + op_input.latency_offset ): reads += 1 op_reads[op_id] = reads @@ -584,14 +585,16 @@ class ListScheduler(Scheduler): if self._max_concurrent_reads: tmp_used_reads = {} for op_input in op.inputs: - source_op = op_input.signals[0].source.operation + source_port = op_input.signals[0].source + source_op = source_port.operation if isinstance(source_op, (Delay, DontCare)): continue + input_read_time = self._current_time + op_input.latency_offset if ( self._schedule.start_times[source_op.graph_id] - != self._current_time - 1 + + source_port.latency_offset + != input_read_time ): - input_read_time = self._current_time + op_input.latency_offset if self._schedule._schedule_time: input_read_time %= self._schedule._schedule_time @@ -821,14 +824,15 @@ class ListScheduler(Scheduler): def _update_port_reads(self, next_op: "Operation") -> None: for input_port in next_op.inputs: - source_op = input_port.signals[0].source.operation + source_port = input_port.signals[0].source + source_op = source_port.operation + time = self._current_time + input_port.latency_offset if ( - not isinstance(source_op, DontCare) - and not isinstance(source_op, Delay) + not isinstance(source_op, (DontCare, Delay)) and self._schedule.start_times[source_op.graph_id] - != self._current_time - 1 + + source_port.latency_offset + != time ): - time = self._current_time + input_port.latency_offset if self._schedule._schedule_time: time %= self._schedule._schedule_time diff --git a/test/unit/test_list_schedulers.py b/test/unit/test_list_schedulers.py index 1d4c1bd5ff177710afe66bd59eceefa9402a8466..e36a315282e7cabd76ef2cc78f7018ccca9f290c 100644 --- a/test/unit/test_list_schedulers.py +++ b/test/unit/test_list_schedulers.py @@ -4,6 +4,7 @@ import numpy as np import pytest from scipy import signal +from b_asic.architecture import Architecture, Memory, ProcessingElement from b_asic.core_operations import ( MADS, Addition, @@ -1531,6 +1532,8 @@ class TestHybridScheduler: } assert schedule.schedule_time == 6 + _validate_recreated_sfg_fft(schedule, points=4, delays=[0, 0, 1, 1]) + _, mem_vars = schedule.get_memory_variables().split_on_length() assert mem_vars.read_ports_bound() <= 2 assert mem_vars.write_ports_bound() <= 3 @@ -1784,37 +1787,167 @@ class TestListScheduler: ), ) - def test_cyclic_and_recursive_loops(self): - N = 3 - Wc = 0.2 - b, a = signal.butter(N, Wc, btype="lowpass", output="ba") - sfg = direct_form_1_iir(b, a) + def test_execution_time_not_one_port_constrained(self): + sfg = radix_2_dif_fft(points=16) - sfg.set_latency_of_type_name(ConstantMultiplication.type_name(), 2) - sfg.set_execution_time_of_type_name(ConstantMultiplication.type_name(), 1) - sfg.set_latency_of_type_name(Addition.type_name(), 3) - sfg.set_execution_time_of_type_name(Addition.type_name(), 1) + sfg.set_latency_of_type(Butterfly, 3) + sfg.set_latency_of_type(ConstantMultiplication, 10) + sfg.set_execution_time_of_type(Butterfly, 2) + sfg.set_execution_time_of_type(ConstantMultiplication, 10) - resources = { - Addition.type_name(): 1, - ConstantMultiplication.type_name(): 1, - Input.type_name(): 1, - Output.type_name(): 1, - } + resources = {Butterfly.type_name(): 1, ConstantMultiplication.type_name(): 1} - with pytest.raises( - ValueError, - match="ListScheduler does not support cyclic scheduling of recursive algorithms. Use RecursiveListScheduler instead.", - ): - Schedule( - sfg, - scheduler=ListScheduler( - sort_order=((1, True), (3, False), (4, False)), - max_resources=resources, - ), - cyclic=True, - schedule_time=sfg.iteration_period_bound(), - ) + schedule = Schedule( + sfg, + scheduler=ListScheduler( + sort_order=((2, True), (3, True)), + max_resources=resources, + max_concurrent_reads=2, + max_concurrent_writes=2, + ), + ) + + direct, mem_vars = schedule.get_memory_variables().split_on_length() + assert mem_vars.read_ports_bound() == 2 + assert mem_vars.write_ports_bound() == 2 + _validate_recreated_sfg_fft(schedule, points=16) + + schedule = Schedule( + sfg, + scheduler=ListScheduler( + sort_order=((1, True), (3, False)), + max_resources=resources, + max_concurrent_reads=2, + max_concurrent_writes=2, + ), + ) + + direct, mem_vars = schedule.get_memory_variables().split_on_length() + assert mem_vars.read_ports_bound() == 2 + assert mem_vars.write_ports_bound() == 2 + _validate_recreated_sfg_fft(schedule, points=16) + + operations = schedule.get_operations() + bfs = operations.get_by_type_name(Butterfly.type_name()) + const_muls = operations.get_by_type_name(ConstantMultiplication.type_name()) + inputs = operations.get_by_type_name(Input.type_name()) + outputs = operations.get_by_type_name(Output.type_name()) + + bf_pe = ProcessingElement(bfs, entity_name="bf1") + mul_pe = ProcessingElement(const_muls, entity_name="mul1") + + pe_in = ProcessingElement(inputs, entity_name="input") + pe_out = ProcessingElement(outputs, entity_name="output") + + processing_elements = [bf_pe, mul_pe, pe_in, pe_out] + + mem_vars = schedule.get_memory_variables() + direct, mem_vars = mem_vars.split_on_length() + + mem_vars_set = mem_vars.split_on_ports( + read_ports=1, + write_ports=1, + total_ports=2, + strategy="ilp_graph_color", + processing_elements=processing_elements, + max_colors=2, + ) + + memories = [] + for i, mem in enumerate(mem_vars_set): + memory = Memory(mem, memory_type="RAM", entity_name=f"memory{i}") + memories.append(memory) + memory.assign("graph_color") + + arch = Architecture( + processing_elements, + memories, + direct_interconnects=direct, + ) + assert len(arch.processing_elements) == 4 + assert len(arch.memories) == 2 + + def test_execution_time_not_one_and_latency_offsets_port_constrained(self): + sfg = radix_2_dif_fft(points=16) + + sfg.set_latency_offsets_of_type( + Butterfly, {"in0": 0, "in1": 1, "out0": 2, "out1": 3} + ) + sfg.set_latency_of_type(ConstantMultiplication, 7) + sfg.set_execution_time_of_type(Butterfly, 2) + sfg.set_execution_time_of_type(ConstantMultiplication, 5) + + resources = {Butterfly.type_name(): 1, ConstantMultiplication.type_name(): 1} + + schedule = Schedule( + sfg, + scheduler=ListScheduler( + sort_order=((2, True), (3, True)), + max_resources=resources, + max_concurrent_reads=2, + max_concurrent_writes=2, + ), + ) + + direct, mem_vars = schedule.get_memory_variables().split_on_length() + assert mem_vars.read_ports_bound() == 2 + assert mem_vars.write_ports_bound() == 2 + _validate_recreated_sfg_fft(schedule, points=16) + + schedule = Schedule( + sfg, + scheduler=ListScheduler( + sort_order=((1, True), (3, False)), + max_resources=resources, + max_concurrent_reads=2, + max_concurrent_writes=2, + ), + ) + + direct, mem_vars = schedule.get_memory_variables().split_on_length() + assert mem_vars.read_ports_bound() == 2 + assert mem_vars.write_ports_bound() == 2 + _validate_recreated_sfg_fft(schedule, points=16) + + operations = schedule.get_operations() + bfs = operations.get_by_type_name(Butterfly.type_name()) + const_muls = operations.get_by_type_name(ConstantMultiplication.type_name()) + inputs = operations.get_by_type_name(Input.type_name()) + outputs = operations.get_by_type_name(Output.type_name()) + + bf_pe = ProcessingElement(bfs, entity_name="bf1") + mul_pe = ProcessingElement(const_muls, entity_name="mul1") + + pe_in = ProcessingElement(inputs, entity_name="input") + pe_out = ProcessingElement(outputs, entity_name="output") + + processing_elements = [bf_pe, mul_pe, pe_in, pe_out] + + mem_vars = schedule.get_memory_variables() + direct, mem_vars = mem_vars.split_on_length() + + mem_vars_set = mem_vars.split_on_ports( + read_ports=1, + write_ports=1, + total_ports=2, + strategy="ilp_graph_color", + processing_elements=processing_elements, + max_colors=2, + ) + + memories = [] + for i, mem in enumerate(mem_vars_set): + memory = Memory(mem, memory_type="RAM", entity_name=f"memory{i}") + memories.append(memory) + memory.assign("graph_color") + + arch = Architecture( + processing_elements, + memories, + direct_interconnects=direct, + ) + assert len(arch.processing_elements) == 4 + assert len(arch.memories) == 2 class TestRecursiveListScheduler: @@ -1982,7 +2115,7 @@ def _validate_recreated_sfg_fft( # constant input -> impulse (with weight=points) output sim = Simulation(schedule.sfg, [Constant() for i in range(points)]) sim.run_for(128) - assert np.allclose(sim.results["0"], points) + assert np.allclose(sim.results["0"][delays[0] :], points) for i in range(1, points): assert np.all(np.isclose(sim.results[str(i)][delays[i] :], 0))