diff --git a/b_asic/resources.py b/b_asic/resources.py index 2ae8e50a28f1b881d66858d91daa0294bcf8c248..e09af3adc476df75af6c128e016554b10bd6a284 100644 --- a/b_asic/resources.py +++ b/b_asic/resources.py @@ -109,27 +109,32 @@ def _sanitize_port_option( return read_ports, write_ports, total_ports -def _get_source( - var: MemoryVariable, pes: list["ProcessingElement"] -) -> "ProcessingElement": - name = var.name.split(".")[0] +def _get_source_port(var: MemoryVariable, pes: list["ProcessingElement"]) -> str: + split_var = iter(var.name.split(".")) + var_name = next(split_var) + port_index = int(next(split_var)) for pe in pes: - pe_names = [proc.name for proc in pe.collection] - if name in pe_names: - return pe + for process in pe: + if var_name == process.name: + for output_port in process.operation.outputs: + if output_port.index == port_index: + return f"{pe.entity_name}.out.{output_port.index}" raise ValueError("Source could not be found for the given variable.") -def _get_destination( - var: MemoryVariable, pes: list["ProcessingElement"] -) -> "ProcessingElement": - name = var.name.split(".")[0] +def _get_destination_port(var: MemoryVariable, pes: list["ProcessingElement"]) -> str: + split_var = iter(var.name.split(".")) + var_name = next(split_var) + port_index = int(next(split_var)) for pe in pes: - for process in pe.processes: + for process in pe: for input_port in process.operation.inputs: input_op = input_port.connected_source.operation - if input_op.graph_id == name: - return pe + if ( + input_op.graph_id == var_name + and input_port.connected_source.index == port_index + ): + return f"{pe.entity_name}.in.{input_port.index}" raise ValueError("Destination could not be found for the given variable.") @@ -1517,6 +1522,12 @@ class ProcessCollection: colors = range(amount_of_colors) + pe_out_ports = [ + f"{pe.entity_name}.out.{port_index}" + for pe in processing_elements + for port_index in range(pe.output_count) + ] + # minimize the amount of input muxes connecting PEs to memories # by minimizing the amount of PEs connected to each memory @@ -1526,9 +1537,10 @@ class ProcessCollection: # y[pe, color] - whether a color has nodes generated from a certain pe x = LpVariable.dicts("x", (nodes, colors), cat=LpBinary) c = LpVariable.dicts("c", colors, cat=LpBinary) - y = LpVariable.dicts("y", (processing_elements, colors), cat=LpBinary) + y = LpVariable.dicts("y", (pe_out_ports, colors), cat=LpBinary) + problem = LpProblem() - problem += lpSum(y[pe][i] for pe in processing_elements for i in colors) + problem += lpSum(y[port][i] for port in pe_out_ports for i in colors) # constraints: # (1) - nodes have exactly one color @@ -1547,9 +1559,9 @@ class ProcessCollection: for color in colors: problem += x[node][color] <= c[color] for node in nodes: - pe = _get_source(node, processing_elements) + port = _get_source_port(node, processing_elements) for color in colors: - problem += x[node][color] <= y[pe][color] + problem += x[node][color] <= y[port][color] max_clique = next(nx.find_cliques(exclusion_graph)) for color, node in enumerate(max_clique): problem += x[node][color] == c[color] == 1 @@ -1601,6 +1613,12 @@ class ProcessCollection: colors = range(amount_of_colors) + pe_in_ports = [ + f"{pe.entity_name}.in.{port_index}" + for pe in processing_elements + for port_index in range(pe.input_count) + ] + # minimize the amount of output muxes connecting PEs to memories # by minimizing the amount of PEs connected to each memory @@ -1610,9 +1628,9 @@ class ProcessCollection: # y[pe, color] - whether a color has nodes writing to a certain PE x = LpVariable.dicts("x", (nodes, colors), cat=LpBinary) c = LpVariable.dicts("c", colors, cat=LpBinary) - y = LpVariable.dicts("y", (processing_elements, colors), cat=LpBinary) + y = LpVariable.dicts("y", (pe_in_ports, colors), cat=LpBinary) problem = LpProblem() - problem += lpSum(y[pe][i] for pe in processing_elements for i in colors) + problem += lpSum(y[port][i] for port in pe_in_ports for i in colors) # constraints: # (1) - nodes have exactly one color @@ -1631,9 +1649,9 @@ class ProcessCollection: for color in colors: problem += x[node][color] <= c[color] for node in nodes: - pe = _get_destination(node, processing_elements) + port = _get_destination_port(node, processing_elements) for color in colors: - problem += x[node][color] <= y[pe][color] + problem += x[node][color] <= y[port][color] max_clique = next(nx.find_cliques(exclusion_graph)) for color, node in enumerate(max_clique): problem += x[node][color] == c[color] == 1 @@ -1685,6 +1703,17 @@ class ProcessCollection: colors = range(amount_of_colors) + pe_in_ports = [ + f"{pe.entity_name}.in.{port_index}" + for pe in processing_elements + for port_index in range(pe.input_count) + ] + pe_out_ports = [ + f"{pe.entity_name}.out.{port_index}" + for pe in processing_elements + for port_index in range(pe.output_count) + ] + # minimize the amount of total muxes connecting PEs to memories # by minimizing the amount of PEs connected to each memory (input & output) @@ -1695,11 +1724,11 @@ class ProcessCollection: # z[pe, color] - whether a color has nodes writing to a certain PE x = LpVariable.dicts("x", (nodes, colors), cat=LpBinary) c = LpVariable.dicts("c", colors, cat=LpBinary) - y = LpVariable.dicts("y", (processing_elements, colors), cat=LpBinary) - z = LpVariable.dicts("z", (processing_elements, colors), cat=LpBinary) + y = LpVariable.dicts("y", (pe_out_ports, colors), cat=LpBinary) + z = LpVariable.dicts("z", (pe_in_ports, colors), cat=LpBinary) problem = LpProblem() - problem += lpSum( - y[pe][i] + z[pe][i] for pe in processing_elements for i in colors + problem += lpSum([y[port][i] for port in pe_out_ports for i in colors]) + lpSum( + [z[port][i] for port in pe_in_ports for i in colors] ) # constraints: @@ -1720,13 +1749,13 @@ class ProcessCollection: for color in colors: problem += x[node][color] <= c[color] for node in nodes: - pe = _get_source(node, processing_elements) + port = _get_source_port(node, processing_elements) for color in colors: - problem += x[node][color] <= y[pe][color] + problem += x[node][color] <= y[port][color] for node in nodes: - pe = _get_destination(node, processing_elements) + port = _get_destination_port(node, processing_elements) for color in colors: - problem += x[node][color] <= z[pe][color] + problem += x[node][color] <= z[port][color] max_clique = next(nx.find_cliques(exclusion_graph)) for color, node in enumerate(max_clique): problem += x[node][color] == c[color] == 1