From fb329086ace991b7276c8988ecf5efafefec69d5 Mon Sep 17 00:00:00 2001
From: Simon Bjurek <simbj106@student.liu.se>
Date: Fri, 18 Apr 2025 17:04:47 +0200
Subject: [PATCH] Add mux reduction to joint ILP resource algorithms

---
 b_asic/architecture.py                       |   8 +-
 b_asic/resource_assigner.py                  | 367 +++++++++++++++++--
 b_asic/resources.py                          |  11 +-
 pyproject.toml                               |   2 +-
 test/integration/test_sfg_to_architecture.py |  46 ++-
 5 files changed, 390 insertions(+), 44 deletions(-)

diff --git a/b_asic/architecture.py b/b_asic/architecture.py
index 93ffff19..1f93b569 100644
--- a/b_asic/architecture.py
+++ b/b_asic/architecture.py
@@ -788,8 +788,8 @@ of :class:`~b_asic.architecture.ProcessingElement`
         Returns
         -------
         (dict, dict)
-            A dictionary with the ProcessingElements that are connected to the write and
-            read ports, respectively, with counts of the number of accesses.
+            A dictionary with the ProcessingElements that are connected to the read and
+            write ports, respectively, with counts of the number of accesses.
         """
         if isinstance(mem, str):
             mem = cast(Memory, self.resource_from_name(mem))
@@ -821,10 +821,10 @@ of :class:`~b_asic.architecture.ProcessingElement`
         Returns
         -------
         list
-            List of dictionaries indicating the sources for each import and the
+            List of dictionaries indicating the sources for each input port and the
             frequency of accesses.
         list
-            List of dictionaries indicating the sources for each outport and the
+            List of dictionaries indicating the destinations for each output port and the
             frequency of accesses.
         """
         if isinstance(pe, str):
