diff --git a/b_asic/architecture.py b/b_asic/architecture.py index 81dbed5e4a2747a70e98cc31d7f9c14231704112..815b00873e2abf85164a6154c7d9bc4d6a0a4218 100644 --- a/b_asic/architecture.py +++ b/b_asic/architecture.py @@ -924,6 +924,53 @@ of :class:`~b_asic.architecture.ProcessingElement` raise KeyError(f"{proc} not in {source.entity_name}") self._build_dicts() + def show( + self, + fmt: str | None = None, + branch_node: bool = True, + cluster: bool = True, + splines: str = "spline", + io_cluster: bool = True, + multiplexers: bool = True, + colored: bool = True, + ) -> None: + """ + Display a visual representation of the Architecture using the default system viewer. + + Parameters + ---------- + fmt : str, optional + File format of the generated graph. Output formats can be found at + https://www.graphviz.org/doc/info/output.html + Most common are "pdf", "eps", "png", and "svg". Default is None which + leads to PDF. + branch_node : bool, default: True + Whether to create a branch node for outputs with fan-out of two or higher. + cluster : bool, default: True + Whether to draw memories and PEs in separate clusters. + splines : {"spline", "line", "ortho", "polyline", "curved"}, default: "spline" + Spline style, see https://graphviz.org/docs/attrs/splines/ for more info. + io_cluster : bool, default: True + Whether Inputs and Outputs are drawn inside an IO cluster. Only relevant + if *cluster* is True. + multiplexers : bool, default: True + Whether input multiplexers are included. + colored : bool, default: True + Whether to color the nodes. + """ + + dg = self._digraph( + branch_node=branch_node, + cluster=cluster, + splines=splines, + io_cluster=io_cluster, + multiplexers=multiplexers, + colored=colored, + ) + if fmt is not None: + dg.format = fmt + dg.view() + def _digraph( self, branch_node: bool = True, @@ -940,8 +987,8 @@ of :class:`~b_asic.architecture.ProcessingElement` Whether to create a branch node for outputs with fan-out of two or higher. cluster : bool, default: True Whether to draw memories and PEs in separate clusters. - splines : str, default: "spline" - The type of interconnect to use for graph drawing. + splines : {"spline", "line", "ortho", "polyline", "curved"}, default: "spline" + Spline style, see https://graphviz.org/docs/attrs/splines/ for more info. io_cluster : bool, default: True Whether Inputs and Outputs are drawn inside an IO cluster. Only relevant if *cluster* is True. diff --git a/b_asic/resources.py b/b_asic/resources.py index acdcd1a01fbef4da96a1f162de3c3542058bb094..87f8659e13ff09effad54a84b6c99ff8f13986e3 100644 --- a/b_asic/resources.py +++ b/b_asic/resources.py @@ -1,11 +1,12 @@ import io import itertools import re +import sys from collections import Counter, defaultdict from collections.abc import Iterable from functools import reduce from math import floor, log2 -from typing import Literal, TypeVar +from typing import TYPE_CHECKING, Literal, TypeVar import matplotlib.pyplot as plt import networkx as nx @@ -23,6 +24,9 @@ from b_asic.process import ( ) from b_asic.types import TypeName +if TYPE_CHECKING: + from b_asic.architecture import ProcessingElement + # Default latency coloring RGB tuple _LATENCY_COLOR = tuple(c / 255 for c in LATENCY_COLOR) _WARNING_COLOR = tuple(c / 255 for c in WARNING_COLOR) @@ -912,6 +916,7 @@ class ProcessCollection: read_ports: int | None = None, write_ports: int | None = None, total_ports: int | None = None, + processing_elements: list["ProcessingElement"] | None = None, ) -> list["ProcessCollection"]: """ Split based on concurrent read and write accesses. @@ -926,6 +931,8 @@ class ProcessCollection: * "graph_color" * "left_edge" + * "min_pe_to_mem" + * "min_mem_to_pe" read_ports : int, optional The number of read ports used when splitting process collection based on @@ -939,6 +946,9 @@ class ProcessCollection: The total number of ports used when splitting process collection based on memory variable access. + processing_elements : list[ProcessingElements], optional + The currently used PEs, only required if heuristic = "min_mem_to_pe". + Returns ------- A set of new ProcessCollection objects with the process splitting. @@ -949,16 +959,40 @@ class ProcessCollection: if heuristic == "graph_color": return self._split_ports_graph_color(read_ports, write_ports, total_ports) elif heuristic == "left_edge": - return self.split_ports_sequentially( + return self._split_ports_sequentially( + read_ports, + write_ports, + total_ports, + sequence=sorted(self), + ) + elif heuristic == "min_pe_to_mem": + if processing_elements is None: + raise ValueError( + "processing_elements must be provided if heuristic = 'min_pe_to_mem'" + ) + return self._split_ports_minimize_pe_to_memory_connections( read_ports, write_ports, total_ports, sequence=sorted(self), + processing_elements=processing_elements, + ) + elif heuristic == "min_mem_to_pe": + if processing_elements is None: + raise ValueError( + "processing_elements must be provided if heuristic = 'min_mem_to_pe'" + ) + return self._split_ports_minimize_memory_to_pe_connections( + read_ports, + write_ports, + total_ports, + sequence=sorted(self), + processing_elements=processing_elements, ) else: raise ValueError("Invalid heuristic provided.") - def split_ports_sequentially( + def _split_ports_sequentially( self, read_ports: int, write_ports: int, @@ -995,36 +1029,6 @@ class ProcessCollection: A set of new ProcessCollection objects with the process splitting. """ - def ports_collide(proc: Process, collection: ProcessCollection): - """ - Predicate test if insertion of a process `proc` results in colliding ports - when inserted to `collection` based on the `read_ports`, `write_ports`, and - `total_ports`. - """ - - # Test the number of concurrent write accesses - collection_writes = defaultdict(int, collection.write_port_accesses()) - if collection_writes[proc.start_time] >= write_ports: - return True - - # Test the number of concurrent read accesses - collection_reads = defaultdict(int, collection.read_port_accesses()) - for proc_read_time in proc.read_times: - if collection_reads[proc_read_time % self.schedule_time] >= read_ports: - return True - - # Test the number of total accesses - collection_total_accesses = defaultdict( - int, Counter(collection_writes) + Counter(collection_reads) - ) - for access_time in [proc.start_time, *proc.read_times]: - if collection_total_accesses[access_time] >= total_ports: - return True - - # No collision detected - return False - - # Make sure that processes from `sequence` and and `self` are equal if set(self.collection) != set(sequence): raise KeyError("processes in `sequence` must be equal to processes in self") @@ -1032,12 +1036,13 @@ class ProcessCollection: for process in sequence: process_added = False for collection in collections: - if not ports_collide(process, collection): + if not self._ports_collide( + process, collection, write_ports, read_ports, total_ports + ): collection.add_process(process) process_added = True break if not process_added: - # Stuff the process in a new collection collections.append( ProcessCollection( [process], @@ -1045,9 +1050,137 @@ class ProcessCollection: cyclic=self._cyclic, ) ) - # Return the list of created ProcessCollections return collections + def _split_ports_minimize_pe_to_memory_connections( + self, + read_ports: int, + write_ports: int, + total_ports: int, + sequence: list[Process], + processing_elements: list["ProcessingElement"], + ) -> list["ProcessCollection"]: + + if set(self.collection) != set(sequence): + raise KeyError("processes in `sequence` must be equal to processes in self") + + num_of_memories = len( + self._split_ports_sequentially( + read_ports, write_ports, total_ports, sequence + ) + ) + collections: list[ProcessCollection] = [ + ProcessCollection( + [], + schedule_time=self.schedule_time, + cyclic=self._cyclic, + ) + for _ in range(num_of_memories) + ] + + for process in sequence: + process_fits_in_collection = self._get_process_fits_in_collection( + process, collections, read_ports, write_ports, total_ports + ) + best_collection = None + best_delta = sys.maxsize + + for i, collection in enumerate(collections): + if process_fits_in_collection[i]: + count_1 = ProcessCollection._count_number_of_pes_connected( + processing_elements, collection + ) + tmp_collection = [*collection.collection, process] + count_2 = ProcessCollection._count_number_of_pes_connected( + processing_elements, tmp_collection + ) + delta = count_2 - count_1 + if delta < best_delta: + best_collection = collection + best_delta = delta + + elif not any(process_fits_in_collection): + collections.append( + ProcessCollection( + [], + schedule_time=self.schedule_time, + cyclic=self._cyclic, + ) + ) + process_fits_in_collection = self._get_process_fits_in_collection( + process, collections, read_ports, write_ports, total_ports + ) + + best_collection.add_process(process) + + for i in range(len(collections)): + if not collections[i].collection: + collections.pop(i) + return collections + + def _split_ports_minimize_memory_to_pe_connections( + self, + read_ports: int, + write_ports: int, + total_ports: int, + sequence: list[Process], + processing_elements: list["ProcessingElement"], + ) -> list["ProcessCollection"]: + raise NotImplementedError() + + def _get_process_fits_in_collection( + self, process, collections, write_ports, read_ports, total_ports + ) -> list[bool]: + return [ + not self._ports_collide( + process, collection, write_ports, read_ports, total_ports + ) + for collection in collections + ] + + def _ports_collide( + self, + proc: Process, + collection: "ProcessCollection", + write_ports: int, + read_ports: int, + total_ports: int, + ) -> bool: + # Test the number of concurrent write accesses + collection_writes = defaultdict(int, collection.write_port_accesses()) + if collection_writes[proc.start_time] >= write_ports: + return True + + # Test the number of concurrent read accesses + collection_reads = defaultdict(int, collection.read_port_accesses()) + for proc_read_time in proc.read_times: + if collection_reads[proc_read_time % self.schedule_time] >= read_ports: + return True + + # Test the number of total accesses + collection_total_accesses = defaultdict( + int, Counter(collection_writes) + Counter(collection_reads) + ) + for access_time in [proc.start_time, *proc.read_times]: + if collection_total_accesses[access_time] >= total_ports: + return True + return False + + @staticmethod + def _count_number_of_pes_connected( + processing_elements: list["ProcessingElement"], + collection: "ProcessCollection", + ) -> int: + collection_process_names = {proc.name.split(".")[0] for proc in collection} + count = 0 + for pe in processing_elements: + if any( + proc.name.split(".")[0] in collection_process_names + for proc in pe.collection + ): + count += 1 + return count + def _split_ports_graph_color( self, read_ports: int, @@ -1203,12 +1336,10 @@ class ProcessCollection: assignment: list[ProcessCollection] = [] for next_process in sorted(self): if next_process.execution_time > self.schedule_time: - # Can not assign process to any cell raise ValueError( f"{next_process} has execution time greater than the schedule time" ) elif next_process.execution_time == self.schedule_time: - # Always assign maximum lifetime process to new cell assignment.append( ProcessCollection( (next_process,), @@ -1216,7 +1347,6 @@ class ProcessCollection: cyclic=self._cyclic, ) ) - continue # Continue assigning next process else: next_process_stop_time = ( next_process.start_time + next_process.execution_time diff --git a/b_asic/scheduler.py b/b_asic/scheduler.py index 767c38cfe7e4f2fe638ff101de6fb8655a7cc516..300ff625b8f4b6720214c7ea1bdb81108ade4502 100644 --- a/b_asic/scheduler.py +++ b/b_asic/scheduler.py @@ -616,8 +616,8 @@ class ListScheduler(Scheduler): def _op_is_schedulable(self, op: "Operation") -> bool: return ( - self._op_satisfies_resource_constraints(op) - and self._op_satisfies_data_dependencies(op) + self._op_satisfies_data_dependencies(op) + and self._op_satisfies_resource_constraints(op) and self._op_satisfies_concurrent_writes(op) and self._op_satisfies_concurrent_reads(op) ) @@ -988,13 +988,13 @@ class RecursiveListScheduler(ListScheduler): self._schedule._schedule_time = self._schedule.get_max_end_time() - if saved_sched_time: + if saved_sched_time and saved_sched_time > self._schedule._schedule_time: self._schedule._schedule_time = saved_sched_time self._logger.debug("--- Scheduling of recursive loops completed ---") def _get_next_recursive_op( self, priority_table: list[tuple["GraphID", int, ...]] - ) -> "GraphID": + ) -> "Operation": sorted_table = sorted(priority_table, key=lambda row: row[1]) return self._sfg.find_by_id(sorted_table[0][0]) diff --git a/test/integration/test_sfg_to_architecture.py b/test/integration/test_sfg_to_architecture.py index be645aed50a22f0dc9500e1f9d7079933981665f..6c183ad27da259c22510773f690618e8119edde8 100644 --- a/test/integration/test_sfg_to_architecture.py +++ b/test/integration/test_sfg_to_architecture.py @@ -91,7 +91,7 @@ def test_pe_constrained_schedule(): # assert arch.schedule_time == schedule.schedule_time -def test_pe_and_memory_constrained_chedule(): +def test_pe_and_memory_constrained_schedule(): sfg = radix_2_dif_fft(points=16) sfg.set_latency_of_type_name(Butterfly.type_name(), 3) @@ -154,3 +154,116 @@ def test_pe_and_memory_constrained_chedule(): assert arch.direct_interconnects == direct assert arch.schedule_time == schedule.schedule_time + + +def test_different_resource_algorithms(): + POINTS = 32 + sfg = radix_2_dif_fft(POINTS) + sfg.set_latency_of_type(Butterfly, 1) + sfg.set_latency_of_type(ConstantMultiplication, 3) + sfg.set_execution_time_of_type(Butterfly, 1) + sfg.set_execution_time_of_type(ConstantMultiplication, 1) + + resources = { + Butterfly.type_name(): 2, + ConstantMultiplication.type_name(): 2, + Input.type_name(): 1, + Output.type_name(): 1, + } + schedule_1 = Schedule( + sfg, + scheduler=HybridScheduler( + resources, max_concurrent_reads=4, max_concurrent_writes=4 + ), + ) + + operations = schedule_1.get_operations() + bfs = operations.get_by_type_name(Butterfly.type_name()) + bfs = bfs.split_on_execution_time() + const_muls = operations.get_by_type_name(ConstantMultiplication.type_name()) + const_muls = const_muls.split_on_execution_time() + inputs = operations.get_by_type_name(Input.type_name()) + outputs = operations.get_by_type_name(Output.type_name()) + + bf_pe_1 = ProcessingElement(bfs[0], entity_name="bf1") + bf_pe_2 = ProcessingElement(bfs[1], entity_name="bf2") + + mul_pe_1 = ProcessingElement(const_muls[0], entity_name="mul1") + mul_pe_2 = ProcessingElement(const_muls[1], entity_name="mul2") + + pe_in = ProcessingElement(inputs, entity_name="input") + pe_out = ProcessingElement(outputs, entity_name="output") + + processing_elements = [bf_pe_1, bf_pe_2, mul_pe_1, mul_pe_2, pe_in, pe_out] + + mem_vars = schedule_1.get_memory_variables() + direct, mem_vars = mem_vars.split_on_length() + + # LEFT-EDGE + mem_vars_set = mem_vars.split_on_ports( + read_ports=1, + write_ports=1, + total_ports=2, + heuristic="left_edge", + processing_elements=processing_elements, + ) + + memories = [] + for i, mem in enumerate(mem_vars_set): + memory = Memory(mem, memory_type="RAM", entity_name=f"memory{i}") + memories.append(memory) + memory.assign("graph_color") + + arch = Architecture( + processing_elements, + memories, + direct_interconnects=direct, + ) + assert len(arch.processing_elements) == 6 + assert len(arch.memories) == 6 + + # MIN-PE-TO-MEM + mem_vars_set = mem_vars.split_on_ports( + read_ports=1, + write_ports=1, + total_ports=2, + heuristic="min_pe_to_mem", + processing_elements=processing_elements, + ) + + memories = [] + for i, mem in enumerate(mem_vars_set): + memory = Memory(mem, memory_type="RAM", entity_name=f"memory{i}") + memories.append(memory) + memory.assign("graph_color") + + arch = Architecture( + processing_elements, + memories, + direct_interconnects=direct, + ) + assert len(arch.processing_elements) == 6 + assert len(arch.memories) == 6 + + # GRAPH COLORING + mem_vars_set = mem_vars.split_on_ports( + read_ports=1, + write_ports=1, + total_ports=2, + heuristic="graph_color", + processing_elements=processing_elements, + ) + + memories = [] + for i, mem in enumerate(mem_vars_set): + memory = Memory(mem, memory_type="RAM", entity_name=f"memory{i}") + memories.append(memory) + memory.assign("graph_color") + + arch = Architecture( + processing_elements, + memories, + direct_interconnects=direct, + ) + assert len(arch.processing_elements) == 6 + assert len(arch.memories) == 4 diff --git a/test/unit/test_resources.py b/test/unit/test_resources.py index 43d84d84da3efec0a680a3fec3c484a6d187f201..d12f272eb7b89c75e381d1478c76be76807991be 100644 --- a/test/unit/test_resources.py +++ b/test/unit/test_resources.py @@ -49,7 +49,7 @@ class TestProcessCollectionPlainMemoryVariable: def test_split_sequence_raises(self, simple_collection: ProcessCollection): with pytest.raises(KeyError, match="processes in `sequence` must be"): - simple_collection.split_ports_sequentially( + simple_collection._split_ports_sequentially( read_ports=1, write_ports=1, total_ports=2, sequence=[] ) @@ -71,6 +71,22 @@ class TestProcessCollectionPlainMemoryVariable: ) assert len(split) == 2 + def test_split_memory_variable_raises(self, simple_collection: ProcessCollection): + with pytest.raises( + ValueError, + match="processing_elements must be provided if heuristic = 'min_pe_to_mem'", + ): + simple_collection.split_on_ports(heuristic="min_pe_to_mem", total_ports=1) + + with pytest.raises( + ValueError, + match="processing_elements must be provided if heuristic = 'min_mem_to_pe'", + ): + simple_collection.split_on_ports(heuristic="min_mem_to_pe", total_ports=1) + + with pytest.raises(ValueError, match="Invalid heuristic provided."): + simple_collection.split_on_ports(heuristic="foo", total_ports=1) + @matplotlib.testing.decorators.image_comparison( ["test_left_edge_cell_assignment.png"] )