From 1bff8c7f3d470d3a6ee857445263c55bb84a2461 Mon Sep 17 00:00:00 2001 From: Simon Bjurek <simbj106@student.liu.se> Date: Tue, 1 Apr 2025 14:48:33 +0200 Subject: [PATCH] fix entity name bugs --- b_asic/architecture.py | 16 ++++++++++ test/unit/test_architecture.py | 54 ++++++++++++++++++++++++++++++++-- 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/b_asic/architecture.py b/b_asic/architecture.py index 64965893..dab4677c 100644 --- a/b_asic/architecture.py +++ b/b_asic/architecture.py @@ -615,13 +615,29 @@ of :class:`~b_asic.architecture.ProcessingElement` direct_interconnects: ProcessCollection | None = None, ): super().__init__(entity_name) + + pe_names = [pe._entity_name for pe in processing_elements] + if None in pe_names: + raise ValueError( + "Entity names must be defined for all processing elements." + ) + if len(pe_names) != len(set(pe_names)): + raise ValueError("Entity names of processing elements needs to be unique.") self._processing_elements = ( [processing_elements] if isinstance(processing_elements, ProcessingElement) else list(processing_elements) ) + + mem_names = [mem._entity_name for mem in memories] + if None in mem_names: + raise ValueError("Entity names must be defined for all memories.") + if len(mem_names) != len(set(mem_names)): + raise ValueError("Entity names of memories needs to be unique.") self._memories = [memories] if isinstance(memories, Memory) else list(memories) + self._direct_interconnects = direct_interconnects + self._variable_input_port_to_resource: defaultdict[ InputPort, set[tuple[Resource, int]] ] = defaultdict(set) diff --git a/test/unit/test_architecture.py b/test/unit/test_architecture.py index e611c048..b9e59741 100644 --- a/test/unit/test_architecture.py +++ b/test/unit/test_architecture.py @@ -5,7 +5,7 @@ import pytest from b_asic.architecture import Architecture, Memory, ProcessingElement from b_asic.core_operations import Addition, ConstantMultiplication -from b_asic.process import PlainMemoryVariable +from b_asic.process import MemoryProcess, OperatorProcess, PlainMemoryVariable from b_asic.resources import ProcessCollection from b_asic.schedule import Schedule from b_asic.scheduler import ASAPScheduler @@ -181,6 +181,54 @@ def test_architecture(schedule_direct_form_iir_lp_filter: Schedule): ) +def test_architecture_not_unique_entity_names(): + pe_1 = ProcessingElement( + ProcessCollection([OperatorProcess(0, Addition(execution_time=1))], 1), + entity_name="foo", + ) + pe_2 = ProcessingElement( + ProcessCollection([OperatorProcess(0, Addition(execution_time=1))], 1), + entity_name="foo", + ) + with pytest.raises( + ValueError, match="Entity names of processing elements needs to be unique." + ): + Architecture([pe_1, pe_2], []) + + mem_1 = Memory( + ProcessCollection([MemoryProcess(0, [1])], 1), + memory_type="RAM", + entity_name="bar", + ) + mem_2 = Memory( + ProcessCollection([MemoryProcess(0, [1])], 1), + memory_type="RAM", + entity_name="bar", + ) + with pytest.raises( + ValueError, match="Entity names of memories needs to be unique." + ): + Architecture([], [mem_1, mem_2]) + + +def test_architecture_entity_names_not_set(): + pe = ProcessingElement( + ProcessCollection([OperatorProcess(0, Addition(execution_time=1))], 1), + ) + with pytest.raises( + ValueError, match="Entity names must be defined for all processing elements." + ): + Architecture([pe], []) + + mem = Memory( + ProcessCollection([MemoryProcess(0, [1])], 1), + ) + with pytest.raises( + ValueError, match="Entity names must be defined for all memories." + ): + Architecture([], [mem]) + + def test_move_process(schedule_direct_form_iir_lp_filter: Schedule): # Resources mvs = schedule_direct_form_iir_lp_filter.get_memory_variables() @@ -210,7 +258,9 @@ def test_move_process(schedule_direct_form_iir_lp_filter: Schedule): direct_conn, mvs = mvs.split_on_length() # Create Memories from the memory variables (split on length to get two memories) - memories: list[Memory] = [Memory(pc) for pc in mvs.split_on_length(6)] + memories: list[Memory] = [ + Memory(pc, entity_name=f"mem{i}") for i, pc in enumerate(mvs.split_on_length(6)) + ] # Create architecture architecture = Architecture( -- GitLab