diff --git a/b_asic/resource_assigner.py b/b_asic/resource_assigner.py
index 762c2526..8abcb11f 100644
--- a/b_asic/resource_assigner.py
+++ b/b_asic/resource_assigner.py
@@ -1,3 +1,5 @@
+from typing import Literal, cast
+
 import networkx as nx
 from pulp import (
     GUROBI,
@@ -11,6 +13,8 @@ from pulp import (
 )
 
 from b_asic.architecture import Memory, ProcessingElement
+from b_asic.operation import Operation
+from b_asic.port import OutputPort
 from b_asic.process import Process
 from b_asic.resources import ProcessCollection
 from b_asic.types import TypeName
@@ -19,6 +23,10 @@ from b_asic.types import TypeName
 def assign_processing_elements_and_memories(
     operations: ProcessCollection,
     memory_variables: ProcessCollection,
+    strategy: Literal[
+        "ilp_graph_color",
+        "ilp_min_total_mux",
+    ] = "ilp_graph_color",
     resources: dict[TypeName, int] | None = None,
     max_memories: int | None = None,
     memory_read_ports: int | None = None,
@@ -58,8 +66,7 @@ def assign_processing_elements_and_memories(
         memory variable access.
 
     solver : PuLP MIP solver object, optional
-        Only used if strategy is an ILP method.
-        Valid options are
+        Valid options are:
 
         * PULP_CBC_CMD() - preinstalled
         * GUROBI() - license required, but likely faster
@@ -69,10 +76,13 @@ def assign_processing_elements_and_memories(
     A tuple containing one list of assigned PEs and one list of assigned memories.
     """
     operation_groups = operations.split_on_type_name()
+    direct, mem_vars = memory_variables.split_on_length()
 
     operations_set, memory_variable_set = _split_operations_and_variables(
         operation_groups,
-        memory_variables,
+        mem_vars,
+        direct,
+        strategy,
         resources,
         max_memories,
         memory_read_ports,
@@ -92,12 +102,17 @@ def assign_processing_elements_and_memories(
         for i, mem in enumerate(memory_variable_set)
     ]
 
-    return processing_elements, memories
+    return processing_elements, memories, direct
 
 
 def _split_operations_and_variables(
     operation_groups: dict[TypeName, ProcessCollection],
     memory_variables: ProcessCollection,
+    direct_variables: ProcessCollection,
+    strategy: Literal[
+        "ilp_graph_color",
+        "ilp_min_total_mux",
+    ] = "ilp_graph_color",
     resources: dict[TypeName, int] | None = None,
     max_memories: int | None = None,
     memory_read_ports: int | None = None,
@@ -121,17 +136,23 @@ def _split_operations_and_variables(
     # generate the exclusion graphs along with a color upper bound for PEs
     pe_exclusion_graphs = []
     pe_colors = []
+    pe_operations = []
     for group in operation_groups.values():
         pe_ex_graph = group.create_exclusion_graph_from_execution_time()
         pe_exclusion_graphs.append(pe_ex_graph)
-        pe_op_type = next(iter(group)).operation.type_name()
-        if pe_op_type in resources:
-            coloring = nx.coloring.greedy_color(
-                pe_ex_graph, strategy="saturation_largest_first"
-            )
-            pe_colors.append(range(len(set(coloring.values()))))
+        operation = next(iter(group)).operation
+        pe_operations.append(operation)
+        if strategy == "ilp_graph_color":
+            if not resources or operation.type_name() not in resources:
+                coloring = nx.coloring.greedy_color(
+                    pe_ex_graph, strategy="saturation_largest_first"
+                )
+                pe_colors.append(range(len(set(coloring.values()))))
+            else:
+                pe_colors.append(range(resources[operation.type_name()]))
         else:
-            pe_colors.append(range(resources[pe_op_type]))
+            pe_colors.append(list(range(resources[operation.type_name()])))
+        print(operation.type_name())
 
     # generate the exclusion graphs along with a color upper bound for memories
     mem_exclusion_graph = memory_variables.create_exclusion_graph_from_ports(
@@ -144,10 +165,25 @@ def _split_operations_and_variables(
         max_memories = len(set(coloring.values()))
     mem_colors = range(max_memories)
 
-    # color the graphs concurrently using ILP to minimize the total amount of resources
-    pe_x, mem_x = _ilp_coloring(
-        pe_exclusion_graphs, mem_exclusion_graph, mem_colors, pe_colors, solver
-    )
+    if strategy == "ilp_graph_color":
+        # color the graphs concurrently using ILP to minimize the total amount of resources
+        pe_x, mem_x = _ilp_coloring(
+            pe_exclusion_graphs, mem_exclusion_graph, mem_colors, pe_colors, solver
+        )
+    elif strategy == "ilp_min_total_mux":
+        # color the graphs concurrently using ILP to minimize the amount of multiplexers
+        # given the amount of resources and memories
+        pe_x, mem_x = _ilp_coloring_min_mux(
+            pe_exclusion_graphs,
+            mem_exclusion_graph,
+            mem_colors,
+            pe_colors,
+            pe_operations,
+            direct_variables,
+            solver,
+        )
+    else:
+        raise ValueError(f"Invalid strategy '{strategy}'")
 
     # assign memories based on coloring
     mem_process_collections = _get_assignment_from_coloring(
@@ -167,12 +203,12 @@ def _split_operations_and_variables(
 
 
 def _ilp_coloring(
-    pe_exclusion_graphs,
-    mem_exclusion_graph,
-    mem_colors,
-    pe_colors,
-    solver,
-):
+    pe_exclusion_graphs: list[nx.Graph],
+    mem_exclusion_graph: nx.Graph,
+    mem_colors: list[int],
+    pe_colors: list[list[int]],
+    solver: PULP_CBC_CMD | GUROBI | None = None,
+) -> tuple[dict, dict]:
     mem_graph_nodes = list(mem_exclusion_graph.nodes())
     mem_graph_edges = list(mem_exclusion_graph.edges())
 
@@ -186,8 +222,8 @@ def _ilp_coloring(
     #       colored in a certain color
     #   pe_c[i, color] whether color is used in the i:th PE exclusion graph
 
-    mem_x = LpVariable.dicts("x", (mem_graph_nodes, mem_colors), cat=LpBinary)
-    mem_c = LpVariable.dicts("c", mem_colors, cat=LpBinary)
+    mem_x = LpVariable.dicts("mem_x", (mem_graph_nodes, mem_colors), cat=LpBinary)
+    mem_c = LpVariable.dicts("mem_c", mem_colors, cat=LpBinary)
 
     pe_x = {}
     for i, pe_exclusion_graph in enumerate(pe_exclusion_graphs):
@@ -195,13 +231,15 @@ def _ilp_coloring(
         for node in list(pe_exclusion_graph.nodes()):
             pe_x[i][node] = {}
             for color in pe_colors[i]:
-                pe_x[i][node][color] = LpVariable(f"x_{i}_{node}_{color}", cat=LpBinary)
+                pe_x[i][node][color] = LpVariable(
+                    f"pe_x_{i}_{node}_{color}", cat=LpBinary
+                )
 
     pe_c = {}
     for i in range(len(pe_exclusion_graphs)):
         pe_c[i] = {}
         for color in pe_colors[i]:
-            pe_c[i][color] = LpVariable(f"x_{i}_{color}", cat=LpBinary)
+            pe_c[i][color] = LpVariable(f"pe_c_{i}_{color}", cat=LpBinary)
 
     problem = LpProblem()
     problem += lpSum(mem_c[color] for color in mem_colors) + lpSum(
@@ -265,6 +303,229 @@ def _ilp_coloring(
     return pe_x, mem_x
 
 
+def _ilp_coloring_min_mux(
+    pe_exclusion_graphs: list[nx.Graph],
+    mem_exclusion_graph: nx.Graph,
+    mem_colors: list[int],
+    pe_colors: list[list[int]],
+    pe_operations: list[Operation],
+    direct: ProcessCollection,
+    solver: PULP_CBC_CMD | GUROBI | None = None,
+) -> tuple[dict, dict]:
+    mem_graph_nodes = list(mem_exclusion_graph.nodes())
+    mem_graph_edges = list(mem_exclusion_graph.edges())
+
+    pe_ops = [
+        op
+        for i in range(len(pe_exclusion_graphs))
+        for op in list(pe_exclusion_graphs[i].nodes())
+    ]
+
+    pe_in_port_indices = [list(range(op.input_count)) for op in pe_operations]
+    pe_out_port_indices = [list(range(op.output_count)) for op in pe_operations]
+
+    # specify the ILP problem of minimizing the amount of resources
+
+    # binary variables:
+    #   mem_x[node, color] - whether node in memory exclusion graph is colored
+    #       in a certain color
+    #   mem_c[color] - whether color is used in the memory exclusion graph
+    #   pe_x[i, node, color] - whether node in the i:th PE exclusion graph is
+    #       colored in a certain color
+    #   pe_c[i, color] whether color is used in the i:th PE exclusion graph
+    #   a[i, j, k, l] - whether the k:th output port of the j:th PE in the i:th graph
+    #       writes to the the l:th memory
+    #   b[i, j, k, l] - whether the i:th memory
+    #       writes to the l:th input port of the k:th PE in the j:th PE exclusion graph
+    #   c[i, j, k, l, m, n] - whether the k:th output port of the j:th PE in the i:th PE exclusion graph
+    #       writes to the n:th input port of the m:th PE in the l:th PE exclusion graph
+
+    mem_x = LpVariable.dicts("mem_x", (mem_graph_nodes, mem_colors), cat=LpBinary)
+    mem_c = LpVariable.dicts("mem_c", mem_colors, cat=LpBinary)
+
+    pe_x = {}
+    for i, pe_exclusion_graph in enumerate(pe_exclusion_graphs):
+        pe_x[i] = {}
+        for node in list(pe_exclusion_graph.nodes()):
+            pe_x[i][node] = {}
+            for color in pe_colors[i]:
+                pe_x[i][node][color] = LpVariable(
+                    f"pe_x_{i}_{node}_{color}", cat=LpBinary
+                )
+
+    pe_c = {}
+    for i in range(len(pe_exclusion_graphs)):
+        pe_c[i] = {}
+        for color in pe_colors[i]:
+            pe_c[i][color] = LpVariable(f"pe_c_{i}_{color}", cat=LpBinary)
+
+    a = {}
+    for i in range(len(pe_exclusion_graphs)):
+        a[i] = {}
+        for j in pe_colors[i]:
+            a[i][j] = {}
+            for k in pe_out_port_indices[i]:
+                a[i][j][k] = {}
+                for l in mem_colors:
+                    a[i][j][k][l] = LpVariable(f"a_{i}_{j}_{k}_{l}", cat=LpBinary)
+
+    b = {}
+    for i in mem_colors:
+        b[i] = {}
+        for j in range(len(pe_exclusion_graphs)):
+            b[i][j] = {}
+            for k in pe_colors[j]:
+                b[i][j][k] = {}
+                for l in pe_in_port_indices[j]:
+                    b[i][j][k][l] = LpVariable(f"b_{i}_{j}_{k}_{l}", cat=LpBinary)
+
+    c = {}
+    for i in range(len(pe_exclusion_graphs)):
+        c[i] = {}
+        for j in pe_colors[i]:
+            c[i][j] = {}
+            for k in pe_out_port_indices[i]:
+                c[i][j][k] = {}
+                for l in range(len(pe_exclusion_graphs)):
+                    c[i][j][k][l] = {}
+                    for m in pe_colors[l]:
+                        c[i][j][k][l][m] = {}
+                        for n in pe_in_port_indices[l]:
+                            c[i][j][k][l][m][n] = LpVariable(
+                                f"c_{i}_{j}_{k}_{l}_{m}_{n}", cat=LpBinary
+                            )
+
+    problem = LpProblem()
+    problem += (
+        lpSum(
+            [
+                a[i][j][k][l]
+                for i in range(len(pe_exclusion_graphs))
+                for j in pe_colors[i]
+                for k in pe_out_port_indices[i]
+                for l in mem_colors
+            ]
+        )
+        + lpSum(
+            [
+                b[i][j][k][l]
+                for i in mem_colors
+                for j in range(len(pe_exclusion_graphs))
+                for k in pe_colors[j]
+                for l in pe_in_port_indices[j]
+            ]
+        )
+        + lpSum(
+            [
+                c[i][j][k][l][m][n]
+                for i in range(len(pe_exclusion_graphs))
+                for j in pe_colors[i]
+                for k in pe_out_port_indices[i]
+                for l in range(len(pe_exclusion_graphs))
+                for m in pe_colors[l]
+                for n in pe_in_port_indices[l]
+            ]
+        )
+    )
+
+    # coloring constraints for the memory variable exclusion graph
+    for node in mem_graph_nodes:
+        problem += lpSum(mem_x[node][i] for i in mem_colors) == 1
+    for u, v in mem_graph_edges:
+        for color in mem_colors:
+            problem += mem_x[u][color] + mem_x[v][color] <= 1
+    for node in mem_graph_nodes:
+        for color in mem_colors:
+            problem += mem_x[node][color] <= mem_c[color]
+
+    # connect assignment to "a"
+    for i in range(len(pe_exclusion_graphs)):
+        pe_nodes = list(pe_exclusion_graphs[i].nodes())
+        for pe_node in pe_nodes:
+            for k in pe_out_port_indices[i]:
+                mem_node = _get_mem_node(pe_node, k, mem_exclusion_graph)
+                if mem_node is not None:
+                    for j in pe_colors[i]:
+                        for l in mem_colors:
+                            problem += a[i][j][k][l] >= (
+                                pe_x[i][pe_node][j] + mem_x[mem_node][l] - 1
+                            )
+
+    # connect assignment to "b"
+    for mem_node in mem_graph_nodes:
+        for j in range(len(pe_exclusion_graphs)):
+            pe_pairs = _get_pe_nodes(mem_node, pe_exclusion_graphs[j])
+            for pair in pe_pairs:
+                for k in pe_colors[j]:
+                    for i in mem_colors:
+                        problem += b[i][j][k][pair[1]] >= (
+                            pe_x[j][pair[0]][k] + mem_x[mem_node][i] - 1
+                        )
+
+    # connect assignment to "c"
+    for i in range(len(pe_exclusion_graphs)):
+        pe_nodes_1 = list(pe_exclusion_graphs[i].nodes())
+        for j in range(len(pe_exclusion_graphs)):
+            for pe_node_1 in pe_nodes_1:
+                for l in pe_in_port_indices[j]:
+                    for k in pe_out_port_indices[i]:
+                        pe_nodes_2 = _get_pe_to_pe_connection(
+                            pe_node_1, direct, pe_ops, l, k
+                        )
+                        for pe_node_2 in pe_nodes_2:
+                            for pe_color_1 in pe_colors[i]:
+                                if pe_node_2 in pe_exclusion_graphs[j].nodes():
+                                    for pe_color_2 in pe_colors[j]:
+                                        problem += c[i][pe_color_1][k][j][pe_color_2][
+                                            l
+                                        ] >= (
+                                            pe_x[i][pe_node_1][pe_color_1]
+                                            + pe_x[j][pe_node_2][pe_color_2]
+                                            - 1
+                                        )
+
+    # speed
+    max_clique = next(nx.find_cliques(mem_exclusion_graph))
+    for color, node in enumerate(max_clique):
+        problem += mem_x[node][color] == mem_c[color] == 1
+    for color in mem_colors:
+        problem += mem_c[color] <= lpSum(mem_x[node][color] for node in mem_graph_nodes)
+    for color in mem_colors[:-1]:
+        problem += mem_c[color + 1] <= mem_c[color]
+
+    for i, pe_exclusion_graph in enumerate(pe_exclusion_graphs):
+        # coloring constraints for PE exclusion graphs
+        nodes = list(pe_exclusion_graph.nodes())
+        edges = list(pe_exclusion_graph.edges())
+        for node in nodes:
+            problem += lpSum(pe_x[i][node][color] for color in pe_colors[i]) == 1
+        for u, v in edges:
+            for color in pe_colors[i]:
+                problem += pe_x[i][u][color] + pe_x[i][v][color] <= 1
+        for node in nodes:
+            for color in pe_colors[i]:
+                problem += pe_x[i][node][color] <= pe_c[i][color]
+        # speed
+        max_clique = next(nx.find_cliques(pe_exclusion_graphs[i]))
+        for color, node in enumerate(max_clique):
+            problem += pe_x[i][node][color] == pe_c[i][color] == 1
+        for color in pe_colors[i]:
+            problem += pe_c[i][color] <= lpSum(pe_x[i][node][color] for node in nodes)
+        for color in pe_colors[i][:-1]:
+            problem += pe_c[i][color + 1] <= pe_c[i][color]
+
+    if solver is None:
+        solver = PULP_CBC_CMD()
+
+    status = problem.solve(solver)
+
+    if status != LpStatusOptimal:
+        raise ValueError(
+            "Optimal solution could not be found via ILP, use another method."
+        )
+    return pe_x, mem_x
+
+
 def _get_assignment_from_coloring(
     exclusion_graph: nx.Graph,
     x: dict[Process, dict[int, LpVariable]],
@@ -287,3 +548,59 @@ def _get_assignment_from_coloring(
         assignment[cell].add_process(process)
 
     return list(assignment.values())
+
+
+def _get_mem_node(
+    pe_node: Process, pe_port_index: int, mem_nodes: list[Process]
+) -> tuple[Process, int] | tuple[None, None]:
+    for mem_process in mem_nodes:
+        split_name = iter(mem_process.name.split("."))
+        var_name = next(split_name)
+        port_index = int(next(split_name))
+        if var_name == pe_node.name and pe_port_index == port_index:
+            return mem_process
+
+
+def _get_pe_nodes(
+    mem_node: Process, pe_nodes: list[Process]
+) -> tuple[Process, int] | tuple[None, None]:
+    nodes = []
+    split_var = iter(mem_node.name.split("."))
+    var_name = next(split_var)
+    port_index = int(next(split_var))
+    for pe_process in pe_nodes:
+        for input_port in pe_process.operation.inputs:
+            input_op = input_port.connected_source.operation
+            if (
+                input_op.graph_id == var_name
+                and input_port.connected_source.index == port_index
+            ):
+                nodes.append((pe_process, input_port.index))
+    return nodes
+
+
+def _get_pe_to_pe_connection(
+    pe_node: Process,
+    direct_variables: ProcessCollection,
+    other_pe_nodes: list[Process],
+    pe_in_port_index: int,
+    pe_out_port_index: int,
+) -> tuple[Process, int, int] | tuple[None, None, None]:
+    nodes = []
+    for direct_var in direct_variables:
+        split_var = iter(direct_var.name.split("."))
+        var_name = next(split_var)
+        port_index = int(next(split_var))
+
+        if var_name == pe_node.name and pe_out_port_index == port_index:
+            for output_port in pe_node.operation.outputs:
+                port = cast(OutputPort, output_port)
+                if port.index == port_index:
+                    for output_signal in port.signals:
+                        if output_signal.destination.index == pe_in_port_index:
+                            op = output_signal.destination_operation
+
+                            for other_pe_node in other_pe_nodes:
+                                if other_pe_node.name == op.graph_id:
+                                    nodes.append(other_pe_node)
+    return nodes
diff --git a/b_asic/resources.py b/b_asic/resources.py
index 59663967..12e572df 100644
--- a/b_asic/resources.py
+++ b/b_asic/resources.py
@@ -834,7 +834,6 @@ class ProcessCollection:
             }
             node1_start_time = node1.start_time % self.schedule_time
             if total_ports == 1 and node1.start_time in node1_stop_times:
-                print(node1.start_time, node1_stop_times)
                 raise ValueError("Cannot read and write in same cycle.")
             for node2 in exclusion_graph:
                 if node1 == node2:
@@ -956,7 +955,7 @@ class ProcessCollection:
             Node ordering strategy passed to
             :func:`networkx.algorithms.coloring.greedy_color`.
             This parameter is only considered if *strategy* is set to 'greedy_graph_color'.
-            One of
+            Valid options are:
 
             * 'largest_first'
             * 'random_sequential'
@@ -1017,7 +1016,7 @@ class ProcessCollection:
         ----------
         strategy : str, default: "left_edge"
             The strategy used when splitting this :class:`ProcessCollection`.
-            Valid options are
+            Valid options are:
 
             * "ilp_graph_color" - ILP-based optimal graph coloring
             * "ilp_min_input_mux" - ILP-based optimal graph coloring minimizing the number of input multiplexers
@@ -1050,7 +1049,7 @@ class ProcessCollection:
 
         solver : PuLP MIP solver object, optional
             Only used if strategy is an ILP method.
-            Valid options are
+            Valid options are:
 
             * PULP_CBC_CMD() - preinstalled with the package
             * GUROBI() - required licence but likely faster
@@ -1789,8 +1788,8 @@ class ProcessCollection:
         #   (1) - nodes have exactly one color
         #   (2) - adjacent nodes cannot have the same color
         #   (3) - only permit assignments if color is used
-        #   (4) - if node is colored then enable the PE which generates that node
-        #   (5) - if node is colored then enable the PE reads from that node
+        #   (4) - if node is colored then enable the PE port which writes to that node
+        #   (5) - if node is colored then enable the PE port which reads from that node
         #   (6) - reduce solution space by assigning colors to the largest clique
         #   (7 & 8) - reduce solution space by ignoring the symmetry caused
         #       by cycling the graph colors
diff --git a/pyproject.toml b/pyproject.toml
index 80fa5ce8..6e0cc01d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -116,7 +116,7 @@ select = [
   "I",
   "G004"
 ]
-ignore = ["F403", "B008", "B021", "B006", "UP038", "RUF023", "A005"]
+ignore = ["F403", "B008", "B021", "B006", "UP038", "RUF023", "A005", "E741"]
 
 [tool.typos]
 default.extend-identifiers = { ba = "ba", addd0 = "addd0", inout = "inout", ArChItEctUrE = "ArChItEctUrE" }
diff --git a/test/integration/test_sfg_to_architecture.py b/test/integration/test_sfg_to_architecture.py
index 52174a07..92b29c83 100644
--- a/test/integration/test_sfg_to_architecture.py
+++ b/test/integration/test_sfg_to_architecture.py
@@ -508,11 +508,10 @@ def test_joint_resource_assignment():
         ),
     )
 
-    direct, mem_vars = schedule.get_memory_variables().split_on_length()
-    pes, mems = assign_processing_elements_and_memories(
+    pes, mems, direct = assign_processing_elements_and_memories(
         schedule.get_operations(),
-        mem_vars,
-        resources,
+        schedule.get_memory_variables(),
+        resources=resources,
         max_memories=3,
         memory_read_ports=1,
         memory_write_ports=1,
@@ -531,11 +530,10 @@ def test_joint_resource_assignment():
         ),
     )
 
-    direct, mem_vars = schedule.get_memory_variables().split_on_length()
-    pes, mems = assign_processing_elements_and_memories(
+    pes, mems, direct = assign_processing_elements_and_memories(
         schedule.get_operations(),
-        mem_vars,
-        resources,
+        schedule.get_memory_variables(),
+        resources=resources,
         max_memories=4,
         memory_read_ports=1,
         memory_write_ports=1,
@@ -545,3 +543,35 @@ def test_joint_resource_assignment():
     arch = Architecture(pes, mems, direct_interconnects=direct)
     assert len(arch.processing_elements) == 5
     assert len(arch.memories) == 4
+
+
+def test_joint_resource_assignment_mux_reduction():
+    POINTS = 32
+    sfg = radix_2_dif_fft(POINTS)
+    sfg.set_latency_of_type_name("bfly", 1)
+    sfg.set_latency_of_type_name("cmul", 3)
+    sfg.set_execution_time_of_type_name("bfly", 1)
+    sfg.set_execution_time_of_type_name("cmul", 1)
+
+    resources = {"bfly": 1, "cmul": 1, "in": 1, "out": 1}
+    schedule = Schedule(
+        sfg,
+        scheduler=HybridScheduler(
+            resources, max_concurrent_reads=3, max_concurrent_writes=3
+        ),
+    )
+
+    pes, mems, direct = assign_processing_elements_and_memories(
+        schedule.get_operations(),
+        schedule.get_memory_variables(),
+        strategy="ilp_min_total_mux",
+        resources=resources,
+        max_memories=3,
+        memory_read_ports=1,
+        memory_write_ports=1,
+        memory_total_ports=2,
+    )
+
+    arch = Architecture(pes, mems, direct_interconnects=direct)
+    assert len(arch.processing_elements) == 4
+    assert len(arch.memories) == 3
-- 
GitLab