Skip to content
Snippets Groups Projects
Commit 69b10b1e authored by Oscar Gustafsson's avatar Oscar Gustafsson :bicyclist:
Browse files

More typing fixes and less sets

parent 662901f7
No related branches found
No related tags found
1 merge request!372More typing fixes and less sets
Pipeline #97062 passed
...@@ -3,7 +3,18 @@ B-ASIC architecture classes. ...@@ -3,7 +3,18 @@ B-ASIC architecture classes.
""" """
from collections import defaultdict from collections import defaultdict
from io import TextIOWrapper from io import TextIOWrapper
from typing import Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union, cast from typing import (
DefaultDict,
Dict,
Iterable,
Iterator,
List,
Optional,
Set,
Tuple,
Union,
cast,
)
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from graphviz import Digraph from graphviz import Digraph
...@@ -30,7 +41,7 @@ class HardwareBlock: ...@@ -30,7 +41,7 @@ class HardwareBlock:
""" """
def __init__(self, entity_name: Optional[str] = None): def __init__(self, entity_name: Optional[str] = None):
self._entity_name = None self._entity_name: Optional[str] = None
if entity_name is not None: if entity_name is not None:
self.set_entity_name(entity_name) self.set_entity_name(entity_name)
...@@ -56,7 +67,7 @@ class HardwareBlock: ...@@ -56,7 +67,7 @@ class HardwareBlock:
path : str path : str
Directory to write code in. Directory to write code in.
""" """
if not self.entity_name: if not self._entity_name:
raise ValueError("Entity name must be set") raise ValueError("Entity name must be set")
raise NotImplementedError raise NotImplementedError
...@@ -201,6 +212,10 @@ class Resource(HardwareBlock): ...@@ -201,6 +212,10 @@ class Resource(HardwareBlock):
self.plot_content(ax) self.plot_content(ax)
return fig return fig
@property
def collection(self) -> ProcessCollection:
return self._collection
class ProcessingElement(Resource): class ProcessingElement(Resource):
""" """
...@@ -233,10 +248,8 @@ class ProcessingElement(Resource): ...@@ -233,10 +248,8 @@ class ProcessingElement(Resource):
op_type = type(ops[0]) op_type = type(ops[0])
if not all(isinstance(op, op_type) for op in ops): if not all(isinstance(op, op_type) for op in ops):
raise TypeError("Different Operation types in ProcessCollection") raise TypeError("Different Operation types in ProcessCollection")
self._collection = process_collection
self._operation_type = op_type self._operation_type = op_type
self._type_name = op_type.type_name() self._type_name = op_type.type_name()
self._entity_name = entity_name
self._input_count = ops[0].input_count self._input_count = ops[0].input_count
self._output_count = ops[0].output_count self._output_count = ops[0].output_count
self._assignment = list( self._assignment = list(
...@@ -246,8 +259,8 @@ class ProcessingElement(Resource): ...@@ -246,8 +259,8 @@ class ProcessingElement(Resource):
raise ValueError("Cannot map ProcessCollection to single ProcessingElement") raise ValueError("Cannot map ProcessCollection to single ProcessingElement")
@property @property
def processes(self) -> Set[OperatorProcess]: def processes(self) -> List[OperatorProcess]:
return {cast(OperatorProcess, p) for p in self._collection} return [cast(OperatorProcess, p) for p in self._collection]
class Memory(Resource): class Memory(Resource):
...@@ -350,11 +363,11 @@ of :class:`~b_asic.architecture.ProcessingElement` ...@@ -350,11 +363,11 @@ of :class:`~b_asic.architecture.ProcessingElement`
): ):
super().__init__(entity_name) super().__init__(entity_name)
self._processing_elements = ( self._processing_elements = (
set(processing_elements) [processing_elements]
if isinstance(processing_elements, ProcessingElement) if isinstance(processing_elements, ProcessingElement)
else processing_elements else list(processing_elements)
) )
self._memories = set(memories) if isinstance(memories, Memory) else memories self._memories = [memories] if isinstance(memories, Memory) else list(memories)
self._direct_interconnects = direct_interconnects self._direct_interconnects = direct_interconnects
self._variable_inport_to_resource: Dict[InputPort, Tuple[Resource, int]] = {} self._variable_inport_to_resource: Dict[InputPort, Tuple[Resource, int]] = {}
self._variable_outport_to_resource: Dict[OutputPort, Tuple[Resource, int]] = {} self._variable_outport_to_resource: Dict[OutputPort, Tuple[Resource, int]] = {}
...@@ -457,9 +470,9 @@ of :class:`~b_asic.architecture.ProcessingElement` ...@@ -457,9 +470,9 @@ of :class:`~b_asic.architecture.ProcessingElement`
A dictionary with the ProcessingElements that are connected to the write and A dictionary with the ProcessingElements that are connected to the write and
read ports, respectively, with counts of the number of accesses. read ports, respectively, with counts of the number of accesses.
""" """
d_in = defaultdict(_interconnect_dict) d_in: DefaultDict[Resource, int] = defaultdict(_interconnect_dict)
d_out = defaultdict(_interconnect_dict) d_out: DefaultDict[Resource, int] = defaultdict(_interconnect_dict)
for var in mem._collection: for var in mem.collection:
var = cast(MemoryVariable, var) var = cast(MemoryVariable, var)
d_in[self._operation_outport_to_resource[var.write_port]] += 1 d_in[self._operation_outport_to_resource[var.write_port]] += 1
for read_port in var.read_ports: for read_port in var.read_ports:
...@@ -490,19 +503,22 @@ of :class:`~b_asic.architecture.ProcessingElement` ...@@ -490,19 +503,22 @@ of :class:`~b_asic.architecture.ProcessingElement`
frequency of accesses. frequency of accesses.
""" """
ops = cast(List[OperatorProcess], list(pe._collection)) d_in: List[DefaultDict[Tuple[Resource, int], int]] = [
d_in = [defaultdict(_interconnect_dict) for _ in ops[0].operation.inputs] defaultdict(_interconnect_dict) for _ in range(pe.input_count)
d_out = [defaultdict(_interconnect_dict) for _ in ops[0].operation.outputs] ]
for var in pe._collection: d_out: List[DefaultDict[Tuple[Resource, int], int]] = [
defaultdict(_interconnect_dict) for _ in range(pe.output_count)
]
for var in pe.collection:
var = cast(OperatorProcess, var) var = cast(OperatorProcess, var)
for i, input in enumerate(var.operation.inputs): for i, input_ in enumerate(var.operation.inputs):
d_in[i][self._variable_inport_to_resource[input]] += 1 d_in[i][self._variable_inport_to_resource[input_]] += 1
for i, output in enumerate(var.operation.outputs): for i, output in enumerate(var.operation.outputs):
d_out[i][self._variable_outport_to_resource[output]] += 1 d_out[i][self._variable_outport_to_resource[output]] += 1
return [dict(d) for d in d_in], [dict(d) for d in d_out] return [dict(d) for d in d_in], [dict(d) for d in d_out]
def _digraph(self) -> Digraph: def _digraph(self) -> Digraph:
edges = set() edges: Set[Tuple[str, str, str]] = set()
dg = Digraph(node_attr={'shape': 'record'}) dg = Digraph(node_attr={'shape': 'record'})
# dg.attr(rankdir="LR") # dg.attr(rankdir="LR")
for i, mem in enumerate(self._memories): for i, mem in enumerate(self._memories):
...@@ -529,16 +545,16 @@ of :class:`~b_asic.architecture.ProcessingElement` ...@@ -529,16 +545,16 @@ of :class:`~b_asic.architecture.ProcessingElement`
f"{cnt}", f"{cnt}",
) )
) )
for src, dest, cnt in edges: for src_str, dest_str, cnt_str in edges:
dg.edge(src, dest, label=cnt) dg.edge(src_str, dest_str, label=cnt_str)
return dg return dg
@property @property
def memories(self) -> Iterable[Memory]: def memories(self) -> List[Memory]:
return self._memories return self._memories
@property @property
def processing_elements(self) -> Iterable[ProcessingElement]: def processing_elements(self) -> List[ProcessingElement]:
return self._processing_elements return self._processing_elements
@property @property
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment