Skip to content
Snippets Groups Projects
Commit 8a028e82 authored by Mikael Henriksson's avatar Mikael Henriksson :runner:
Browse files

resources.py: add split_ports_sequentially() and left-edge based split_on_ports()

parent 3f32dcde
No related branches found
No related tags found
1 merge request!430Add `split_ports_sequentially()`, left-edge based `split_on_ports()`, and always default to the left-edge heuristic
import io import io
import re import re
from collections import Counter from collections import Counter, defaultdict
from functools import reduce from functools import reduce
from typing import Dict, Iterable, List, Optional, Tuple, TypeVar, Union from typing import Dict, Iterable, List, Optional, Tuple, TypeVar, Union
...@@ -886,7 +886,7 @@ class ProcessCollection: ...@@ -886,7 +886,7 @@ class ProcessCollection:
def split_on_ports( def split_on_ports(
self, self,
heuristic: str = "graph_color", heuristic: str = "left_edge",
read_ports: Optional[int] = None, read_ports: Optional[int] = None,
write_ports: Optional[int] = None, write_ports: Optional[int] = None,
total_ports: Optional[int] = None, total_ports: Optional[int] = None,
...@@ -903,7 +903,7 @@ class ProcessCollection: ...@@ -903,7 +903,7 @@ class ProcessCollection:
Valid options are: Valid options are:
* "graph_color" * "graph_color"
* "..." * "left_edge"
read_ports : int, optional read_ports : int, optional
The number of read ports used when splitting process collection based on The number of read ports used when splitting process collection based on
...@@ -926,9 +926,105 @@ class ProcessCollection: ...@@ -926,9 +926,105 @@ class ProcessCollection:
) )
if heuristic == "graph_color": if heuristic == "graph_color":
return self._split_ports_graph_color(read_ports, write_ports, total_ports) return self._split_ports_graph_color(read_ports, write_ports, total_ports)
elif heuristic == "left_edge":
return self.split_ports_sequentially(
read_ports,
write_ports,
total_ports,
sequence=sorted(self),
)
else: else:
raise ValueError("Invalid heuristic provided.") raise ValueError("Invalid heuristic provided.")
def split_ports_sequentially(
self,
read_ports: int,
write_ports: int,
total_ports: int,
sequence: List[Process],
) -> List["ProcessCollection"]:
"""
Split this collection into multiple new collections by sequentially assigning
processes in the order of `sequence`.
This method takes the processes from `sequence`, in order, and assignes them to
to multiple new `ProcessCollection` based on port collisions in a first-come
first-served manner. The first `Process` in `sequence` is assigned first, and
the last `Proccess` in `sequence is assigned last.
Parameters
----------
read_ports : int
The number of read ports used when splitting process collection based on
memory variable access.
write_ports : int
The number of write ports used when splitting process collection based on
memory variable access.
total_ports : int
The total number of ports used when splitting process collection based on
memory variable access.
sequence: list of `Process`
A list of the processes used to determine the order in which processes are
assigned.
Returns
-------
A set of new ProcessCollection objects with the process splitting.
"""
def ports_collide(proc: Process, collection: ProcessCollection):
"""
Predicate test if insertion of a process `proc` results in colliding ports
when inserted to `collection` based on the `read_ports`, `write_ports`, and
`total_ports`.
"""
# Test the number of concurrent write accesses
collection_writes = defaultdict(int, collection.write_port_accesses())
if collection_writes[proc.start_time] >= write_ports:
return True
# Test the number of concurrent read accesses
collection_reads = defaultdict(int, collection.read_port_accesses())
for proc_read_time in proc.read_times:
if collection_reads[proc_read_time % self.schedule_time] >= read_ports:
return True
# Test the number of total accesses
collection_total_accesses = defaultdict(
int, Counter(collection_writes) + Counter(collection_reads)
)
for access_time in [proc.start_time, *proc.read_times]:
if collection_total_accesses[access_time] >= total_ports:
return True
# No collision detected
return False
# Make sure that processes from `sequence` and and `self` are equal
if set(self.collection) != set(sequence):
raise KeyError("processes in `sequence` must be equal to processes in self")
collections: List[ProcessCollection] = []
for process in sequence:
process_added = False
for collection in collections:
if not ports_collide(process, collection):
collection.add_process(process)
process_added = True
break
if not process_added:
# Stuff the process in a new collection
collections.append(
ProcessCollection(
[process],
schedule_time=self.schedule_time,
cyclic=self._cyclic,
)
)
# Return the list of created ProcessCollections
return collections
def _split_ports_graph_color( def _split_ports_graph_color(
self, self,
read_ports: int, read_ports: int,
......
...@@ -157,10 +157,7 @@ def test_architecture(schedule_direct_form_iir_lp_filter: Schedule): ...@@ -157,10 +157,7 @@ def test_architecture(schedule_direct_form_iir_lp_filter: Schedule):
# Graph representation # Graph representation
# Parts are non-deterministic, but this first part seems OK # Parts are non-deterministic, but this first part seems OK
s = ( s = 'digraph {\n\tnode [shape=box]\n\tsplines=spline\n\tsubgraph cluster_memories'
'digraph {\n\tnode [shape=box]\n\tsplines=spline\n\tsubgraph'
' cluster_memories'
)
assert architecture._digraph().source.startswith(s) assert architecture._digraph().source.startswith(s)
s = 'digraph {\n\tnode [shape=box]\n\tsplines=spline\n\tMEM0' s = 'digraph {\n\tnode [shape=box]\n\tsplines=spline\n\tMEM0'
assert architecture._digraph(cluster=False).source.startswith(s) assert architecture._digraph(cluster=False).source.startswith(s)
...@@ -229,9 +226,9 @@ def test_move_process(schedule_direct_form_iir_lp_filter: Schedule): ...@@ -229,9 +226,9 @@ def test_move_process(schedule_direct_form_iir_lp_filter: Schedule):
architecture.move_process('in0.0', memories[1], memories[0]) architecture.move_process('in0.0', memories[1], memories[0])
assert memories[0].collection.from_name('in0.0') assert memories[0].collection.from_name('in0.0')
assert processing_elements[1].collection.from_name('add0')
architecture.move_process('add0', processing_elements[1], processing_elements[0])
assert processing_elements[0].collection.from_name('add0') assert processing_elements[0].collection.from_name('add0')
architecture.move_process('add0', processing_elements[0], processing_elements[1])
assert processing_elements[1].collection.from_name('add0')
# Processes leave the resources they have moved from # Processes leave the resources they have moved from
with pytest.raises(KeyError): with pytest.raises(KeyError):
...@@ -239,7 +236,7 @@ def test_move_process(schedule_direct_form_iir_lp_filter: Schedule): ...@@ -239,7 +236,7 @@ def test_move_process(schedule_direct_form_iir_lp_filter: Schedule):
with pytest.raises(KeyError): with pytest.raises(KeyError):
memories[1].collection.from_name('in0.0') memories[1].collection.from_name('in0.0')
with pytest.raises(KeyError): with pytest.raises(KeyError):
processing_elements[1].collection.from_name('add0') processing_elements[0].collection.from_name('add0')
# Processes can only be moved when the source and destination process-types match # Processes can only be moved when the source and destination process-types match
with pytest.raises(TypeError, match="cmul3.0 not of type"): with pytest.raises(TypeError, match="cmul3.0 not of type"):
......
import re import re
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pytest
import matplotlib.testing.decorators import matplotlib.testing.decorators
import pytest
from b_asic.core_operations import ConstantMultiplication from b_asic.core_operations import ConstantMultiplication
from b_asic.process import PlainMemoryVariable from b_asic.process import PlainMemoryVariable
...@@ -14,25 +14,57 @@ from b_asic.resources import ProcessCollection, _ForwardBackwardTable ...@@ -14,25 +14,57 @@ from b_asic.resources import ProcessCollection, _ForwardBackwardTable
class TestProcessCollectionPlainMemoryVariable: class TestProcessCollectionPlainMemoryVariable:
@matplotlib.testing.decorators.image_comparison(['test_draw_process_collection.png']) @matplotlib.testing.decorators.image_comparison(
['test_draw_process_collection.png']
)
def test_draw_process_collection(self, simple_collection): def test_draw_process_collection(self, simple_collection):
fig, ax = plt.subplots() fig, ax = plt.subplots()
simple_collection.plot(ax=ax, show_markers=False) simple_collection.plot(ax=ax, show_markers=False)
return fig return fig
@matplotlib.testing.decorators.image_comparison(['test_draw_matrix_transposer_4.png']) @matplotlib.testing.decorators.image_comparison(
['test_draw_matrix_transposer_4.png']
)
def test_draw_matrix_transposer_4(self): def test_draw_matrix_transposer_4(self):
fig, ax = plt.subplots() fig, ax = plt.subplots()
generate_matrix_transposer(4).plot(ax=ax) # type: ignore generate_matrix_transposer(4).plot(ax=ax) # type: ignore
return fig return fig
def test_split_memory_variable(self, simple_collection: ProcessCollection): def test_split_memory_variable_graph_color(
self, simple_collection: ProcessCollection
):
collection_split = simple_collection.split_on_ports( collection_split = simple_collection.split_on_ports(
heuristic="graph_color", read_ports=1, write_ports=1, total_ports=2 heuristic="graph_color", read_ports=1, write_ports=1, total_ports=2
) )
assert len(collection_split) == 3 assert len(collection_split) == 3
@matplotlib.testing.decorators.image_comparison(['test_left_edge_cell_assignment.png']) def test_split_sequence_raises(self, simple_collection: ProcessCollection):
with pytest.raises(KeyError, match="processes in `sequence` must be"):
simple_collection.split_ports_sequentially(
read_ports=1, write_ports=1, total_ports=2, sequence=[]
)
def test_split_memory_variable_left_edge(
self, simple_collection: ProcessCollection
):
split = simple_collection.split_on_ports(
heuristic="left_edge", read_ports=1, write_ports=1, total_ports=2
)
assert len(split) == 3
split = simple_collection.split_on_ports(
heuristic="left_edge", read_ports=1, write_ports=2, total_ports=2
)
assert len(split) == 3
split = simple_collection.split_on_ports(
heuristic="left_edge", read_ports=2, write_ports=2, total_ports=2
)
assert len(split) == 2
@matplotlib.testing.decorators.image_comparison(
['test_left_edge_cell_assignment.png']
)
def test_left_edge_cell_assignment(self, simple_collection: ProcessCollection): def test_left_edge_cell_assignment(self, simple_collection: ProcessCollection):
fig, ax = plt.subplots(1, 2) fig, ax = plt.subplots(1, 2)
assignment = list(simple_collection._left_edge_assignment()) assignment = list(simple_collection._left_edge_assignment())
...@@ -158,7 +190,9 @@ class TestProcessCollectionPlainMemoryVariable: ...@@ -158,7 +190,9 @@ class TestProcessCollectionPlainMemoryVariable:
assert len(simple_collection) == 7 assert len(simple_collection) == 7
assert new_proc not in simple_collection assert new_proc not in simple_collection
@matplotlib.testing.decorators.image_comparison(['test_max_min_lifetime_bar_plot.png']) @matplotlib.testing.decorators.image_comparison(
['test_max_min_lifetime_bar_plot.png']
)
def test_max_min_lifetime_bar_plot(self): def test_max_min_lifetime_bar_plot(self):
fig, ax = plt.subplots() fig, ax = plt.subplots()
collection = ProcessCollection( collection = ProcessCollection(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment