From 4822cfd6238f528f09ba724fae4557239cefed54 Mon Sep 17 00:00:00 2001
From: Simon Bjurek <simbj106@student.liu.se>
Date: Tue, 1 Apr 2025 14:48:33 +0200
Subject: [PATCH] fix entity name bugs

---
 b_asic/architecture.py         | 18 +++++++++++-
 b_asic/signal_flow_graph.py    |  6 ++--
 test/unit/test_architecture.py | 54 ++++++++++++++++++++++++++++++++--
 3 files changed, 72 insertions(+), 6 deletions(-)

diff --git a/b_asic/architecture.py b/b_asic/architecture.py
index 64965893..f105ef0a 100644
--- a/b_asic/architecture.py
+++ b/b_asic/architecture.py
@@ -615,13 +615,29 @@ of :class:`~b_asic.architecture.ProcessingElement`
         direct_interconnects: ProcessCollection | None = None,
     ):
         super().__init__(entity_name)
+
+        pe_names = [pe._entity_name for pe in processing_elements]
+        if None in pe_names:
+            raise ValueError(
+                "Entity names must be defined for all processing elements."
+            )
+        if len(pe_names) != len(set(pe_names)):
+            raise ValueError("Entity names of processing elements needs to be unique.")
         self._processing_elements = (
             [processing_elements]
             if isinstance(processing_elements, ProcessingElement)
             else list(processing_elements)
         )
+
+        mem_names = [mem._entity_name for mem in memories]
+        if None in mem_names:
+            raise ValueError("Entity names must be defined for all memories.")
+        if len(mem_names) != len(set(mem_names)):
+            raise ValueError("Entity names of memories needs to be unique.")
         self._memories = [memories] if isinstance(memories, Memory) else list(memories)
+
         self._direct_interconnects = direct_interconnects
+
         self._variable_input_port_to_resource: defaultdict[
             InputPort, set[tuple[Resource, int]]
         ] = defaultdict(set)
@@ -975,7 +991,7 @@ of :class:`~b_asic.architecture.ProcessingElement`
         self,
         branch_node: bool = True,
         cluster: bool = True,
-        splines: str = "spline",
+        splines: Literal["spline", "line", "ortho", "polyline", "curved"] = "spline",
         io_cluster: bool = True,
         multiplexers: bool = True,
         colored: bool = True,
diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py
index 5b8d7c68..9b95fed5 100644
--- a/b_asic/signal_flow_graph.py
+++ b/b_asic/signal_flow_graph.py
@@ -14,7 +14,7 @@ from io import StringIO
 from math import ceil
 from numbers import Number
 from queue import PriorityQueue
-from typing import ClassVar, Optional, Union, cast
+from typing import ClassVar, Literal, Optional, Union, cast
 
 import numpy as np
 from graphviz import Digraph
@@ -1616,7 +1616,7 @@ class SFG(AbstractOperation):
         engine: str | None = None,
         branch_node: bool = True,
         port_numbering: bool = True,
-        splines: str = "spline",
+        splines: Literal["spline", "line", "ortho", "polyline", "curved"] = "spline",
     ) -> Digraph:
         """
         Return a Digraph of the SFG.
@@ -1729,7 +1729,7 @@ class SFG(AbstractOperation):
         engine: str | None = None,
         branch_node: bool = True,
         port_numbering: bool = True,
-        splines: str = "spline",
+        splines: Literal["spline", "line", "ortho", "polyline", "curved"] = "spline",
     ) -> None:
         """
         Display a visual representation of the SFG using the default system viewer.
diff --git a/test/unit/test_architecture.py b/test/unit/test_architecture.py
index e611c048..b9e59741 100644
--- a/test/unit/test_architecture.py
+++ b/test/unit/test_architecture.py
@@ -5,7 +5,7 @@ import pytest
 
 from b_asic.architecture import Architecture, Memory, ProcessingElement
 from b_asic.core_operations import Addition, ConstantMultiplication
-from b_asic.process import PlainMemoryVariable
+from b_asic.process import MemoryProcess, OperatorProcess, PlainMemoryVariable
 from b_asic.resources import ProcessCollection
 from b_asic.schedule import Schedule
 from b_asic.scheduler import ASAPScheduler
@@ -181,6 +181,54 @@ def test_architecture(schedule_direct_form_iir_lp_filter: Schedule):
     )
 
 
+def test_architecture_not_unique_entity_names():
+    pe_1 = ProcessingElement(
+        ProcessCollection([OperatorProcess(0, Addition(execution_time=1))], 1),
+        entity_name="foo",
+    )
+    pe_2 = ProcessingElement(
+        ProcessCollection([OperatorProcess(0, Addition(execution_time=1))], 1),
+        entity_name="foo",
+    )
+    with pytest.raises(
+        ValueError, match="Entity names of processing elements needs to be unique."
+    ):
+        Architecture([pe_1, pe_2], [])
+
+    mem_1 = Memory(
+        ProcessCollection([MemoryProcess(0, [1])], 1),
+        memory_type="RAM",
+        entity_name="bar",
+    )
+    mem_2 = Memory(
+        ProcessCollection([MemoryProcess(0, [1])], 1),
+        memory_type="RAM",
+        entity_name="bar",
+    )
+    with pytest.raises(
+        ValueError, match="Entity names of memories needs to be unique."
+    ):
+        Architecture([], [mem_1, mem_2])
+
+
+def test_architecture_entity_names_not_set():
+    pe = ProcessingElement(
+        ProcessCollection([OperatorProcess(0, Addition(execution_time=1))], 1),
+    )
+    with pytest.raises(
+        ValueError, match="Entity names must be defined for all processing elements."
+    ):
+        Architecture([pe], [])
+
+    mem = Memory(
+        ProcessCollection([MemoryProcess(0, [1])], 1),
+    )
+    with pytest.raises(
+        ValueError, match="Entity names must be defined for all memories."
+    ):
+        Architecture([], [mem])
+
+
 def test_move_process(schedule_direct_form_iir_lp_filter: Schedule):
     # Resources
     mvs = schedule_direct_form_iir_lp_filter.get_memory_variables()
@@ -210,7 +258,9 @@ def test_move_process(schedule_direct_form_iir_lp_filter: Schedule):
     direct_conn, mvs = mvs.split_on_length()
 
     # Create Memories from the memory variables (split on length to get two memories)
-    memories: list[Memory] = [Memory(pc) for pc in mvs.split_on_length(6)]
+    memories: list[Memory] = [
+        Memory(pc, entity_name=f"mem{i}") for i, pc in enumerate(mvs.split_on_length(6))
+    ]
 
     # Create architecture
     architecture = Architecture(
-- 
GitLab