Skip to content
Snippets Groups Projects
Commit cf6e5808 authored by Oscar Gustafsson's avatar Oscar Gustafsson :bicyclist:
Browse files

Fix resource-related issues

parent 0cd0723a
No related branches found
No related tags found
1 merge request!192Fix resource-related issues
Pipeline #89737 passed
...@@ -20,11 +20,20 @@ class Process: ...@@ -20,11 +20,20 @@ class Process:
Start time of process. Start time of process.
execution_time : int execution_time : int
Execution time (lifetime) of process. Execution time (lifetime) of process.
name : str, optional
The name of the process. If not provided, generate a name.
""" """
def __init__(self, start_time: int, execution_time: int): def __init__(
self, start_time: int, execution_time: int, name: Optional[str] = None
):
self._start_time = start_time self._start_time = start_time
self._execution_time = execution_time self._execution_time = execution_time
if name is None:
self._name = f"Proc. {PlainMemoryVariable._name_cnt}"
PlainMemoryVariable._name_cnt += 1
else:
self._name = name
def __lt__(self, other): def __lt__(self, other):
return self._start_time < other.start_time or ( return self._start_time < other.start_time or (
...@@ -42,6 +51,16 @@ class Process: ...@@ -42,6 +51,16 @@ class Process:
"""Return the execution time.""" """Return the execution time."""
return self._execution_time return self._execution_time
@property
def name(self) -> str:
return self._name
def __str__(self) -> str:
return self._name
# Static counter for default names
_name_cnt = 0
class OperatorProcess(Process): class OperatorProcess(Process):
""" """
...@@ -53,16 +72,27 @@ class OperatorProcess(Process): ...@@ -53,16 +72,27 @@ class OperatorProcess(Process):
Start time of process. Start time of process.
operation : Operation operation : Operation
Operation that the process corresponds to. Operation that the process corresponds to.
name : str, optional
The name of the process.
""" """
def __init__(self, start_time: int, operation: Operation): def __init__(
self,
start_time: int,
operation: Operation,
name: Optional[str] = None,
):
execution_time = operation.execution_time execution_time = operation.execution_time
if execution_time is None: if execution_time is None:
raise ValueError( raise ValueError(
"Operation {operation!r} does not have an execution time" "Operation {operation!r} does not have an execution time"
" specified!" " specified!"
) )
super().__init__(start_time, execution_time) super().__init__(
start_time,
execution_time,
name=name,
)
self._operation = operation self._operation = operation
...@@ -80,6 +110,8 @@ class MemoryVariable(Process): ...@@ -80,6 +110,8 @@ class MemoryVariable(Process):
reads : {InputPort: int, ...} reads : {InputPort: int, ...}
Dictionary with the InputPorts that reads the memory variable and Dictionary with the InputPorts that reads the memory variable and
for how long after the *write_time* they will read. for how long after the *write_time* they will read.
name : str, optional
The name of the process.
""" """
def __init__( def __init__(
...@@ -87,12 +119,15 @@ class MemoryVariable(Process): ...@@ -87,12 +119,15 @@ class MemoryVariable(Process):
write_time: int, write_time: int,
write_port: OutputPort, write_port: OutputPort,
reads: Dict[InputPort, int], reads: Dict[InputPort, int],
name: Optional[str] = None,
): ):
self._read_ports = tuple(reads.keys()) self._read_ports = tuple(reads.keys())
self._life_times = tuple(reads.values()) self._life_times = tuple(reads.values())
self._write_port = write_port self._write_port = write_port
super().__init__( super().__init__(
start_time=write_time, execution_time=max(self._life_times) start_time=write_time,
execution_time=max(self._life_times),
name=name,
) )
@property @property
...@@ -123,6 +158,8 @@ class PlainMemoryVariable(Process): ...@@ -123,6 +158,8 @@ class PlainMemoryVariable(Process):
reads : {int: int, ...} reads : {int: int, ...}
Dictionary where the key is the destination identifier and the value Dictionary where the key is the destination identifier and the value
is the time after *write_time* that the memory variable is read. is the time after *write_time* that the memory variable is read.
name : str, optional
The name of the process.
""" """
def __init__( def __init__(
...@@ -135,11 +172,10 @@ class PlainMemoryVariable(Process): ...@@ -135,11 +172,10 @@ class PlainMemoryVariable(Process):
self._read_ports = tuple(reads.keys()) self._read_ports = tuple(reads.keys())
self._life_times = tuple(reads.values()) self._life_times = tuple(reads.values())
self._write_port = write_port self._write_port = write_port
if name is None:
self._name = str(PlainMemoryVariable._name_cnt)
PlainMemoryVariable._name_cnt += 1
super().__init__( super().__init__(
start_time=write_time, execution_time=max(self._life_times) start_time=write_time,
execution_time=max(self._life_times),
name=name,
) )
@property @property
...@@ -153,13 +189,3 @@ class PlainMemoryVariable(Process): ...@@ -153,13 +189,3 @@ class PlainMemoryVariable(Process):
@property @property
def write_port(self) -> int: def write_port(self) -> int:
return self._write_port return self._write_port
@property
def name(self) -> str:
return self._name
def __str__(self) -> str:
return self._name
# Static counter for default names
_name_cnt = 0
...@@ -6,6 +6,7 @@ import random ...@@ -6,6 +6,7 @@ import random
from typing import Optional, Set from typing import Optional, Set
from b_asic.process import PlainMemoryVariable from b_asic.process import PlainMemoryVariable
from b_asic.resources import ProcessCollection
def _insert_delays(inputorder, outputorder, min_lifetime, cyclic): def _insert_delays(inputorder, outputorder, min_lifetime, cyclic):
...@@ -14,9 +15,7 @@ def _insert_delays(inputorder, outputorder, min_lifetime, cyclic): ...@@ -14,9 +15,7 @@ def _insert_delays(inputorder, outputorder, min_lifetime, cyclic):
outputorder = [o - maxdiff + min_lifetime for o in outputorder] outputorder = [o - maxdiff + min_lifetime for o in outputorder]
maxdelay = max(outputorder[i] - inputorder[i] for i in range(size)) maxdelay = max(outputorder[i] - inputorder[i] for i in range(size))
if cyclic: if cyclic:
if maxdelay < size: if maxdelay >= size:
outputorder = [o % size for o in outputorder]
else:
inputorder = inputorder + [i + size for i in inputorder] inputorder = inputorder + [i + size for i in inputorder]
outputorder = outputorder + [o + size for o in outputorder] outputorder = outputorder + [o + size for o in outputorder]
return inputorder, outputorder return inputorder, outputorder
...@@ -24,7 +23,7 @@ def _insert_delays(inputorder, outputorder, min_lifetime, cyclic): ...@@ -24,7 +23,7 @@ def _insert_delays(inputorder, outputorder, min_lifetime, cyclic):
def generate_random_interleaver( def generate_random_interleaver(
size: int, min_lifetime: int = 0, cyclic: bool = True size: int, min_lifetime: int = 0, cyclic: bool = True
) -> Set[PlainMemoryVariable]: ) -> ProcessCollection:
""" """
Generate a ProcessCollection with memory variable corresponding to a random Generate a ProcessCollection with memory variable corresponding to a random
interleaver with length *size*. interleaver with length *size*.
...@@ -48,15 +47,19 @@ def generate_random_interleaver( ...@@ -48,15 +47,19 @@ def generate_random_interleaver(
inputorder = list(range(size)) inputorder = list(range(size))
outputorder = inputorder[:] outputorder = inputorder[:]
random.shuffle(outputorder) random.shuffle(outputorder)
print(inputorder, outputorder)
inputorder, outputorder = _insert_delays( inputorder, outputorder = _insert_delays(
inputorder, outputorder, min_lifetime, cyclic inputorder, outputorder, min_lifetime, cyclic
) )
print(inputorder, outputorder) return ProcessCollection(
return { {
PlainMemoryVariable(inputorder[i], 0, {0: outputorder[i]}) PlainMemoryVariable(
for i in range(size) inputorder[i], 0, {0: outputorder[i] - inputorder[i]}
} )
for i in range(len(inputorder))
},
len(inputorder),
cyclic,
)
def generate_matrix_transposer( def generate_matrix_transposer(
...@@ -64,7 +67,7 @@ def generate_matrix_transposer( ...@@ -64,7 +67,7 @@ def generate_matrix_transposer(
width: Optional[int] = None, width: Optional[int] = None,
min_lifetime: int = 0, min_lifetime: int = 0,
cyclic: bool = True, cyclic: bool = True,
) -> Set[PlainMemoryVariable]: ) -> ProcessCollection:
r""" r"""
Generate a ProcessCollection with memory variable corresponding to transposing a Generate a ProcessCollection with memory variable corresponding to transposing a
matrix of size *height* :math:`\times` *width*. If *width* is not provided, a matrix of size *height* :math:`\times` *width*. If *width* is not provided, a
...@@ -101,12 +104,19 @@ def generate_matrix_transposer( ...@@ -101,12 +104,19 @@ def generate_matrix_transposer(
for col in range(height): for col in range(height):
outputorder.append(col * width + row) outputorder.append(col * width + row)
print(inputorder, outputorder)
inputorder, outputorder = _insert_delays( inputorder, outputorder = _insert_delays(
inputorder, outputorder, min_lifetime, cyclic inputorder, outputorder, min_lifetime, cyclic
) )
print(inputorder, outputorder) return ProcessCollection(
return { {
PlainMemoryVariable(inputorder[i], 0, {0: outputorder[i]}) PlainMemoryVariable(
for i in range(width * height) inputorder[i],
} 0,
{0: outputorder[i] - inputorder[i]},
name=f"{inputorder[i]}",
)
for i in range(len(inputorder))
},
len(inputorder),
cyclic,
)
import re
from typing import Dict, List, Optional, Set, Tuple, Union from typing import Dict, List, Optional, Set, Tuple, Union
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -8,6 +9,16 @@ from matplotlib.ticker import MaxNLocator ...@@ -8,6 +9,16 @@ from matplotlib.ticker import MaxNLocator
from b_asic.process import Process from b_asic.process import Process
# From https://stackoverflow.com/questions/2669059/how-to-sort-alpha-numeric-set-in-python
def _sorted_nicely(to_be_sorted):
"""Sort the given iterable in the way that humans expect."""
convert = lambda text: int(text) if text.isdigit() else text
alphanum_key = lambda key: [
convert(c) for c in re.split('([0-9]+)', str(key))
]
return sorted(to_be_sorted, key=alphanum_key)
def draw_exclusion_graph_coloring( def draw_exclusion_graph_coloring(
exclusion_graph: nx.Graph, exclusion_graph: nx.Graph,
color_dict: Dict[Process, int], color_dict: Dict[Process, int],
...@@ -79,14 +90,23 @@ class ProcessCollection: ...@@ -79,14 +90,23 @@ class ProcessCollection:
Parameters Parameters
---------- ----------
collection : set of :class:`~b_asic.process.Process` objects, optional collection : set of :class:`~b_asic.process.Process` objects
The Process objects forming this ProcessCollection.
schedule_time : int, default: 0
Length of the time-axis in the generated graph.
cyclic : bool, default: False
If the processes operates cyclically, i.e., if time 0 == time *schedule_time*.
""" """
def __init__(self, collection: Optional[Set[Process]] = None): def __init__(
if collection is None: self,
self._collection: Set[Process] = set() collection: Set[Process],
else: schedule_time: int,
self._collection = collection cyclic: bool = False,
):
self._collection = collection
self._schedule_time = schedule_time
self._cyclic = cyclic
def add_process(self, process: Process): def add_process(self, process: Process):
""" """
...@@ -101,7 +121,6 @@ class ProcessCollection: ...@@ -101,7 +121,6 @@ class ProcessCollection:
def draw_lifetime_chart( def draw_lifetime_chart(
self, self,
schedule_time: int = 0,
ax: Optional[Axes] = None, ax: Optional[Axes] = None,
show_name: bool = True, show_name: bool = True,
): ):
...@@ -110,9 +129,6 @@ class ProcessCollection: ...@@ -110,9 +129,6 @@ class ProcessCollection:
Parameters Parameters
---------- ----------
schedule_time : int, default: 0
Length of the time-axis in the generated graph. The time axis will span [0, schedule_time-1].
If set to zero (which is the default), the ...
ax : :class:`matplotlib.axes.Axes`, optional ax : :class:`matplotlib.axes.Axes`, optional
Matplotlib Axes object to draw this lifetime chart onto. If not provided (i.e., set to None), this will Matplotlib Axes object to draw this lifetime chart onto. If not provided (i.e., set to None), this will
return a new axes object on return. return a new axes object on return.
...@@ -133,26 +149,19 @@ class ProcessCollection: ...@@ -133,26 +149,19 @@ class ProcessCollection:
# Draw the lifetime chart # Draw the lifetime chart
PAD_L, PAD_R = 0.05, 0.05 PAD_L, PAD_R = 0.05, 0.05
max_execution_time = max( max_execution_time = max(
[process.execution_time for process in self._collection] process.execution_time for process in self._collection
) )
schedule_time = ( if max_execution_time > self._schedule_time:
schedule_time
if schedule_time
else max(p.start_time + p.execution_time for p in self._collection)
)
if max_execution_time > schedule_time:
# Schedule time needs to be greater than or equal to the maximum process life time # Schedule time needs to be greater than or equal to the maximum process life time
raise KeyError( raise KeyError(
f'Error: Schedule time: {schedule_time} < Max execution time:' f'Error: Schedule time: {self._schedule_time} < Max execution'
f' {max_execution_time}' f' time: {max_execution_time}'
) )
for i, process in enumerate( for i, process in enumerate(_sorted_nicely(self._collection)):
sorted(self._collection, key=lambda p: str(p)) bar_start = process.start_time % self._schedule_time
):
bar_start = process.start_time % schedule_time
bar_end = ( bar_end = (
process.start_time + process.execution_time process.start_time + process.execution_time
) % schedule_time ) % self._schedule_time
if bar_end > bar_start: if bar_end > bar_start:
_ax.broken_barh( _ax.broken_barh(
[(PAD_L + bar_start, bar_end - bar_start - PAD_L - PAD_R)], [(PAD_L + bar_start, bar_end - bar_start - PAD_L - PAD_R)],
...@@ -164,7 +173,7 @@ class ProcessCollection: ...@@ -164,7 +173,7 @@ class ProcessCollection:
[ [
( (
PAD_L + bar_start, PAD_L + bar_start,
schedule_time - bar_start - PAD_L, self._schedule_time - bar_start - PAD_L,
) )
], ],
(i + 0.55, 0.9), (i + 0.55, 0.9),
...@@ -175,7 +184,10 @@ class ProcessCollection: ...@@ -175,7 +184,10 @@ class ProcessCollection:
[ [
( (
PAD_L + bar_start, PAD_L + bar_start,
schedule_time - bar_start - PAD_L - PAD_R, self._schedule_time
- bar_start
- PAD_L
- PAD_R,
) )
], ],
(i + 0.55, 0.9), (i + 0.55, 0.9),
...@@ -190,6 +202,8 @@ class ProcessCollection: ...@@ -190,6 +202,8 @@ class ProcessCollection:
_ax.xaxis.set_major_locator(MaxNLocator(integer=True)) _ax.xaxis.set_major_locator(MaxNLocator(integer=True))
_ax.yaxis.set_major_locator(MaxNLocator(integer=True)) _ax.yaxis.set_major_locator(MaxNLocator(integer=True))
_ax.set_xlim(0, self._schedule_time)
_ax.set_ylim(0.25, len(self._collection) + 0.75)
return _ax return _ax
def create_exclusion_graph_from_overlap( def create_exclusion_graph_from_overlap(
...@@ -332,12 +346,14 @@ class ProcessCollection: ...@@ -332,12 +346,14 @@ class ProcessCollection:
coloring = nx.coloring.greedy_color(exclusion_graph) coloring = nx.coloring.greedy_color(exclusion_graph)
draw_exclusion_graph_coloring(exclusion_graph, coloring) draw_exclusion_graph_coloring(exclusion_graph, coloring)
# process_collection_list = [ProcessCollection()]*(max(coloring.values()) + 1) # process_collection_list = [ProcessCollection()]*(max(coloring.values()) + 1)
process_collection_list = [ process_collection_set_list = [
ProcessCollection() for _ in range(max(coloring.values()) + 1) set() for _ in range(max(coloring.values()) + 1)
] ]
for process, color in coloring.items(): for process, color in coloring.items():
process_collection_list[color].add_process(process) process_collection_set_list[color].add(process)
return { return {
process_collection ProcessCollection(
for process_collection in process_collection_list process_collection_set, self._schedule_time, self._cyclic
)
for process_collection_set in process_collection_set_list
} }
test/baseline/test_draw_matrix_transposer_4.png

20.9 KiB

test/baseline/test_draw_process_collection.png

10.1 KiB | W: | H:

test/baseline/test_draw_process_collection.png

14.7 KiB | W: | H:

test/baseline/test_draw_process_collection.png
test/baseline/test_draw_process_collection.png
test/baseline/test_draw_process_collection.png
test/baseline/test_draw_process_collection.png
  • 2-up
  • Swipe
  • Onion skin
...@@ -16,12 +16,13 @@ def simple_collection(): ...@@ -16,12 +16,13 @@ def simple_collection():
PlainMemoryVariable(0, NO_PORT, {NO_PORT: 3}), PlainMemoryVariable(0, NO_PORT, {NO_PORT: 3}),
PlainMemoryVariable(0, NO_PORT, {NO_PORT: 2}), PlainMemoryVariable(0, NO_PORT, {NO_PORT: 2}),
PlainMemoryVariable(0, NO_PORT, {NO_PORT: 6}), PlainMemoryVariable(0, NO_PORT, {NO_PORT: 6}),
} },
8,
) )
@pytest.fixture() @pytest.fixture()
def collection(): def cyclic_simple_collection():
NO_PORT = 0 NO_PORT = 0
return ProcessCollection( return ProcessCollection(
{ {
...@@ -32,5 +33,7 @@ def collection(): ...@@ -32,5 +33,7 @@ def collection():
PlainMemoryVariable(0, NO_PORT, {NO_PORT: 3}), PlainMemoryVariable(0, NO_PORT, {NO_PORT: 3}),
PlainMemoryVariable(0, NO_PORT, {NO_PORT: 2}), PlainMemoryVariable(0, NO_PORT, {NO_PORT: 2}),
PlainMemoryVariable(0, NO_PORT, {NO_PORT: 6}), PlainMemoryVariable(0, NO_PORT, {NO_PORT: 6}),
} },
6,
True,
) )
...@@ -2,8 +2,11 @@ import matplotlib.pyplot as plt ...@@ -2,8 +2,11 @@ import matplotlib.pyplot as plt
import networkx as nx import networkx as nx
import pytest import pytest
from b_asic.process import PlainMemoryVariable from b_asic.research.interleaver import (
from b_asic.resources import ProcessCollection, draw_exclusion_graph_coloring generate_matrix_transposer,
generate_random_interleaver,
)
from b_asic.resources import draw_exclusion_graph_coloring
class TestProcessCollectionPlainMemoryVariable: class TestProcessCollectionPlainMemoryVariable:
...@@ -15,7 +18,7 @@ class TestProcessCollectionPlainMemoryVariable: ...@@ -15,7 +18,7 @@ class TestProcessCollectionPlainMemoryVariable:
def test_draw_proces_collection(self, simple_collection): def test_draw_proces_collection(self, simple_collection):
_, ax = plt.subplots(1, 2) _, ax = plt.subplots(1, 2)
simple_collection.draw_lifetime_chart(schedule_time=8, ax=ax[0]) simple_collection.draw_lifetime_chart(ax=ax[0])
exclusion_graph = ( exclusion_graph = (
simple_collection.create_exclusion_graph_from_overlap() simple_collection.create_exclusion_graph_from_overlap()
) )
...@@ -27,3 +30,26 @@ class TestProcessCollectionPlainMemoryVariable: ...@@ -27,3 +30,26 @@ class TestProcessCollectionPlainMemoryVariable:
read_ports=1, write_ports=1, total_ports=2 read_ports=1, write_ports=1, total_ports=2
) )
assert len(collection_split) == 3 assert len(collection_split) == 3
@pytest.mark.mpl_image_compare(style='mpl20')
def test_draw_matrix_transposer_4(self):
fig, ax = plt.subplots()
generate_matrix_transposer(4).draw_lifetime_chart(ax=ax)
return fig
def test_generate_random_interleaver(self):
return
for _ in range(10):
for size in range(5, 20, 5):
assert (
len(
generate_random_interleaver(size).split(
read_ports=1, write_ports=1
)
)
== 1
)
assert (
len(generate_random_interleaver(size).split(total_ports=1))
== 2
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment