diff --git a/.gitignore b/.gitignore index d251d2bfb551e0abf8cea3fe0911d1aeece9e2c0..d6034647a557c21f67a8a128c313d80164c2f0bb 100644 --- a/.gitignore +++ b/.gitignore @@ -115,3 +115,4 @@ TODO.txt *.log b_asic/_version.py docs_sphinx/_build/ +docs_sphinx/examples diff --git a/b_asic/process.py b/b_asic/process.py index 131ad5995b40d5bb9600b7a7795cbacefd8c082f..d48fa0ee5ff24b2ab3667918a30e71a055776706 100644 --- a/b_asic/process.py +++ b/b_asic/process.py @@ -2,7 +2,7 @@ B-ASIC classes representing resource usage. """ -from typing import Dict, Tuple +from typing import Dict, Optional, Tuple from b_asic.operation import Operation from b_asic.port import InputPort, OutputPort @@ -130,10 +130,14 @@ class PlainMemoryVariable(Process): write_time: int, write_port: int, reads: Dict[int, int], + name: Optional[str] = None, ): self._read_ports = tuple(reads.keys()) self._life_times = tuple(reads.values()) self._write_port = write_port + if name is None: + self._name = str(PlainMemoryVariable._name_cnt) + PlainMemoryVariable._name_cnt += 1 super().__init__( start_time=write_time, execution_time=max(self._life_times) ) @@ -149,3 +153,13 @@ class PlainMemoryVariable(Process): @property def write_port(self) -> int: return self._write_port + + @property + def name(self) -> str: + return self._name + + def __str__(self) -> str: + return self._name + + # Static counter for default names + _name_cnt = 0 diff --git a/b_asic/resources.py b/b_asic/resources.py new file mode 100644 index 0000000000000000000000000000000000000000..2c55e0af2354cd73b57844c3916fc180ffef2c92 --- /dev/null +++ b/b_asic/resources.py @@ -0,0 +1,339 @@ +from typing import Dict, List, Optional, Set, Tuple, Union + +import matplotlib.pyplot as plt +import networkx as nx +from matplotlib.axes import Axes +from matplotlib.ticker import MaxNLocator + +from b_asic.process import Process + + +def draw_exclusion_graph_coloring( + exclusion_graph: nx.Graph, + color_dict: Dict[Process, int], + ax: Optional[Axes] = None, + color_list: Optional[ + Union[List[str], List[Tuple[float, float, float]]] + ] = None, +): + """ + Use matplotlib.pyplot and networkx to draw a colored exclusion graph from the memory assigment + + .. code-block:: python + + _, ax = plt.subplots(1, 1) + collection = ProcessCollection(...) + exclusion_graph = collection.create_exclusion_graph_from_overlap() + color_dict = nx.greedy_color(exclusion_graph) + draw_exclusion_graph_coloring(exclusion_graph, color_dict, ax=ax[0]) + plt.show() + + Parameters + ---------- + exclusion_graph : nx.Graph + A nx.Graph exclusion graph object that is to be drawn. + + color_dict : dictionary + A color dictionary where keys are Process objects and where values are integers representing colors. These + dictionaries are automatically generated by :func:`networkx.algorithms.coloring.greedy_color`. + + ax : :class:`matplotlib.axes.Axes`, optional + A Matplotlib Axes object to draw the exclusion graph + + color_list : Optional[Union[List[str], List[Tuple[float,float,float]]]] + """ + COLOR_LIST = [ + '#aa0000', + '#00aa00', + '#0000ff', + '#ff00aa', + '#ffaa00', + '#00ffaa', + '#aaff00', + '#aa00ff', + '#00aaff', + '#ff0000', + '#00ff00', + '#0000aa', + '#aaaa00', + '#aa00aa', + '#00aaaa', + ] + node_color_dict = {} + if color_list is None: + node_color_dict = {k: COLOR_LIST[v] for k, v in color_dict.items()} + else: + node_color_dict = {k: color_list[v] for k, v in color_dict.items()} + node_color_list = [node_color_dict[node] for node in exclusion_graph] + nx.draw_networkx(exclusion_graph, node_color=node_color_list, ax=ax) + + +class ProcessCollection: + """ + Collection of one or more processes + + Parameters + ---------- + collection : set of :class:`~b_asic.process.Process` objects, optional + """ + + def __init__(self, collection: Optional[Set[Process]] = None): + if collection is None: + self._collection = set[Process]() + else: + self._collection = collection + + def add_process(self, process: Process): + """ + Add a new process to this process collection. + + Parameters + ---------- + process : Process + The process object to be added to the collection + """ + self._collection.add(process) + + def draw_lifetime_chart( + self, + schedule_time: int = 0, + ax: Optional[Axes] = None, + show_name: bool = True, + ): + """ + Use matplotlib.pyplot to generate a process variable lifetime chart from this process collection. + + Parameters + ---------- + schedule_time : int, default: 0 + Length of the time-axis in the generated graph. The time axis will span [0, schedule_time-1]. + If set to zero (which is the default), the ... + ax : :class:`matplotlib.axes.Axes`, optional + Matplotlib Axes object to draw this lifetime chart onto. If not provided (i.e., set to None), this will + return a new axes object on return. + show_name : bool, default: True + Show name of all processes in the lifetime chart. + + Returns + ------- + ax: Associated Matplotlib Axes (or array of Axes) object + """ + + # Setup the Axes object + if ax is None: + _, _ax = plt.subplots() + else: + _ax = ax + + # Draw the lifetime chart + PAD_L, PAD_R = 0.05, 0.05 + max_execution_time = max( + [process.execution_time for process in self._collection] + ) + schedule_time = ( + schedule_time + if schedule_time + else max(p.start_time + p.execution_time for p in self._collection) + ) + if max_execution_time > schedule_time: + # Schedule time needs to be greater than or equal to the maximum process life time + raise KeyError( + f'Error: Schedule time: {schedule_time} < Max execution time:' + f' {max_execution_time}' + ) + for i, process in enumerate( + sorted(self._collection, key=lambda p: str(p)) + ): + bar_start = process.start_time % schedule_time + bar_end = ( + process.start_time + process.execution_time + ) % schedule_time + if bar_end > bar_start: + _ax.broken_barh( + [(PAD_L + bar_start, bar_end - bar_start - PAD_L - PAD_R)], + (i + 0.55, 0.9), + ) + else: # bar_end < bar_start + if bar_end != 0: + _ax.broken_barh( + [ + ( + PAD_L + bar_start, + schedule_time - bar_start - PAD_L, + ) + ], + (i + 0.55, 0.9), + ) + _ax.broken_barh([(0, bar_end - PAD_R)], (i + 0.55, 0.9)) + else: + _ax.broken_barh( + [ + ( + PAD_L + bar_start, + schedule_time - bar_start - PAD_L - PAD_R, + ) + ], + (i + 0.55, 0.9), + ) + if show_name: + _ax.annotate( + str(process), + (bar_start + PAD_L + 0.025, i + 1.00), + va="center", + ) + _ax.grid(True) + _ax.set_title(f'Schedule time: {schedule_time}') + + _ax.xaxis.set_major_locator(MaxNLocator(integer=True)) + _ax.yaxis.set_major_locator(MaxNLocator(integer=True)) + return _ax + + def create_exclusion_graph_from_overlap( + self, add_name: bool = True + ) -> nx.Graph: + """ + Generate exclusion graph based on processes overlaping in time + + Parameters + ---------- + add_name : bool, default: True + Add name of all processes as a node attribute in the exclusion graph. + + Returns + ------- + An nx.Graph exclusion graph where nodes are processes and arcs + between two processes indicated overlap in time + """ + exclusion_graph = nx.Graph() + exclusion_graph.add_nodes_from(self._collection) + for process1 in self._collection: + for process2 in self._collection: + if process1 == process2: + continue + else: + t1 = set( + range( + process1.start_time, + process1.start_time + process1.execution_time, + ) + ) + t2 = set( + range( + process2.start_time, + process2.start_time + process2.execution_time, + ) + ) + if t1.intersection(t2): + exclusion_graph.add_edge(process1, process2) + return exclusion_graph + + def split( + self, + heuristic: str = "graph_color", + read_ports: Optional[int] = None, + write_ports: Optional[int] = None, + total_ports: Optional[int] = None, + ) -> Set["ProcessCollection"]: + """ + Split this process storage based on some heuristic. + + Parameters + ---------- + heuristic : str, default: "graph_color" + The heuristic used when spliting this ProcessCollection. + Valid options are: + * "graph_color" + * "..." + read_ports : int, optional + The number of read ports used when spliting process collection based on memory variable access. + write_ports : int, optional + The number of write ports used when spliting process collection based on memory variable access. + total_ports : int, optional + The total number of ports used when spliting process collection based on memory variable access. + + Returns + ------- + A set of new ProcessColleciton objects with the process spliting. + """ + if total_ports is None: + if read_ports is None or write_ports is None: + raise ValueError("inteligent quote") + else: + total_ports = read_ports + write_ports + else: + read_ports = total_ports if read_ports is None else read_ports + write_ports = total_ports if write_ports is None else write_ports + + if heuristic == "graph_color": + return self._split_graph_color( + read_ports, write_ports, total_ports + ) + else: + raise ValueError("Invalid heuristic provided") + + def _split_graph_color( + self, read_ports: int, write_ports: int, total_ports: int + ) -> Set["ProcessCollection"]: + """ + Parameters + ---------- + read_ports : int, optional + The number of read ports used when spliting process collection based on memory variable access. + write_ports : int, optional + The number of write ports used when spliting process collection based on memory variable access. + total_ports : int, optional + The total number of ports used when spliting process collection based on memory variable access. + """ + if read_ports != 1 or write_ports != 1: + raise ValueError( + "Spliting with read and write ports not equal to one with the" + " graph coloring heuristic does not make sense." + ) + if total_ports not in (1, 2): + raise ValueError( + "Total ports should be either 1 (non-concurent reads/writes)" + " or 2 (concurrent read/writes) for graph coloring heuristic." + ) + + # Create new exclusion graph. Nodes are Processes + exclusion_graph = nx.Graph() + exclusion_graph.add_nodes_from(self._collection) + + # Add exclusions (arcs) between processes in the exclusion graph + for node1 in exclusion_graph: + for node2 in exclusion_graph: + if node1 == node2: + continue + else: + node1_stop_time = node1.start_time + node1.execution_time + node2_stop_time = node2.start_time + node2.execution_time + if total_ports == 1: + # Single-port assignment + if node1.start_time == node2.start_time: + exclusion_graph.add_edge(node1, node2) + elif node1_stop_time == node2_stop_time: + exclusion_graph.add_edge(node1, node2) + elif node1.start_time == node2_stop_time: + exclusion_graph.add_edge(node1, node2) + elif node1_stop_time == node2.start_time: + exclusion_graph.add_edge(node1, node2) + else: + # Dual-port assignment + if node1.start_time == node2.start_time: + exclusion_graph.add_edge(node1, node2) + elif node1_stop_time == node2_stop_time: + exclusion_graph.add_edge(node1, node2) + + # Perform assignment + coloring = nx.coloring.greedy_color(exclusion_graph) + draw_exclusion_graph_coloring(exclusion_graph, coloring) + # process_collection_list = [ProcessCollection()]*(max(coloring.values()) + 1) + process_collection_list = [ + ProcessCollection() for _ in range(max(coloring.values()) + 1) + ] + for process, color in coloring.items(): + process_collection_list[color].add_process(process) + return { + process_collection + for process_collection in process_collection_list + } diff --git a/docs_sphinx/api/index.rst b/docs_sphinx/api/index.rst index 8ad0965a14a70f922e562a8380c97e3688ed1587..c3480c10839dfe8cc14fc216fc9fed6e18ef85e9 100644 --- a/docs_sphinx/api/index.rst +++ b/docs_sphinx/api/index.rst @@ -10,6 +10,7 @@ API operation.rst port.rst process.rst + resources.rst schedule.rst sfg_generators.rst signal.rst diff --git a/docs_sphinx/api/resources.rst b/docs_sphinx/api/resources.rst new file mode 100644 index 0000000000000000000000000000000000000000..d87fa73bfb74905c5528098c83fa86fca0e149d5 --- /dev/null +++ b/docs_sphinx/api/resources.rst @@ -0,0 +1,7 @@ +******************** +``b_asic.resources`` +******************** + +.. automodule:: b_asic.resources + :members: + :undoc-members: diff --git a/docs_sphinx/conf.py b/docs_sphinx/conf.py index a8ac592dbaa05be232ee823874298e1793cebca0..90b049c77ea82b912bad97401eb2f1b2e65d9763 100644 --- a/docs_sphinx/conf.py +++ b/docs_sphinx/conf.py @@ -39,6 +39,7 @@ intersphinx_mapping = { 'matplotlib': ('https://matplotlib.org/stable/', None), 'numpy': ('https://numpy.org/doc/stable/', None), 'PyQt5': ("https://www.riverbankcomputing.com/static/Docs/PyQt5", None), + 'networkx': ('https://networkx.org/documentation/stable', None), } numpydoc_show_class_members = False diff --git a/test/test_resources.py b/test/test_resources.py new file mode 100644 index 0000000000000000000000000000000000000000..fb41017d93722c002bf9356d27ea4688107aa41a --- /dev/null +++ b/test/test_resources.py @@ -0,0 +1,85 @@ +from time import sleep + +import matplotlib.pyplot as plt +import networkx as nx +import pytest + +from b_asic.process import PlainMemoryVariable +from b_asic.resources import ProcessCollection, draw_exclusion_graph_coloring + +NO_PORT = 0 + +# +# Tests to run +# +def test_run(): + # TestProcessCollectionPlainMemoryVariable().test_draw_proces_collection() + TestProcessCollectionPlainMemoryVariable().test_split_memory_variable() + + +class TestProcessCollectionPlainMemoryVariable: + def __init__(self) -> None: + self.collection = ProcessCollection( + { + PlainMemoryVariable(4, NO_PORT, {NO_PORT: 2}), + PlainMemoryVariable(2, NO_PORT, {NO_PORT: 6}), + PlainMemoryVariable(3, NO_PORT, {NO_PORT: 5}), + PlainMemoryVariable(6, NO_PORT, {NO_PORT: 2}), + PlainMemoryVariable(0, NO_PORT, {NO_PORT: 3}), + PlainMemoryVariable(0, NO_PORT, {NO_PORT: 2}), + PlainMemoryVariable(0, NO_PORT, {NO_PORT: 6}), + } + ) + + def test_draw_proces_collection(self): + _, ax = plt.subplots(1, 2) + self.collection.draw_lifetime_chart(schedule_time=8, ax=ax[0]) + exclusion_graph = self.collection.create_exclusion_graph_from_overlap() + color_dict = nx.coloring.greedy_color(exclusion_graph) + draw_exclusion_graph_coloring(exclusion_graph, color_dict, ax=ax[1]) + plt.show() + + def test_split_memory_variable(self): + collection_split = self.collection.split( + read_ports=1, write_ports=1, total_ports=2 + ) + _, ax = plt.subplots(1, len(collection_split) + 1) + # print(f'Length: {len(ax)}') + # assert(False) + self.collection.draw_lifetime_chart(ax=ax[0]) + ax[0].set_title("Original") + for idx, collection in enumerate(collection_split): + collection.draw_lifetime_chart(ax=ax[idx + 1]) + plt.show() + + +# def test_draw_process_collection(): +# collection = ProcessCollection({ +# PlainMemoryVariable(4, NO_PORT, {NO_PORT: 2}), +# PlainMemoryVariable(2, NO_PORT, {NO_PORT: 6}), +# PlainMemoryVariable(3, NO_PORT, {NO_PORT: 5}), +# PlainMemoryVariable(6, NO_PORT, {NO_PORT: 2}), +# PlainMemoryVariable(0, NO_PORT, {NO_PORT: 3}), +# PlainMemoryVariable(0, NO_PORT, {NO_PORT: 2}), +# PlainMemoryVariable(0, NO_PORT, {NO_PORT: 6}), +# }) +# _, ax = plt.subplots(1, 2) +# collection.draw_lifetime_chart(schedule_time=8, ax=ax[0]) +# exclusion_graph = collection.create_exclusion_graph_from_overlap() +# color_dict = nx.greedy_color(exclusion_graph) +# draw_exclusion_graph_coloring(exclusion_graph, color_dict, ax=ax[1]) +# plt.show() +# +# def test_memory_split(): + + +# @pytest.mark.mpl_image_compare(remove_text=True, style='mpl20') + +# def test_create_exclusion_graph_overlap(): +# collection = ProcessCollection({ +# PlainMemoryVariable(4, NO_PORT, {NO_PORT: 2}), +# PlainMemoryVariable(2, NO_PORT, {NO_PORT: 6}), +# }) +# exclusion_graph = collection.create_exclusion_graph_overlap() +# nx.draw(exclusion_graph) +# plt.show()