diff --git a/b_asic/architecture.py b/b_asic/architecture.py index 815b00873e2abf85164a6154c7d9bc4d6a0a4218..649658937e115778ed880a7d362ab07fbe2a2e85 100644 --- a/b_asic/architecture.py +++ b/b_asic/architecture.py @@ -929,7 +929,7 @@ of :class:`~b_asic.architecture.ProcessingElement` fmt: str | None = None, branch_node: bool = True, cluster: bool = True, - splines: str = "spline", + splines: Literal["spline", "line", "ortho", "polyline", "curved"] = "spline", io_cluster: bool = True, multiplexers: bool = True, colored: bool = True, diff --git a/b_asic/resources.py b/b_asic/resources.py index 87f8659e13ff09effad54a84b6c99ff8f13986e3..85d555fac220edbdf33c3c5f3a21321c2b7dfd28 100644 --- a/b_asic/resources.py +++ b/b_asic/resources.py @@ -946,7 +946,7 @@ class ProcessCollection: The total number of ports used when splitting process collection based on memory variable access. - processing_elements : list[ProcessingElements], optional + processing_elements : list of ProcessingElement, optional The currently used PEs, only required if heuristic = "min_mem_to_pe". Returns @@ -959,7 +959,7 @@ class ProcessCollection: if heuristic == "graph_color": return self._split_ports_graph_color(read_ports, write_ports, total_ports) elif heuristic == "left_edge": - return self._split_ports_sequentially( + return self.split_ports_sequentially( read_ports, write_ports, total_ports, @@ -992,7 +992,7 @@ class ProcessCollection: else: raise ValueError("Invalid heuristic provided.") - def _split_ports_sequentially( + def split_ports_sequentially( self, read_ports: int, write_ports: int, @@ -1065,7 +1065,7 @@ class ProcessCollection: raise KeyError("processes in `sequence` must be equal to processes in self") num_of_memories = len( - self._split_ports_sequentially( + self.split_ports_sequentially( read_ports, write_ports, total_ports, sequence ) ) @@ -1087,11 +1087,11 @@ class ProcessCollection: for i, collection in enumerate(collections): if process_fits_in_collection[i]: - count_1 = ProcessCollection._count_number_of_pes_connected( + count_1 = ProcessCollection._count_number_of_pes_read_from( processing_elements, collection ) tmp_collection = [*collection.collection, process] - count_2 = ProcessCollection._count_number_of_pes_connected( + count_2 = ProcessCollection._count_number_of_pes_read_from( processing_elements, tmp_collection ) delta = count_2 - count_1 @@ -1113,9 +1113,9 @@ class ProcessCollection: best_collection.add_process(process) - for i in range(len(collections)): - if not collections[i].collection: - collections.pop(i) + collections = [ + collection for collection in collections if collection.collection + ] return collections def _split_ports_minimize_memory_to_pe_connections( @@ -1126,7 +1126,63 @@ class ProcessCollection: sequence: list[Process], processing_elements: list["ProcessingElement"], ) -> list["ProcessCollection"]: - raise NotImplementedError() + + if set(self.collection) != set(sequence): + raise KeyError("processes in `sequence` must be equal to processes in self") + + num_of_memories = len( + self.split_ports_sequentially( + read_ports, write_ports, total_ports, sequence + ) + ) + collections: list[ProcessCollection] = [ + ProcessCollection( + [], + schedule_time=self.schedule_time, + cyclic=self._cyclic, + ) + for _ in range(num_of_memories) + ] + + for process in sequence: + process_fits_in_collection = self._get_process_fits_in_collection( + process, collections, read_ports, write_ports, total_ports + ) + best_collection = None + best_delta = sys.maxsize + + for i, collection in enumerate(collections): + if process_fits_in_collection[i]: + + count_1 = ProcessCollection._count_number_of_pes_written_to( + processing_elements, collection + ) + tmp_collection = [*collection.collection, process] + count_2 = ProcessCollection._count_number_of_pes_written_to( + processing_elements, tmp_collection + ) + delta = count_2 - count_1 + if delta < best_delta: + best_collection = collection + best_delta = delta + + elif not any(process_fits_in_collection): + collections.append( + ProcessCollection( + [], + schedule_time=self.schedule_time, + cyclic=self._cyclic, + ) + ) + process_fits_in_collection = self._get_process_fits_in_collection( + process, collections, read_ports, write_ports, total_ports + ) + best_collection.add_process(process) + + collections = [ + collection for collection in collections if collection.collection + ] + return collections def _get_process_fits_in_collection( self, process, collections, write_ports, read_ports, total_ports @@ -1167,7 +1223,7 @@ class ProcessCollection: return False @staticmethod - def _count_number_of_pes_connected( + def _count_number_of_pes_read_from( processing_elements: list["ProcessingElement"], collection: "ProcessCollection", ) -> int: @@ -1181,6 +1237,24 @@ class ProcessCollection: count += 1 return count + @staticmethod + def _count_number_of_pes_written_to( + processing_elements: list["ProcessingElement"], + collection: "ProcessCollection", + ) -> int: + collection_process_names = {proc.name.split(".")[0] for proc in collection} + count = 0 + for pe in processing_elements: + for process in pe.processes: + for input in process.operation.inputs: + input_op = input.connected_source.operation + if input_op.graph_id in collection_process_names: + count += 1 + break + if count != 0: + break + return count + def _split_ports_graph_color( self, read_ports: int, diff --git a/test/integration/test_sfg_to_architecture.py b/test/integration/test_sfg_to_architecture.py index 6c183ad27da259c22510773f690618e8119edde8..8266b7b9756688f54b900b1f325bc5b18ec4eb01 100644 --- a/test/integration/test_sfg_to_architecture.py +++ b/test/integration/test_sfg_to_architecture.py @@ -245,6 +245,29 @@ def test_different_resource_algorithms(): assert len(arch.processing_elements) == 6 assert len(arch.memories) == 6 + # MIN-MEM-TO-PE + mem_vars_set = mem_vars.split_on_ports( + read_ports=1, + write_ports=1, + total_ports=2, + heuristic="min_mem_to_pe", + processing_elements=processing_elements, + ) + + memories = [] + for i, mem in enumerate(mem_vars_set): + memory = Memory(mem, memory_type="RAM", entity_name=f"memory{i}") + memories.append(memory) + memory.assign("graph_color") + + arch = Architecture( + processing_elements, + memories, + direct_interconnects=direct, + ) + assert len(arch.processing_elements) == 6 + assert len(arch.memories) == 6 + # GRAPH COLORING mem_vars_set = mem_vars.split_on_ports( read_ports=1, diff --git a/test/unit/test_resources.py b/test/unit/test_resources.py index d12f272eb7b89c75e381d1478c76be76807991be..e6c01c83f5fb509cb88f5c441ad327052bdf9070 100644 --- a/test/unit/test_resources.py +++ b/test/unit/test_resources.py @@ -49,7 +49,7 @@ class TestProcessCollectionPlainMemoryVariable: def test_split_sequence_raises(self, simple_collection: ProcessCollection): with pytest.raises(KeyError, match="processes in `sequence` must be"): - simple_collection._split_ports_sequentially( + simple_collection.split_ports_sequentially( read_ports=1, write_ports=1, total_ports=2, sequence=[] )