diff --git a/b_asic/process.py b/b_asic/process.py index d48fa0ee5ff24b2ab3667918a30e71a055776706..8d7e7620af5c95fe8ad719c20e2b1d782725db0f 100644 --- a/b_asic/process.py +++ b/b_asic/process.py @@ -20,11 +20,20 @@ class Process: Start time of process. execution_time : int Execution time (lifetime) of process. + name : str, optional + The name of the process. If not provided, generate a name. """ - def __init__(self, start_time: int, execution_time: int): + def __init__( + self, start_time: int, execution_time: int, name: Optional[str] = None + ): self._start_time = start_time self._execution_time = execution_time + if name is None: + self._name = f"Proc. {PlainMemoryVariable._name_cnt}" + PlainMemoryVariable._name_cnt += 1 + else: + self._name = name def __lt__(self, other): return self._start_time < other.start_time or ( @@ -42,6 +51,16 @@ class Process: """Return the execution time.""" return self._execution_time + @property + def name(self) -> str: + return self._name + + def __str__(self) -> str: + return self._name + + # Static counter for default names + _name_cnt = 0 + class OperatorProcess(Process): """ @@ -53,16 +72,27 @@ class OperatorProcess(Process): Start time of process. operation : Operation Operation that the process corresponds to. + name : str, optional + The name of the process. """ - def __init__(self, start_time: int, operation: Operation): + def __init__( + self, + start_time: int, + operation: Operation, + name: Optional[str] = None, + ): execution_time = operation.execution_time if execution_time is None: raise ValueError( "Operation {operation!r} does not have an execution time" " specified!" ) - super().__init__(start_time, execution_time) + super().__init__( + start_time, + execution_time, + name=name, + ) self._operation = operation @@ -80,6 +110,8 @@ class MemoryVariable(Process): reads : {InputPort: int, ...} Dictionary with the InputPorts that reads the memory variable and for how long after the *write_time* they will read. + name : str, optional + The name of the process. """ def __init__( @@ -87,12 +119,15 @@ class MemoryVariable(Process): write_time: int, write_port: OutputPort, reads: Dict[InputPort, int], + name: Optional[str] = None, ): self._read_ports = tuple(reads.keys()) self._life_times = tuple(reads.values()) self._write_port = write_port super().__init__( - start_time=write_time, execution_time=max(self._life_times) + start_time=write_time, + execution_time=max(self._life_times), + name=name, ) @property @@ -123,6 +158,8 @@ class PlainMemoryVariable(Process): reads : {int: int, ...} Dictionary where the key is the destination identifier and the value is the time after *write_time* that the memory variable is read. + name : str, optional + The name of the process. """ def __init__( @@ -135,11 +172,10 @@ class PlainMemoryVariable(Process): self._read_ports = tuple(reads.keys()) self._life_times = tuple(reads.values()) self._write_port = write_port - if name is None: - self._name = str(PlainMemoryVariable._name_cnt) - PlainMemoryVariable._name_cnt += 1 super().__init__( - start_time=write_time, execution_time=max(self._life_times) + start_time=write_time, + execution_time=max(self._life_times), + name=name, ) @property @@ -153,13 +189,3 @@ class PlainMemoryVariable(Process): @property def write_port(self) -> int: return self._write_port - - @property - def name(self) -> str: - return self._name - - def __str__(self) -> str: - return self._name - - # Static counter for default names - _name_cnt = 0 diff --git a/b_asic/research/interleaver.py b/b_asic/research/interleaver.py index 83eb4b3bb415c81bb767e0e6dcd1d124d536e589..b64a79b9bfe6a0709c19bc9de4427b263b79986d 100644 --- a/b_asic/research/interleaver.py +++ b/b_asic/research/interleaver.py @@ -6,6 +6,7 @@ import random from typing import Optional, Set from b_asic.process import PlainMemoryVariable +from b_asic.resources import ProcessCollection def _insert_delays(inputorder, outputorder, min_lifetime, cyclic): @@ -14,9 +15,7 @@ def _insert_delays(inputorder, outputorder, min_lifetime, cyclic): outputorder = [o - maxdiff + min_lifetime for o in outputorder] maxdelay = max(outputorder[i] - inputorder[i] for i in range(size)) if cyclic: - if maxdelay < size: - outputorder = [o % size for o in outputorder] - else: + if maxdelay >= size: inputorder = inputorder + [i + size for i in inputorder] outputorder = outputorder + [o + size for o in outputorder] return inputorder, outputorder @@ -24,7 +23,7 @@ def _insert_delays(inputorder, outputorder, min_lifetime, cyclic): def generate_random_interleaver( size: int, min_lifetime: int = 0, cyclic: bool = True -) -> Set[PlainMemoryVariable]: +) -> ProcessCollection: """ Generate a ProcessCollection with memory variable corresponding to a random interleaver with length *size*. @@ -48,15 +47,19 @@ def generate_random_interleaver( inputorder = list(range(size)) outputorder = inputorder[:] random.shuffle(outputorder) - print(inputorder, outputorder) inputorder, outputorder = _insert_delays( inputorder, outputorder, min_lifetime, cyclic ) - print(inputorder, outputorder) - return { - PlainMemoryVariable(inputorder[i], 0, {0: outputorder[i]}) - for i in range(size) - } + return ProcessCollection( + { + PlainMemoryVariable( + inputorder[i], 0, {0: outputorder[i] - inputorder[i]} + ) + for i in range(len(inputorder)) + }, + len(inputorder), + cyclic, + ) def generate_matrix_transposer( @@ -64,7 +67,7 @@ def generate_matrix_transposer( width: Optional[int] = None, min_lifetime: int = 0, cyclic: bool = True, -) -> Set[PlainMemoryVariable]: +) -> ProcessCollection: r""" Generate a ProcessCollection with memory variable corresponding to transposing a matrix of size *height* :math:`\times` *width*. If *width* is not provided, a @@ -101,12 +104,19 @@ def generate_matrix_transposer( for col in range(height): outputorder.append(col * width + row) - print(inputorder, outputorder) inputorder, outputorder = _insert_delays( inputorder, outputorder, min_lifetime, cyclic ) - print(inputorder, outputorder) - return { - PlainMemoryVariable(inputorder[i], 0, {0: outputorder[i]}) - for i in range(width * height) - } + return ProcessCollection( + { + PlainMemoryVariable( + inputorder[i], + 0, + {0: outputorder[i] - inputorder[i]}, + name=f"{inputorder[i]}", + ) + for i in range(len(inputorder)) + }, + len(inputorder), + cyclic, + ) diff --git a/b_asic/resources.py b/b_asic/resources.py index eaed928f9f5c8ad5b70059ae6371af885ada8833..c91047b240e027ddae810f48d8905bd8d65bac42 100644 --- a/b_asic/resources.py +++ b/b_asic/resources.py @@ -1,3 +1,4 @@ +import re from typing import Dict, List, Optional, Set, Tuple, Union import matplotlib.pyplot as plt @@ -8,6 +9,16 @@ from matplotlib.ticker import MaxNLocator from b_asic.process import Process +# From https://stackoverflow.com/questions/2669059/how-to-sort-alpha-numeric-set-in-python +def _sorted_nicely(to_be_sorted): + """Sort the given iterable in the way that humans expect.""" + convert = lambda text: int(text) if text.isdigit() else text + alphanum_key = lambda key: [ + convert(c) for c in re.split('([0-9]+)', str(key)) + ] + return sorted(to_be_sorted, key=alphanum_key) + + def draw_exclusion_graph_coloring( exclusion_graph: nx.Graph, color_dict: Dict[Process, int], @@ -79,14 +90,23 @@ class ProcessCollection: Parameters ---------- - collection : set of :class:`~b_asic.process.Process` objects, optional + collection : set of :class:`~b_asic.process.Process` objects + The Process objects forming this ProcessCollection. + schedule_time : int, default: 0 + Length of the time-axis in the generated graph. + cyclic : bool, default: False + If the processes operates cyclically, i.e., if time 0 == time *schedule_time*. """ - def __init__(self, collection: Optional[Set[Process]] = None): - if collection is None: - self._collection: Set[Process] = set() - else: - self._collection = collection + def __init__( + self, + collection: Set[Process], + schedule_time: int, + cyclic: bool = False, + ): + self._collection = collection + self._schedule_time = schedule_time + self._cyclic = cyclic def add_process(self, process: Process): """ @@ -101,7 +121,6 @@ class ProcessCollection: def draw_lifetime_chart( self, - schedule_time: int = 0, ax: Optional[Axes] = None, show_name: bool = True, ): @@ -110,9 +129,6 @@ class ProcessCollection: Parameters ---------- - schedule_time : int, default: 0 - Length of the time-axis in the generated graph. The time axis will span [0, schedule_time-1]. - If set to zero (which is the default), the ... ax : :class:`matplotlib.axes.Axes`, optional Matplotlib Axes object to draw this lifetime chart onto. If not provided (i.e., set to None), this will return a new axes object on return. @@ -133,26 +149,19 @@ class ProcessCollection: # Draw the lifetime chart PAD_L, PAD_R = 0.05, 0.05 max_execution_time = max( - [process.execution_time for process in self._collection] + process.execution_time for process in self._collection ) - schedule_time = ( - schedule_time - if schedule_time - else max(p.start_time + p.execution_time for p in self._collection) - ) - if max_execution_time > schedule_time: + if max_execution_time > self._schedule_time: # Schedule time needs to be greater than or equal to the maximum process life time raise KeyError( - f'Error: Schedule time: {schedule_time} < Max execution time:' - f' {max_execution_time}' + f'Error: Schedule time: {self._schedule_time} < Max execution' + f' time: {max_execution_time}' ) - for i, process in enumerate( - sorted(self._collection, key=lambda p: str(p)) - ): - bar_start = process.start_time % schedule_time + for i, process in enumerate(_sorted_nicely(self._collection)): + bar_start = process.start_time % self._schedule_time bar_end = ( process.start_time + process.execution_time - ) % schedule_time + ) % self._schedule_time if bar_end > bar_start: _ax.broken_barh( [(PAD_L + bar_start, bar_end - bar_start - PAD_L - PAD_R)], @@ -164,7 +173,7 @@ class ProcessCollection: [ ( PAD_L + bar_start, - schedule_time - bar_start - PAD_L, + self._schedule_time - bar_start - PAD_L, ) ], (i + 0.55, 0.9), @@ -175,7 +184,10 @@ class ProcessCollection: [ ( PAD_L + bar_start, - schedule_time - bar_start - PAD_L - PAD_R, + self._schedule_time + - bar_start + - PAD_L + - PAD_R, ) ], (i + 0.55, 0.9), @@ -190,6 +202,8 @@ class ProcessCollection: _ax.xaxis.set_major_locator(MaxNLocator(integer=True)) _ax.yaxis.set_major_locator(MaxNLocator(integer=True)) + _ax.set_xlim(0, self._schedule_time) + _ax.set_ylim(0.25, len(self._collection) + 0.75) return _ax def create_exclusion_graph_from_overlap( @@ -332,12 +346,14 @@ class ProcessCollection: coloring = nx.coloring.greedy_color(exclusion_graph) draw_exclusion_graph_coloring(exclusion_graph, coloring) # process_collection_list = [ProcessCollection()]*(max(coloring.values()) + 1) - process_collection_list = [ - ProcessCollection() for _ in range(max(coloring.values()) + 1) + process_collection_set_list = [ + set() for _ in range(max(coloring.values()) + 1) ] for process, color in coloring.items(): - process_collection_list[color].add_process(process) + process_collection_set_list[color].add(process) return { - process_collection - for process_collection in process_collection_list + ProcessCollection( + process_collection_set, self._schedule_time, self._cyclic + ) + for process_collection_set in process_collection_set_list } diff --git a/test/baseline/test_draw_matrix_transposer_4.png b/test/baseline/test_draw_matrix_transposer_4.png new file mode 100644 index 0000000000000000000000000000000000000000..962fc776b73b1096b8d6080fce8d88a380f52c60 Binary files /dev/null and b/test/baseline/test_draw_matrix_transposer_4.png differ diff --git a/test/baseline/test_draw_process_collection.png b/test/baseline/test_draw_process_collection.png index 0ab1996784015b6d0bc3e5bc68e0f3ae21c64d94..87a5bf7fd1ce3055f5cc993ab684c560ef06e4b5 100644 Binary files a/test/baseline/test_draw_process_collection.png and b/test/baseline/test_draw_process_collection.png differ diff --git a/test/fixtures/resources.py b/test/fixtures/resources.py index 6317e4cb7d873ddf0dc258708cc4d6e90034a205..61c8db254f09014ba55a482a8df28481bdc0d46b 100644 --- a/test/fixtures/resources.py +++ b/test/fixtures/resources.py @@ -16,12 +16,13 @@ def simple_collection(): PlainMemoryVariable(0, NO_PORT, {NO_PORT: 3}), PlainMemoryVariable(0, NO_PORT, {NO_PORT: 2}), PlainMemoryVariable(0, NO_PORT, {NO_PORT: 6}), - } + }, + 8, ) @pytest.fixture() -def collection(): +def cyclic_simple_collection(): NO_PORT = 0 return ProcessCollection( { @@ -32,5 +33,7 @@ def collection(): PlainMemoryVariable(0, NO_PORT, {NO_PORT: 3}), PlainMemoryVariable(0, NO_PORT, {NO_PORT: 2}), PlainMemoryVariable(0, NO_PORT, {NO_PORT: 6}), - } + }, + 6, + True, ) diff --git a/test/test_resources.py b/test/test_resources.py index 67f20cc92fe43276229af1d8a5e93e5409d712f1..10e401adf520a9b837c864bb20ab2d2742327237 100644 --- a/test/test_resources.py +++ b/test/test_resources.py @@ -2,8 +2,11 @@ import matplotlib.pyplot as plt import networkx as nx import pytest -from b_asic.process import PlainMemoryVariable -from b_asic.resources import ProcessCollection, draw_exclusion_graph_coloring +from b_asic.research.interleaver import ( + generate_matrix_transposer, + generate_random_interleaver, +) +from b_asic.resources import draw_exclusion_graph_coloring class TestProcessCollectionPlainMemoryVariable: @@ -15,7 +18,7 @@ class TestProcessCollectionPlainMemoryVariable: def test_draw_proces_collection(self, simple_collection): _, ax = plt.subplots(1, 2) - simple_collection.draw_lifetime_chart(schedule_time=8, ax=ax[0]) + simple_collection.draw_lifetime_chart(ax=ax[0]) exclusion_graph = ( simple_collection.create_exclusion_graph_from_overlap() ) @@ -27,3 +30,26 @@ class TestProcessCollectionPlainMemoryVariable: read_ports=1, write_ports=1, total_ports=2 ) assert len(collection_split) == 3 + + @pytest.mark.mpl_image_compare(style='mpl20') + def test_draw_matrix_transposer_4(self): + fig, ax = plt.subplots() + generate_matrix_transposer(4).draw_lifetime_chart(ax=ax) + return fig + + def test_generate_random_interleaver(self): + return + for _ in range(10): + for size in range(5, 20, 5): + assert ( + len( + generate_random_interleaver(size).split( + read_ports=1, write_ports=1 + ) + ) + == 1 + ) + assert ( + len(generate_random_interleaver(size).split(total_ports=1)) + == 2 + )