Skip to content
Snippets Groups Projects
Commit 24545b4f authored by Simon Bjurek's avatar Simon Bjurek
Browse files

Update ILP split_on_ports to consider ports instead of PEs

parent cfd0b5e4
No related branches found
No related tags found
1 merge request!514Update ILP split_on_port to consider ports
Pipeline #160987 passed
......@@ -109,27 +109,32 @@ def _sanitize_port_option(
return read_ports, write_ports, total_ports
def _get_source(
var: MemoryVariable, pes: list["ProcessingElement"]
) -> "ProcessingElement":
name = var.name.split(".")[0]
def _get_source_port(var: MemoryVariable, pes: list["ProcessingElement"]) -> str:
split_var = iter(var.name.split("."))
var_name = next(split_var)
port_index = int(next(split_var))
for pe in pes:
pe_names = [proc.name for proc in pe.collection]
if name in pe_names:
return pe
for process in pe:
if var_name == process.name:
for output_port in process.operation.outputs:
if output_port.index == port_index:
return f"{pe.entity_name}.out.{output_port.index}"
raise ValueError("Source could not be found for the given variable.")
def _get_destination(
var: MemoryVariable, pes: list["ProcessingElement"]
) -> "ProcessingElement":
name = var.name.split(".")[0]
def _get_destination_port(var: MemoryVariable, pes: list["ProcessingElement"]) -> str:
split_var = iter(var.name.split("."))
var_name = next(split_var)
port_index = int(next(split_var))
for pe in pes:
for process in pe.processes:
for process in pe:
for input_port in process.operation.inputs:
input_op = input_port.connected_source.operation
if input_op.graph_id == name:
return pe
if (
input_op.graph_id == var_name
and input_port.connected_source.index == port_index
):
return f"{pe.entity_name}.in.{input_port.index}"
raise ValueError("Destination could not be found for the given variable.")
......@@ -1517,6 +1522,12 @@ class ProcessCollection:
colors = range(amount_of_colors)
pe_out_ports = [
f"{pe.entity_name}.out.{port_index}"
for pe in processing_elements
for port_index in range(pe.output_count)
]
# minimize the amount of input muxes connecting PEs to memories
# by minimizing the amount of PEs connected to each memory
......@@ -1526,9 +1537,10 @@ class ProcessCollection:
# y[pe, color] - whether a color has nodes generated from a certain pe
x = LpVariable.dicts("x", (nodes, colors), cat=LpBinary)
c = LpVariable.dicts("c", colors, cat=LpBinary)
y = LpVariable.dicts("y", (processing_elements, colors), cat=LpBinary)
y = LpVariable.dicts("y", (pe_out_ports, colors), cat=LpBinary)
problem = LpProblem()
problem += lpSum(y[pe][i] for pe in processing_elements for i in colors)
problem += lpSum(y[port][i] for port in pe_out_ports for i in colors)
# constraints:
# (1) - nodes have exactly one color
......@@ -1547,9 +1559,9 @@ class ProcessCollection:
for color in colors:
problem += x[node][color] <= c[color]
for node in nodes:
pe = _get_source(node, processing_elements)
port = _get_source_port(node, processing_elements)
for color in colors:
problem += x[node][color] <= y[pe][color]
problem += x[node][color] <= y[port][color]
max_clique = next(nx.find_cliques(exclusion_graph))
for color, node in enumerate(max_clique):
problem += x[node][color] == c[color] == 1
......@@ -1601,6 +1613,12 @@ class ProcessCollection:
colors = range(amount_of_colors)
pe_in_ports = [
f"{pe.entity_name}.in.{port_index}"
for pe in processing_elements
for port_index in range(pe.input_count)
]
# minimize the amount of output muxes connecting PEs to memories
# by minimizing the amount of PEs connected to each memory
......@@ -1610,9 +1628,9 @@ class ProcessCollection:
# y[pe, color] - whether a color has nodes writing to a certain PE
x = LpVariable.dicts("x", (nodes, colors), cat=LpBinary)
c = LpVariable.dicts("c", colors, cat=LpBinary)
y = LpVariable.dicts("y", (processing_elements, colors), cat=LpBinary)
y = LpVariable.dicts("y", (pe_in_ports, colors), cat=LpBinary)
problem = LpProblem()
problem += lpSum(y[pe][i] for pe in processing_elements for i in colors)
problem += lpSum(y[port][i] for port in pe_in_ports for i in colors)
# constraints:
# (1) - nodes have exactly one color
......@@ -1631,9 +1649,9 @@ class ProcessCollection:
for color in colors:
problem += x[node][color] <= c[color]
for node in nodes:
pe = _get_destination(node, processing_elements)
port = _get_destination_port(node, processing_elements)
for color in colors:
problem += x[node][color] <= y[pe][color]
problem += x[node][color] <= y[port][color]
max_clique = next(nx.find_cliques(exclusion_graph))
for color, node in enumerate(max_clique):
problem += x[node][color] == c[color] == 1
......@@ -1685,6 +1703,17 @@ class ProcessCollection:
colors = range(amount_of_colors)
pe_in_ports = [
f"{pe.entity_name}.in.{port_index}"
for pe in processing_elements
for port_index in range(pe.input_count)
]
pe_out_ports = [
f"{pe.entity_name}.out.{port_index}"
for pe in processing_elements
for port_index in range(pe.output_count)
]
# minimize the amount of total muxes connecting PEs to memories
# by minimizing the amount of PEs connected to each memory (input & output)
......@@ -1695,11 +1724,11 @@ class ProcessCollection:
# z[pe, color] - whether a color has nodes writing to a certain PE
x = LpVariable.dicts("x", (nodes, colors), cat=LpBinary)
c = LpVariable.dicts("c", colors, cat=LpBinary)
y = LpVariable.dicts("y", (processing_elements, colors), cat=LpBinary)
z = LpVariable.dicts("z", (processing_elements, colors), cat=LpBinary)
y = LpVariable.dicts("y", (pe_out_ports, colors), cat=LpBinary)
z = LpVariable.dicts("z", (pe_in_ports, colors), cat=LpBinary)
problem = LpProblem()
problem += lpSum(
y[pe][i] + z[pe][i] for pe in processing_elements for i in colors
problem += lpSum([y[port][i] for port in pe_out_ports for i in colors]) + lpSum(
[z[port][i] for port in pe_in_ports for i in colors]
)
# constraints:
......@@ -1720,13 +1749,13 @@ class ProcessCollection:
for color in colors:
problem += x[node][color] <= c[color]
for node in nodes:
pe = _get_source(node, processing_elements)
port = _get_source_port(node, processing_elements)
for color in colors:
problem += x[node][color] <= y[pe][color]
problem += x[node][color] <= y[port][color]
for node in nodes:
pe = _get_destination(node, processing_elements)
port = _get_destination_port(node, processing_elements)
for color in colors:
problem += x[node][color] <= z[pe][color]
problem += x[node][color] <= z[port][color]
max_clique = next(nx.find_cliques(exclusion_graph))
for color, node in enumerate(max_clique):
problem += x[node][color] == c[color] == 1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment