Skip to content
Snippets Groups Projects
Commit 6c6089a5 authored by Mikael Henriksson's avatar Mikael Henriksson :runner:
Browse files

fix port-split bug #244 and process collection drawing bug #243

parent dd3e83dd
No related branches found
No related tags found
1 merge request!348fix port-split bug #244 and process collection drawing bug #243
Pipeline #96683 passed
......@@ -2,10 +2,11 @@
B-ASIC architecture classes.
"""
from collections import defaultdict
from typing import Dict, List, Optional, Set, Tuple, cast
from typing import Dict, Iterator, List, Optional, Set, Tuple, cast
from graphviz import Digraph
from b_asic.port import InputPort, OutputPort
from b_asic.process import MemoryVariable, OperatorProcess, PlainMemoryVariable
from b_asic.resources import ProcessCollection
......@@ -192,6 +193,10 @@ class Memory(Resource):
)
self._memory_type = memory_type
def __iter__(self) -> Iterator[MemoryVariable]:
# Add information about the iterator type
return cast(Iterator[MemoryVariable], iter(self._collection))
class Architecture:
"""
......@@ -220,10 +225,10 @@ class Architecture:
self._memories = memories
self._entity_name = entity_name
self._direct_interconnects = direct_interconnects
self._variable_inport_to_resource = {}
self._variable_outport_to_resource = {}
self._operation_inport_to_resource = {}
self._operation_outport_to_resource = {}
self._variable_inport_to_resource: Dict[InputPort, Resource] = {}
self._variable_outport_to_resource: Dict[OutputPort, Resource] = {}
self._operation_inport_to_resource: Dict[InputPort, Resource] = {}
self._operation_outport_to_resource: Dict[OutputPort, Resource] = {}
self._build_dicts()
......@@ -240,7 +245,6 @@ class Architecture:
for memory in self.memories:
for mv in memory:
mv = cast(MemoryVariable, mv)
for read_port in mv.read_ports:
self._variable_inport_to_resource[read_port] = memory
self._variable_outport_to_resource[mv.write_port] = memory
......@@ -262,7 +266,6 @@ class Architecture:
memory_write_ports = set()
for memory in self.memories:
for mv in memory:
mv = cast(MemoryVariable, mv)
memory_write_ports.add(mv.write_port)
memory_read_ports.update(mv.read_ports)
if self._direct_interconnects:
......@@ -382,10 +385,16 @@ class Architecture:
def _digraph(self) -> Digraph:
dg = Digraph(node_attr={'shape': 'record'})
for mem in self._memories:
dg.node(mem._entity_name, mem._struct_def())
for pe in self._processing_elements:
dg.node(pe._entity_name, pe._struct_def())
for i, mem in enumerate(self._memories):
if mem._entity_name is not None:
dg.node(mem._entity_name, mem._struct_def())
else:
dg.node(f"MEM-{i}", mem._struct_def())
for i, pe in enumerate(self._processing_elements):
if pe._entity_name is not None:
dg.node(pe._entity_name, pe._struct_def())
else:
dg.node(f"PE-{i}", pe._struct_def())
for pe in self._processing_elements:
inputs, outputs = self.get_interconnects_for_pe(pe)
for i, inp in enumerate(inputs):
......@@ -396,7 +405,7 @@ class Architecture:
for o, outp in enumerate(outputs):
for dest, cnt in outp.items():
dg.edge(
f"{pe._entity_name}:out{0}", dest._entity_name, label=f"{cnt}"
f"{pe._entity_name}:out{o}", dest._entity_name, label=f"{cnt}"
)
return dg
......
......@@ -55,6 +55,10 @@ class Process:
def __repr__(self) -> str:
return f"Process({self.start_time}, {self.execution_time}, {self.name!r})"
@property
def read_times(self) -> Tuple[int, ...]:
return (self.start_time + self.execution_time,)
class OperatorProcess(Process):
"""
......@@ -159,6 +163,10 @@ class MemoryVariable(Process):
f" {reads!r}, {self.name!r})"
)
@property
def read_times(self) -> Tuple[int, ...]:
return tuple(self.start_time + read for read in self._life_times)
class PlainMemoryVariable(Process):
"""
......@@ -225,5 +233,9 @@ class PlainMemoryVariable(Process):
f" {reads!r}, {self.name!r})"
)
@property
def read_times(self) -> Tuple[int, ...]:
return tuple(self.start_time + read for read in self._life_times)
# Static counter for default names
_name_cnt = 0
......@@ -531,20 +531,30 @@ class ProcessCollection:
color=marker_color,
zorder=10,
)
_ax.scatter( # type: ignore
x=bar_end,
y=bar_row + 1,
marker=marker_read,
color=marker_color,
zorder=10,
)
for end_time in process.read_times:
end_time = (
end_time
if end_time == self._schedule_time
else end_time % self._schedule_time
)
_ax.scatter( # type: ignore
x=end_time,
y=bar_row + 1,
marker=marker_read,
color=marker_color,
zorder=10,
)
if process.execution_time > self.schedule_time:
# Execution time longer than schedule time, draw with warning color
_ax.broken_barh( # type: ignore
[(0, self.schedule_time)],
(bar_row + 0.55, 0.9),
color=_WARNING_COLOR,
)
elif bar_end >= bar_start:
elif process.execution_time == 0:
# Execution time zero, don't draw the bar
pass
elif bar_end > bar_start:
_ax.broken_barh( # type: ignore
[(PAD_L + bar_start, bar_end - bar_start - PAD_L - PAD_R)],
(bar_row + 0.55, 0.9),
......@@ -679,28 +689,39 @@ class ProcessCollection:
exclusion_graph = nx.Graph()
exclusion_graph.add_nodes_from(self._collection)
for node1 in exclusion_graph:
# node1_stop_time = (node1.start_time + node1.execution_time) % self.schedule_time
node1_stop_times = tuple(
read_time % self.schedule_time for read_time in node1.read_times
)
if total_ports == 1 and node1.start_time in node1_stop_times:
print(node1.start_time, node1_stop_times)
raise ValueError("Cannot read and write in same cycle.")
for node2 in exclusion_graph:
if node1 == node2:
continue
else:
node1_stop_time = node1.start_time + node1.execution_time
node2_stop_time = node2.start_time + node2.execution_time
if total_ports == 1:
# Single-port assignment
if node1.start_time == node2.start_time:
exclusion_graph.add_edge(node1, node2)
elif node1_stop_time == node2_stop_time:
exclusion_graph.add_edge(node1, node2)
elif node1.start_time == node2_stop_time:
exclusion_graph.add_edge(node1, node2)
elif node1_stop_time == node2.start_time:
exclusion_graph.add_edge(node1, node2)
else:
# Dual-port assignment
if node1.start_time == node2.start_time:
exclusion_graph.add_edge(node1, node2)
elif node1_stop_time == node2_stop_time:
exclusion_graph.add_edge(node1, node2)
# node2_stop_time = (node2.start_time + node2.execution_time) % self.schedule_time
node2_stop_times = tuple(
read_time % self.schedule_time for read_time in node2.read_times
)
for node1_stop_time in node1_stop_times:
for node2_stop_time in node2_stop_times:
if total_ports == 1:
# Single-port assignment
if node1.start_time == node2.start_time:
exclusion_graph.add_edge(node1, node2)
elif node1_stop_time == node2_stop_time:
exclusion_graph.add_edge(node1, node2)
elif node1.start_time == node2_stop_time:
exclusion_graph.add_edge(node1, node2)
elif node1_stop_time == node2.start_time:
exclusion_graph.add_edge(node1, node2)
else:
# Dual-port assignment
if node1.start_time == node2.start_time:
exclusion_graph.add_edge(node1, node2)
elif node1_stop_time == node2_stop_time:
exclusion_graph.add_edge(node1, node2)
return exclusion_graph
def create_exclusion_graph_from_execution_time(self) -> nx.Graph:
......
......@@ -43,9 +43,9 @@ schedule.show(title='Original schedule')
# %%
# Rescheudle to only require one adder and one multiplier
schedule.move_operation('add4', 3)
schedule.move_operation('cmul5', -5)
schedule.move_operation('cmul4', -4)
schedule.move_operation('add4', 2)
schedule.move_operation('cmul5', -4)
schedule.move_operation('cmul4', -5)
schedule.move_operation('cmul6', -2)
schedule.move_operation('cmul3', 1)
schedule.show(title='Improved schedule')
......@@ -74,7 +74,7 @@ mem_vars.show(title="All memory variables")
direct, mem_vars = mem_vars.split_on_length()
direct.show(title="Direct interconnects")
mem_vars.show(title="Non-zero time memory variables")
mem_vars_set = mem_vars.split_on_ports(read_ports=1, write_ports=1, total_ports=1)
mem_vars_set = mem_vars.split_on_ports(read_ports=1, write_ports=1, total_ports=2)
memories = set()
for i, mem in enumerate(mem_vars_set):
......
File deleted
......@@ -109,13 +109,8 @@ class TestProcessCollectionPlainMemoryVariable:
for i, register in enumerate(sorted(register_names)):
assert register == f'R{i}'
# Issue: #175
def test_interleaver_issue175(self):
with open('test/fixtures/interleaver-two-port-issue175.p', 'rb') as f:
interleaver_collection: ProcessCollection = pickle.load(f)
assert len(interleaver_collection.split_on_ports(total_ports=1)) == 2
def test_generate_random_interleaver(self):
return
for _ in range(10):
for size in range(5, 20, 5):
collection = generate_random_interleaver(size)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment