diff --git a/.gitignore b/.gitignore index d251d2bfb551e0abf8cea3fe0911d1aeece9e2c0..d6034647a557c21f67a8a128c313d80164c2f0bb 100644 --- a/.gitignore +++ b/.gitignore @@ -115,3 +115,4 @@ TODO.txt *.log b_asic/_version.py docs_sphinx/_build/ +docs_sphinx/examples diff --git a/README.md b/README.md index f66666903b202acd0a09b6d72a3fbbc1ada0c86c..f2df4a12b14542c1d95bfd581c0bba8f07ba2da3 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ The following packages are required in order to build the library: - [NumPy](https://numpy.org/) - [QtPy](https://github.com/spyder-ide/qtpy) - [setuptools_scm](https://github.com/pypa/setuptools_scm/) + - [NetworkX](https://networkx.org/) - Qt 5 or 6, with Python bindings, one of: - pyside2 - pyqt5 diff --git a/b_asic/process.py b/b_asic/process.py index 131ad5995b40d5bb9600b7a7795cbacefd8c082f..d48fa0ee5ff24b2ab3667918a30e71a055776706 100644 --- a/b_asic/process.py +++ b/b_asic/process.py @@ -2,7 +2,7 @@ B-ASIC classes representing resource usage. """ -from typing import Dict, Tuple +from typing import Dict, Optional, Tuple from b_asic.operation import Operation from b_asic.port import InputPort, OutputPort @@ -130,10 +130,14 @@ class PlainMemoryVariable(Process): write_time: int, write_port: int, reads: Dict[int, int], + name: Optional[str] = None, ): self._read_ports = tuple(reads.keys()) self._life_times = tuple(reads.values()) self._write_port = write_port + if name is None: + self._name = str(PlainMemoryVariable._name_cnt) + PlainMemoryVariable._name_cnt += 1 super().__init__( start_time=write_time, execution_time=max(self._life_times) ) @@ -149,3 +153,13 @@ class PlainMemoryVariable(Process): @property def write_port(self) -> int: return self._write_port + + @property + def name(self) -> str: + return self._name + + def __str__(self) -> str: + return self._name + + # Static counter for default names + _name_cnt = 0 diff --git a/b_asic/resources.py b/b_asic/resources.py new file mode 100644 index 0000000000000000000000000000000000000000..eaed928f9f5c8ad5b70059ae6371af885ada8833 --- /dev/null +++ b/b_asic/resources.py @@ -0,0 +1,343 @@ +from typing import Dict, List, Optional, Set, Tuple, Union + +import matplotlib.pyplot as plt +import networkx as nx +from matplotlib.axes import Axes +from matplotlib.ticker import MaxNLocator + +from b_asic.process import Process + + +def draw_exclusion_graph_coloring( + exclusion_graph: nx.Graph, + color_dict: Dict[Process, int], + ax: Optional[Axes] = None, + color_list: Optional[ + Union[List[str], List[Tuple[float, float, float]]] + ] = None, +): + """ + Use matplotlib.pyplot and networkx to draw a colored exclusion graph from the memory assigment + + .. code-block:: python + + _, ax = plt.subplots(1, 1) + collection = ProcessCollection(...) + exclusion_graph = collection.create_exclusion_graph_from_overlap() + color_dict = nx.greedy_color(exclusion_graph) + draw_exclusion_graph_coloring(exclusion_graph, color_dict, ax=ax[0]) + plt.show() + + Parameters + ---------- + exclusion_graph : nx.Graph + A nx.Graph exclusion graph object that is to be drawn. + + color_dict : dictionary + A color dictionary where keys are Process objects and where values are integers representing colors. These + dictionaries are automatically generated by :func:`networkx.algorithms.coloring.greedy_color`. + + ax : :class:`matplotlib.axes.Axes`, optional + A Matplotlib Axes object to draw the exclusion graph + + color_list : Optional[Union[List[str], List[Tuple[float,float,float]]]] + """ + COLOR_LIST = [ + '#aa0000', + '#00aa00', + '#0000ff', + '#ff00aa', + '#ffaa00', + '#00ffaa', + '#aaff00', + '#aa00ff', + '#00aaff', + '#ff0000', + '#00ff00', + '#0000aa', + '#aaaa00', + '#aa00aa', + '#00aaaa', + ] + node_color_dict = {} + if color_list is None: + node_color_dict = {k: COLOR_LIST[v] for k, v in color_dict.items()} + else: + node_color_dict = {k: color_list[v] for k, v in color_dict.items()} + node_color_list = [node_color_dict[node] for node in exclusion_graph] + nx.draw_networkx( + exclusion_graph, + node_color=node_color_list, + ax=ax, + pos=nx.spring_layout(exclusion_graph, seed=1), + ) + + +class ProcessCollection: + """ + Collection of one or more processes + + Parameters + ---------- + collection : set of :class:`~b_asic.process.Process` objects, optional + """ + + def __init__(self, collection: Optional[Set[Process]] = None): + if collection is None: + self._collection: Set[Process] = set() + else: + self._collection = collection + + def add_process(self, process: Process): + """ + Add a new process to this process collection. + + Parameters + ---------- + process : Process + The process object to be added to the collection + """ + self._collection.add(process) + + def draw_lifetime_chart( + self, + schedule_time: int = 0, + ax: Optional[Axes] = None, + show_name: bool = True, + ): + """ + Use matplotlib.pyplot to generate a process variable lifetime chart from this process collection. + + Parameters + ---------- + schedule_time : int, default: 0 + Length of the time-axis in the generated graph. The time axis will span [0, schedule_time-1]. + If set to zero (which is the default), the ... + ax : :class:`matplotlib.axes.Axes`, optional + Matplotlib Axes object to draw this lifetime chart onto. If not provided (i.e., set to None), this will + return a new axes object on return. + show_name : bool, default: True + Show name of all processes in the lifetime chart. + + Returns + ------- + ax: Associated Matplotlib Axes (or array of Axes) object + """ + + # Setup the Axes object + if ax is None: + _, _ax = plt.subplots() + else: + _ax = ax + + # Draw the lifetime chart + PAD_L, PAD_R = 0.05, 0.05 + max_execution_time = max( + [process.execution_time for process in self._collection] + ) + schedule_time = ( + schedule_time + if schedule_time + else max(p.start_time + p.execution_time for p in self._collection) + ) + if max_execution_time > schedule_time: + # Schedule time needs to be greater than or equal to the maximum process life time + raise KeyError( + f'Error: Schedule time: {schedule_time} < Max execution time:' + f' {max_execution_time}' + ) + for i, process in enumerate( + sorted(self._collection, key=lambda p: str(p)) + ): + bar_start = process.start_time % schedule_time + bar_end = ( + process.start_time + process.execution_time + ) % schedule_time + if bar_end > bar_start: + _ax.broken_barh( + [(PAD_L + bar_start, bar_end - bar_start - PAD_L - PAD_R)], + (i + 0.55, 0.9), + ) + else: # bar_end < bar_start + if bar_end != 0: + _ax.broken_barh( + [ + ( + PAD_L + bar_start, + schedule_time - bar_start - PAD_L, + ) + ], + (i + 0.55, 0.9), + ) + _ax.broken_barh([(0, bar_end - PAD_R)], (i + 0.55, 0.9)) + else: + _ax.broken_barh( + [ + ( + PAD_L + bar_start, + schedule_time - bar_start - PAD_L - PAD_R, + ) + ], + (i + 0.55, 0.9), + ) + if show_name: + _ax.annotate( + str(process), + (bar_start + PAD_L + 0.025, i + 1.00), + va="center", + ) + _ax.grid(True) + + _ax.xaxis.set_major_locator(MaxNLocator(integer=True)) + _ax.yaxis.set_major_locator(MaxNLocator(integer=True)) + return _ax + + def create_exclusion_graph_from_overlap( + self, add_name: bool = True + ) -> nx.Graph: + """ + Generate exclusion graph based on processes overlaping in time + + Parameters + ---------- + add_name : bool, default: True + Add name of all processes as a node attribute in the exclusion graph. + + Returns + ------- + An nx.Graph exclusion graph where nodes are processes and arcs + between two processes indicated overlap in time + """ + exclusion_graph = nx.Graph() + exclusion_graph.add_nodes_from(self._collection) + for process1 in self._collection: + for process2 in self._collection: + if process1 == process2: + continue + else: + t1 = set( + range( + process1.start_time, + process1.start_time + process1.execution_time, + ) + ) + t2 = set( + range( + process2.start_time, + process2.start_time + process2.execution_time, + ) + ) + if t1.intersection(t2): + exclusion_graph.add_edge(process1, process2) + return exclusion_graph + + def split( + self, + heuristic: str = "graph_color", + read_ports: Optional[int] = None, + write_ports: Optional[int] = None, + total_ports: Optional[int] = None, + ) -> Set["ProcessCollection"]: + """ + Split this process storage based on some heuristic. + + Parameters + ---------- + heuristic : str, default: "graph_color" + The heuristic used when spliting this ProcessCollection. + Valid options are: + * "graph_color" + * "..." + read_ports : int, optional + The number of read ports used when spliting process collection based on memory variable access. + write_ports : int, optional + The number of write ports used when spliting process collection based on memory variable access. + total_ports : int, optional + The total number of ports used when spliting process collection based on memory variable access. + + Returns + ------- + A set of new ProcessColleciton objects with the process spliting. + """ + if total_ports is None: + if read_ports is None or write_ports is None: + raise ValueError("inteligent quote") + else: + total_ports = read_ports + write_ports + else: + read_ports = total_ports if read_ports is None else read_ports + write_ports = total_ports if write_ports is None else write_ports + + if heuristic == "graph_color": + return self._split_graph_color( + read_ports, write_ports, total_ports + ) + else: + raise ValueError("Invalid heuristic provided") + + def _split_graph_color( + self, read_ports: int, write_ports: int, total_ports: int + ) -> Set["ProcessCollection"]: + """ + Parameters + ---------- + read_ports : int, optional + The number of read ports used when spliting process collection based on memory variable access. + write_ports : int, optional + The number of write ports used when spliting process collection based on memory variable access. + total_ports : int, optional + The total number of ports used when spliting process collection based on memory variable access. + """ + if read_ports != 1 or write_ports != 1: + raise ValueError( + "Spliting with read and write ports not equal to one with the" + " graph coloring heuristic does not make sense." + ) + if total_ports not in (1, 2): + raise ValueError( + "Total ports should be either 1 (non-concurent reads/writes)" + " or 2 (concurrent read/writes) for graph coloring heuristic." + ) + + # Create new exclusion graph. Nodes are Processes + exclusion_graph = nx.Graph() + exclusion_graph.add_nodes_from(self._collection) + + # Add exclusions (arcs) between processes in the exclusion graph + for node1 in exclusion_graph: + for node2 in exclusion_graph: + if node1 == node2: + continue + else: + node1_stop_time = node1.start_time + node1.execution_time + node2_stop_time = node2.start_time + node2.execution_time + if total_ports == 1: + # Single-port assignment + if node1.start_time == node2.start_time: + exclusion_graph.add_edge(node1, node2) + elif node1_stop_time == node2_stop_time: + exclusion_graph.add_edge(node1, node2) + elif node1.start_time == node2_stop_time: + exclusion_graph.add_edge(node1, node2) + elif node1_stop_time == node2.start_time: + exclusion_graph.add_edge(node1, node2) + else: + # Dual-port assignment + if node1.start_time == node2.start_time: + exclusion_graph.add_edge(node1, node2) + elif node1_stop_time == node2_stop_time: + exclusion_graph.add_edge(node1, node2) + + # Perform assignment + coloring = nx.coloring.greedy_color(exclusion_graph) + draw_exclusion_graph_coloring(exclusion_graph, coloring) + # process_collection_list = [ProcessCollection()]*(max(coloring.values()) + 1) + process_collection_list = [ + ProcessCollection() for _ in range(max(coloring.values()) + 1) + ] + for process, color in coloring.items(): + process_collection_list[color].add_process(process) + return { + process_collection + for process_collection in process_collection_list + } diff --git a/docs_sphinx/api/index.rst b/docs_sphinx/api/index.rst index 8ad0965a14a70f922e562a8380c97e3688ed1587..c3480c10839dfe8cc14fc216fc9fed6e18ef85e9 100644 --- a/docs_sphinx/api/index.rst +++ b/docs_sphinx/api/index.rst @@ -10,6 +10,7 @@ API operation.rst port.rst process.rst + resources.rst schedule.rst sfg_generators.rst signal.rst diff --git a/docs_sphinx/api/resources.rst b/docs_sphinx/api/resources.rst new file mode 100644 index 0000000000000000000000000000000000000000..d87fa73bfb74905c5528098c83fa86fca0e149d5 --- /dev/null +++ b/docs_sphinx/api/resources.rst @@ -0,0 +1,7 @@ +******************** +``b_asic.resources`` +******************** + +.. automodule:: b_asic.resources + :members: + :undoc-members: diff --git a/docs_sphinx/conf.py b/docs_sphinx/conf.py index a8ac592dbaa05be232ee823874298e1793cebca0..90b049c77ea82b912bad97401eb2f1b2e65d9763 100644 --- a/docs_sphinx/conf.py +++ b/docs_sphinx/conf.py @@ -39,6 +39,7 @@ intersphinx_mapping = { 'matplotlib': ('https://matplotlib.org/stable/', None), 'numpy': ('https://numpy.org/doc/stable/', None), 'PyQt5': ("https://www.riverbankcomputing.com/static/Docs/PyQt5", None), + 'networkx': ('https://networkx.org/documentation/stable', None), } numpydoc_show_class_members = False diff --git a/pyproject.toml b/pyproject.toml index fef65a9d736b856afafcac85958ff0c18ecc5dd3..be37e866e99b80a006e3680bd42109d9ebfaf709 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "graphviz>=0.19", "matplotlib", "setuptools_scm[toml]>=6.2", + "networkx", ] classifiers = [ "Intended Audience :: Education", diff --git a/requirements.txt b/requirements.txt index 343973832ddda6f448d7a07353e9cf7e7a96b0cc..1a591a913fc52ebd44e66d27fd3a3694ae9baaf8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ qtpy graphviz>=0.19 matplotlib setuptools_scm[toml]>=6.2 +networkx diff --git a/test/baseline/test_draw_process_collection.png b/test/baseline/test_draw_process_collection.png new file mode 100644 index 0000000000000000000000000000000000000000..0ab1996784015b6d0bc3e5bc68e0f3ae21c64d94 Binary files /dev/null and b/test/baseline/test_draw_process_collection.png differ diff --git a/test/conftest.py b/test/conftest.py index 179a82c24e7ed3f3fb375fdfa3088a33eac1ef31..138fefe00e2bcb08efdbd238f297181febffe186 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -2,6 +2,7 @@ import os from distutils import dir_util from test.fixtures.operation_tree import * from test.fixtures.port import * +from test.fixtures.resources import * from test.fixtures.schedule import * from test.fixtures.signal import signal, signals from test.fixtures.signal_flow_graph import * diff --git a/test/fixtures/resources.py b/test/fixtures/resources.py new file mode 100644 index 0000000000000000000000000000000000000000..6317e4cb7d873ddf0dc258708cc4d6e90034a205 --- /dev/null +++ b/test/fixtures/resources.py @@ -0,0 +1,36 @@ +import pytest + +from b_asic.process import PlainMemoryVariable +from b_asic.resources import ProcessCollection + + +@pytest.fixture() +def simple_collection(): + NO_PORT = 0 + return ProcessCollection( + { + PlainMemoryVariable(4, NO_PORT, {NO_PORT: 2}), + PlainMemoryVariable(2, NO_PORT, {NO_PORT: 6}), + PlainMemoryVariable(3, NO_PORT, {NO_PORT: 5}), + PlainMemoryVariable(6, NO_PORT, {NO_PORT: 2}), + PlainMemoryVariable(0, NO_PORT, {NO_PORT: 3}), + PlainMemoryVariable(0, NO_PORT, {NO_PORT: 2}), + PlainMemoryVariable(0, NO_PORT, {NO_PORT: 6}), + } + ) + + +@pytest.fixture() +def collection(): + NO_PORT = 0 + return ProcessCollection( + { + PlainMemoryVariable(4, NO_PORT, {NO_PORT: 2}), + PlainMemoryVariable(2, NO_PORT, {NO_PORT: 6}), + PlainMemoryVariable(3, NO_PORT, {NO_PORT: 5}), + PlainMemoryVariable(6, NO_PORT, {NO_PORT: 2}), + PlainMemoryVariable(0, NO_PORT, {NO_PORT: 3}), + PlainMemoryVariable(0, NO_PORT, {NO_PORT: 2}), + PlainMemoryVariable(0, NO_PORT, {NO_PORT: 6}), + } + ) diff --git a/test/test_resources.py b/test/test_resources.py new file mode 100644 index 0000000000000000000000000000000000000000..67f20cc92fe43276229af1d8a5e93e5409d712f1 --- /dev/null +++ b/test/test_resources.py @@ -0,0 +1,29 @@ +import matplotlib.pyplot as plt +import networkx as nx +import pytest + +from b_asic.process import PlainMemoryVariable +from b_asic.resources import ProcessCollection, draw_exclusion_graph_coloring + + +class TestProcessCollectionPlainMemoryVariable: + @pytest.mark.mpl_image_compare(style='mpl20') + def test_draw_process_collection(self, simple_collection): + fig, ax = plt.subplots() + simple_collection.draw_lifetime_chart(ax=ax) + return fig + + def test_draw_proces_collection(self, simple_collection): + _, ax = plt.subplots(1, 2) + simple_collection.draw_lifetime_chart(schedule_time=8, ax=ax[0]) + exclusion_graph = ( + simple_collection.create_exclusion_graph_from_overlap() + ) + color_dict = nx.coloring.greedy_color(exclusion_graph) + draw_exclusion_graph_coloring(exclusion_graph, color_dict, ax=ax[1]) + + def test_split_memory_variable(self, simple_collection): + collection_split = simple_collection.split( + read_ports=1, write_ports=1, total_ports=2 + ) + assert len(collection_split) == 3