diff --git a/b_asic/resources.py b/b_asic/resources.py index 2c55e0af2354cd73b57844c3916fc180ffef2c92..eaed928f9f5c8ad5b70059ae6371af885ada8833 100644 --- a/b_asic/resources.py +++ b/b_asic/resources.py @@ -65,7 +65,12 @@ def draw_exclusion_graph_coloring( else: node_color_dict = {k: color_list[v] for k, v in color_dict.items()} node_color_list = [node_color_dict[node] for node in exclusion_graph] - nx.draw_networkx(exclusion_graph, node_color=node_color_list, ax=ax) + nx.draw_networkx( + exclusion_graph, + node_color=node_color_list, + ax=ax, + pos=nx.spring_layout(exclusion_graph, seed=1), + ) class ProcessCollection: @@ -79,7 +84,7 @@ class ProcessCollection: def __init__(self, collection: Optional[Set[Process]] = None): if collection is None: - self._collection = set[Process]() + self._collection: Set[Process] = set() else: self._collection = collection @@ -182,7 +187,6 @@ class ProcessCollection: va="center", ) _ax.grid(True) - _ax.set_title(f'Schedule time: {schedule_time}') _ax.xaxis.set_major_locator(MaxNLocator(integer=True)) _ax.yaxis.set_major_locator(MaxNLocator(integer=True)) diff --git a/test/baseline/test_draw_process_collection.png b/test/baseline/test_draw_process_collection.png new file mode 100644 index 0000000000000000000000000000000000000000..0ab1996784015b6d0bc3e5bc68e0f3ae21c64d94 Binary files /dev/null and b/test/baseline/test_draw_process_collection.png differ diff --git a/test/conftest.py b/test/conftest.py index 179a82c24e7ed3f3fb375fdfa3088a33eac1ef31..138fefe00e2bcb08efdbd238f297181febffe186 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -2,6 +2,7 @@ import os from distutils import dir_util from test.fixtures.operation_tree import * from test.fixtures.port import * +from test.fixtures.resources import * from test.fixtures.schedule import * from test.fixtures.signal import signal, signals from test.fixtures.signal_flow_graph import * diff --git a/test/fixtures/resources.py b/test/fixtures/resources.py new file mode 100644 index 0000000000000000000000000000000000000000..6317e4cb7d873ddf0dc258708cc4d6e90034a205 --- /dev/null +++ b/test/fixtures/resources.py @@ -0,0 +1,36 @@ +import pytest + +from b_asic.process import PlainMemoryVariable +from b_asic.resources import ProcessCollection + + +@pytest.fixture() +def simple_collection(): + NO_PORT = 0 + return ProcessCollection( + { + PlainMemoryVariable(4, NO_PORT, {NO_PORT: 2}), + PlainMemoryVariable(2, NO_PORT, {NO_PORT: 6}), + PlainMemoryVariable(3, NO_PORT, {NO_PORT: 5}), + PlainMemoryVariable(6, NO_PORT, {NO_PORT: 2}), + PlainMemoryVariable(0, NO_PORT, {NO_PORT: 3}), + PlainMemoryVariable(0, NO_PORT, {NO_PORT: 2}), + PlainMemoryVariable(0, NO_PORT, {NO_PORT: 6}), + } + ) + + +@pytest.fixture() +def collection(): + NO_PORT = 0 + return ProcessCollection( + { + PlainMemoryVariable(4, NO_PORT, {NO_PORT: 2}), + PlainMemoryVariable(2, NO_PORT, {NO_PORT: 6}), + PlainMemoryVariable(3, NO_PORT, {NO_PORT: 5}), + PlainMemoryVariable(6, NO_PORT, {NO_PORT: 2}), + PlainMemoryVariable(0, NO_PORT, {NO_PORT: 3}), + PlainMemoryVariable(0, NO_PORT, {NO_PORT: 2}), + PlainMemoryVariable(0, NO_PORT, {NO_PORT: 6}), + } + ) diff --git a/test/test_resources.py b/test/test_resources.py index fb41017d93722c002bf9356d27ea4688107aa41a..67f20cc92fe43276229af1d8a5e93e5409d712f1 100644 --- a/test/test_resources.py +++ b/test/test_resources.py @@ -1,5 +1,3 @@ -from time import sleep - import matplotlib.pyplot as plt import networkx as nx import pytest @@ -7,79 +5,25 @@ import pytest from b_asic.process import PlainMemoryVariable from b_asic.resources import ProcessCollection, draw_exclusion_graph_coloring -NO_PORT = 0 - -# -# Tests to run -# -def test_run(): - # TestProcessCollectionPlainMemoryVariable().test_draw_proces_collection() - TestProcessCollectionPlainMemoryVariable().test_split_memory_variable() - class TestProcessCollectionPlainMemoryVariable: - def __init__(self) -> None: - self.collection = ProcessCollection( - { - PlainMemoryVariable(4, NO_PORT, {NO_PORT: 2}), - PlainMemoryVariable(2, NO_PORT, {NO_PORT: 6}), - PlainMemoryVariable(3, NO_PORT, {NO_PORT: 5}), - PlainMemoryVariable(6, NO_PORT, {NO_PORT: 2}), - PlainMemoryVariable(0, NO_PORT, {NO_PORT: 3}), - PlainMemoryVariable(0, NO_PORT, {NO_PORT: 2}), - PlainMemoryVariable(0, NO_PORT, {NO_PORT: 6}), - } - ) + @pytest.mark.mpl_image_compare(style='mpl20') + def test_draw_process_collection(self, simple_collection): + fig, ax = plt.subplots() + simple_collection.draw_lifetime_chart(ax=ax) + return fig - def test_draw_proces_collection(self): + def test_draw_proces_collection(self, simple_collection): _, ax = plt.subplots(1, 2) - self.collection.draw_lifetime_chart(schedule_time=8, ax=ax[0]) - exclusion_graph = self.collection.create_exclusion_graph_from_overlap() + simple_collection.draw_lifetime_chart(schedule_time=8, ax=ax[0]) + exclusion_graph = ( + simple_collection.create_exclusion_graph_from_overlap() + ) color_dict = nx.coloring.greedy_color(exclusion_graph) draw_exclusion_graph_coloring(exclusion_graph, color_dict, ax=ax[1]) - plt.show() - def test_split_memory_variable(self): - collection_split = self.collection.split( + def test_split_memory_variable(self, simple_collection): + collection_split = simple_collection.split( read_ports=1, write_ports=1, total_ports=2 ) - _, ax = plt.subplots(1, len(collection_split) + 1) - # print(f'Length: {len(ax)}') - # assert(False) - self.collection.draw_lifetime_chart(ax=ax[0]) - ax[0].set_title("Original") - for idx, collection in enumerate(collection_split): - collection.draw_lifetime_chart(ax=ax[idx + 1]) - plt.show() - - -# def test_draw_process_collection(): -# collection = ProcessCollection({ -# PlainMemoryVariable(4, NO_PORT, {NO_PORT: 2}), -# PlainMemoryVariable(2, NO_PORT, {NO_PORT: 6}), -# PlainMemoryVariable(3, NO_PORT, {NO_PORT: 5}), -# PlainMemoryVariable(6, NO_PORT, {NO_PORT: 2}), -# PlainMemoryVariable(0, NO_PORT, {NO_PORT: 3}), -# PlainMemoryVariable(0, NO_PORT, {NO_PORT: 2}), -# PlainMemoryVariable(0, NO_PORT, {NO_PORT: 6}), -# }) -# _, ax = plt.subplots(1, 2) -# collection.draw_lifetime_chart(schedule_time=8, ax=ax[0]) -# exclusion_graph = collection.create_exclusion_graph_from_overlap() -# color_dict = nx.greedy_color(exclusion_graph) -# draw_exclusion_graph_coloring(exclusion_graph, color_dict, ax=ax[1]) -# plt.show() -# -# def test_memory_split(): - - -# @pytest.mark.mpl_image_compare(remove_text=True, style='mpl20') - -# def test_create_exclusion_graph_overlap(): -# collection = ProcessCollection({ -# PlainMemoryVariable(4, NO_PORT, {NO_PORT: 2}), -# PlainMemoryVariable(2, NO_PORT, {NO_PORT: 6}), -# }) -# exclusion_graph = collection.create_exclusion_graph_overlap() -# nx.draw(exclusion_graph) -# plt.show() + assert len(collection_split) == 3