From 19ac7b917224ba4cd5be0eefd391da6548d1cd39 Mon Sep 17 00:00:00 2001
From: Simon Bjurek <simbj106@student.liu.se>
Date: Mon, 7 Apr 2025 09:33:50 +0200
Subject: [PATCH] Update ILP resouce allocation to minimize PE -> mem muxes

---
 b_asic/resources.py                          | 156 ++++++++++++++++---
 test/integration/test_sfg_to_architecture.py |  28 +++-
 test/unit/test_resources.py                  |  30 +++-
 3 files changed, 186 insertions(+), 28 deletions(-)

diff --git a/b_asic/resources.py b/b_asic/resources.py
index df2b34df..160fd8d9 100644
--- a/b_asic/resources.py
+++ b/b_asic/resources.py
@@ -917,6 +917,7 @@ class ProcessCollection:
         write_ports: int | None = None,
         total_ports: int | None = None,
         processing_elements: list["ProcessingElement"] | None = None,
+        amount_of_sets: int | None = None,
     ) -> list["ProcessCollection"]:
         """
         Split based on concurrent read and write accesses.
@@ -930,11 +931,12 @@ class ProcessCollection:
             Valid options are:
 
             * "ilp_graph_color"
+            * "ilp_min_input_mux"
             * "greedy_graph_color"
             * "equitable_graph_color"
             * "left_edge"
-            * "min_pe_to_mem"
-            * "min_mem_to_pe"
+            * "left_edge_min_pe_to_mem"
+            * "left_edge_min_mem_to_pe"
 
         read_ports : int, optional
             The number of read ports used when splitting process collection based on
@@ -949,7 +951,11 @@ class ProcessCollection:
             memory variable access.
 
         processing_elements : list of ProcessingElement, optional
-            The currently used PEs, only required if heuristic = "min_mem_to_pe".
+            The currently used PEs, only required if heuristic = "min_mem_to_pe",
+            "ilp_graph_color" or "ilp_min_input_mux".
+
+        amount_of_sets : int, optional
+            amount of sets to split to, only required if heuristics = "ilp_min_input_mux".
 
         Returns
         -------
@@ -962,6 +968,22 @@ class ProcessCollection:
             return self._split_ports_ilp_graph_color(
                 read_ports, write_ports, total_ports
             )
+        elif heuristic == "ilp_min_input_mux":
+            if processing_elements is None:
+                raise ValueError(
+                    "processing_elements must be provided if heuristic = 'ilp_min_input_mux'"
+                )
+            if amount_of_sets is None:
+                raise ValueError(
+                    "amount_of_sets must be provided if heuristic = 'ilp_min_input_mux'"
+                )
+            return self._split_ports_ilp_min_input_mux_graph_color(
+                read_ports,
+                write_ports,
+                total_ports,
+                processing_elements,
+                amount_of_sets,
+            )
         elif heuristic == "greedy_graph_color":
             return self._split_ports_greedy_graph_color(
                 read_ports, write_ports, total_ports
@@ -977,24 +999,24 @@ class ProcessCollection:
                 total_ports,
                 sequence=sorted(self),
             )
-        elif heuristic == "min_pe_to_mem":
+        elif heuristic == "left_edge_min_pe_to_mem":
             if processing_elements is None:
                 raise ValueError(
-                    "processing_elements must be provided if heuristic = 'min_pe_to_mem'"
+                    "processing_elements must be provided if heuristic = 'left_edge_min_pe_to_mem'"
                 )
-            return self._split_ports_minimize_pe_to_memory_connections(
+            return self._split_ports_sequentially_minimize_pe_to_memory_connections(
                 read_ports,
                 write_ports,
                 total_ports,
                 sequence=sorted(self),
                 processing_elements=processing_elements,
             )
-        elif heuristic == "min_mem_to_pe":
+        elif heuristic == "left_edge_min_mem_to_pe":
             if processing_elements is None:
                 raise ValueError(
-                    "processing_elements must be provided if heuristic = 'min_mem_to_pe'"
+                    "processing_elements must be provided if heuristic = 'left_edge_min_mem_to_pe'"
                 )
-            return self._split_ports_minimize_memory_to_pe_connections(
+            return self._split_ports_sequentially_minimize_memory_to_pe_connections(
                 read_ports,
                 write_ports,
                 total_ports,
@@ -1064,7 +1086,7 @@ class ProcessCollection:
                 )
         return collections
 
-    def _split_ports_minimize_pe_to_memory_connections(
+    def _split_ports_sequentially_minimize_pe_to_memory_connections(
         self,
         read_ports: int,
         write_ports: int,
@@ -1130,7 +1152,7 @@ class ProcessCollection:
         ]
         return collections
 
-    def _split_ports_minimize_memory_to_pe_connections(
+    def _split_ports_sequentially_minimize_memory_to_pe_connections(
         self,
         read_ports: int,
         write_ports: int,
@@ -1351,17 +1373,17 @@ class ProcessCollection:
         coloring = nx.coloring.greedy_color(
             exclusion_graph, strategy="saturation_largest_first"
         )
-        max_colors = len(set(coloring.values()))
+        colors = range(len(set(coloring.values())))
+
+        # find the minimal amount of colors (memories)
 
         # binary variables:
         #   x[node, color] - whether node is colored in a certain color
         #   c[color] - whether color is used
-        x = LpVariable.dicts("x", (nodes, range(max_colors)), cat=LpBinary)
-        c = LpVariable.dicts("c", range(max_colors), cat=LpBinary)
-
-        # create the problem, objective function - minimize the number of colors used
+        x = LpVariable.dicts("x", (nodes, colors), cat=LpBinary)
+        c = LpVariable.dicts("c", colors, cat=LpBinary)
         problem = LpProblem()
-        problem += lpSum(c[i] for i in range(max_colors))
+        problem += lpSum(c[i] for i in colors)
 
         # constraints:
         #   1 - nodes have exactly one color
@@ -1369,12 +1391,12 @@ class ProcessCollection:
         #   3 - only permit assignments if color is used
         #   4 - reduce solution space by setting the color of one node
         for node in nodes:
-            problem += lpSum(x[node][i] for i in range(max_colors)) == 1
+            problem += lpSum(x[node][i] for i in colors) == 1
         for u, v in edges:
-            for color in range(max_colors):
-                problem += x[u][color] + x[v][color] <= c[color]
+            for color in colors:
+                problem += x[u][color] + x[v][color] <= 1
         for node in nodes:
-            for color in range(max_colors):
+            for color in colors:
                 problem += x[node][color] <= c[color]
         problem += x[nodes[0]][0] == c[0] == 1
 
@@ -1387,7 +1409,97 @@ class ProcessCollection:
 
         node_colors = {}
         for node in nodes:
-            for i in range(max_colors):
+            for i in colors:
+                if value(x[node][i]) == 1:
+                    node_colors[node] = i
+
+        # reduce the solution by removing unused colors
+        sorted_unique_values = sorted(set(node_colors.values()))
+        coloring_mapping = {val: i for i, val in enumerate(sorted_unique_values)}
+        minimal_coloring = {
+            key: coloring_mapping[node_colors[key]] for key in node_colors
+        }
+
+        return self._split_from_graph_coloring(minimal_coloring)
+
+    def _split_ports_ilp_min_input_mux_graph_color(
+        self,
+        read_ports: int,
+        write_ports: int,
+        total_ports: int,
+        processing_elements: list["ProcessingElement"],
+        amount_of_colors: int,
+    ) -> list["ProcessCollection"]:
+        from pulp import (
+            LpBinary,
+            LpProblem,
+            LpStatusOptimal,
+            LpVariable,
+            lpSum,
+            value,
+        )
+
+        # create new exclusion graph. Nodes are Processes
+        exclusion_graph = self.create_exclusion_graph_from_ports(
+            read_ports, write_ports, total_ports
+        )
+        nodes = list(exclusion_graph.nodes())
+        edges = list(exclusion_graph.edges())
+
+        colors = range(amount_of_colors)
+
+        # minimize the amount of input muxes connecting PEs to memories
+        # by minimizing the amount of PEs connected to each memory
+
+        # binary variables:
+        #   x[node, color] - whether node is colored in a certain color
+        #   c[color] - whether color is used
+        #   y[pe, color] - whether a color has nodes generated from a certain pe
+
+        x = LpVariable.dicts("x", (nodes, colors), cat=LpBinary)
+        c = LpVariable.dicts("c", colors, cat=LpBinary)
+        y = LpVariable.dicts("y", (processing_elements, colors), cat=LpBinary)
+        problem = LpProblem()
+        problem += lpSum(y[pe][i] for pe in processing_elements for i in colors)
+
+        def _get_source(
+            var: MemoryVariable, pes: list["ProcessingElement"]
+        ) -> "ProcessingElement":
+            name = var.name.split(".")[0]
+            for pe in pes:
+                pe_names = [proc.name for proc in pe.collection]
+                if name in pe_names:
+                    return pe
+            raise ValueError("Source could not be found for the given variable.")
+
+        # constraints:
+        #   1 - nodes have exactly one color
+        #   2 - adjacent nodes cannot have the same color
+        #   3 - only permit assignments if color is used
+        #   4 - if node is colored then enable the PE which generates that node (variable)
+        for node in nodes:
+            problem += lpSum(x[node][i] for i in colors) == 1
+        for u, v in edges:
+            for color in colors:
+                problem += x[u][color] + x[v][color] <= 1
+        for node in nodes:
+            for color in colors:
+                problem += x[node][color] <= c[color]
+        for node in nodes:
+            pe = _get_source(node, processing_elements)
+            for color in colors:
+                problem += x[node][color] <= y[pe][color]
+
+        status = problem.solve()
+
+        if status != LpStatusOptimal:
+            raise ValueError(
+                "Optimal solution could not be found via ILP, use another method."
+            )
+
+        node_colors = {}
+        for node in nodes:
+            for i in colors:
                 if value(x[node][i]) == 1:
                     node_colors[node] = i
 
diff --git a/test/integration/test_sfg_to_architecture.py b/test/integration/test_sfg_to_architecture.py
index b530cd3c..2dbc6602 100644
--- a/test/integration/test_sfg_to_architecture.py
+++ b/test/integration/test_sfg_to_architecture.py
@@ -227,7 +227,7 @@ def test_different_resource_algorithms():
         read_ports=1,
         write_ports=1,
         total_ports=2,
-        heuristic="min_pe_to_mem",
+        heuristic="left_edge_min_pe_to_mem",
         processing_elements=processing_elements,
     )
 
@@ -250,7 +250,7 @@ def test_different_resource_algorithms():
         read_ports=1,
         write_ports=1,
         total_ports=2,
-        heuristic="min_mem_to_pe",
+        heuristic="left_edge_min_mem_to_pe",
         processing_elements=processing_elements,
     )
 
@@ -371,3 +371,27 @@ def test_different_resource_algorithms():
     )
     assert len(arch.processing_elements) == 5
     assert len(arch.memories) == 4
+
+    # ILP COLOR MIN INPUT MUX
+    mem_vars_set = mem_vars.split_on_ports(
+        read_ports=1,
+        write_ports=1,
+        total_ports=2,
+        heuristic="ilp_min_input_mux",
+        processing_elements=processing_elements,
+        amount_of_sets=4,
+    )
+
+    memories = []
+    for i, mem in enumerate(mem_vars_set):
+        memory = Memory(mem, memory_type="RAM", entity_name=f"memory{i}")
+        memories.append(memory)
+        memory.assign("graph_color")
+
+    arch = Architecture(
+        processing_elements,
+        memories,
+        direct_interconnects=direct,
+    )
+    assert len(arch.processing_elements) == 5
+    assert len(arch.memories) == 4
diff --git a/test/unit/test_resources.py b/test/unit/test_resources.py
index c54cca94..ac8bcc53 100644
--- a/test/unit/test_resources.py
+++ b/test/unit/test_resources.py
@@ -74,15 +74,37 @@ class TestProcessCollectionPlainMemoryVariable:
     def test_split_memory_variable_raises(self, simple_collection: ProcessCollection):
         with pytest.raises(
             ValueError,
-            match="processing_elements must be provided if heuristic = 'min_pe_to_mem'",
+            match="processing_elements must be provided if heuristic = 'ilp_min_input_mux'",
         ):
-            simple_collection.split_on_ports(heuristic="min_pe_to_mem", total_ports=1)
+            simple_collection.split_on_ports(
+                heuristic="ilp_min_input_mux", total_ports=1
+            )
+
+        with pytest.raises(
+            ValueError,
+            match="amount_of_sets must be provided if heuristic = 'ilp_min_input_mux'",
+        ):
+            simple_collection.split_on_ports(
+                heuristic="ilp_min_input_mux",
+                total_ports=1,
+                processing_elements=[],
+            )
 
         with pytest.raises(
             ValueError,
-            match="processing_elements must be provided if heuristic = 'min_mem_to_pe'",
+            match="processing_elements must be provided if heuristic = 'left_edge_min_pe_to_mem'",
         ):
-            simple_collection.split_on_ports(heuristic="min_mem_to_pe", total_ports=1)
+            simple_collection.split_on_ports(
+                heuristic="left_edge_min_pe_to_mem", total_ports=1
+            )
+
+        with pytest.raises(
+            ValueError,
+            match="processing_elements must be provided if heuristic = 'left_edge_min_mem_to_pe'",
+        ):
+            simple_collection.split_on_ports(
+                heuristic="left_edge_min_mem_to_pe", total_ports=1
+            )
 
         with pytest.raises(ValueError, match="Invalid heuristic provided."):
             simple_collection.split_on_ports(heuristic="foo", total_ports=1)
-- 
GitLab