Skip to content
Snippets Groups Projects
Commit 69b10b1e authored by Oscar Gustafsson's avatar Oscar Gustafsson :bicyclist:
Browse files

More typing fixes and less sets

parent 662901f7
No related branches found
No related tags found
1 merge request!372More typing fixes and less sets
Pipeline #97062 passed
......@@ -3,7 +3,18 @@ B-ASIC architecture classes.
"""
from collections import defaultdict
from io import TextIOWrapper
from typing import Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union, cast
from typing import (
DefaultDict,
Dict,
Iterable,
Iterator,
List,
Optional,
Set,
Tuple,
Union,
cast,
)
import matplotlib.pyplot as plt
from graphviz import Digraph
......@@ -30,7 +41,7 @@ class HardwareBlock:
"""
def __init__(self, entity_name: Optional[str] = None):
self._entity_name = None
self._entity_name: Optional[str] = None
if entity_name is not None:
self.set_entity_name(entity_name)
......@@ -56,7 +67,7 @@ class HardwareBlock:
path : str
Directory to write code in.
"""
if not self.entity_name:
if not self._entity_name:
raise ValueError("Entity name must be set")
raise NotImplementedError
......@@ -201,6 +212,10 @@ class Resource(HardwareBlock):
self.plot_content(ax)
return fig
@property
def collection(self) -> ProcessCollection:
return self._collection
class ProcessingElement(Resource):
"""
......@@ -233,10 +248,8 @@ class ProcessingElement(Resource):
op_type = type(ops[0])
if not all(isinstance(op, op_type) for op in ops):
raise TypeError("Different Operation types in ProcessCollection")
self._collection = process_collection
self._operation_type = op_type
self._type_name = op_type.type_name()
self._entity_name = entity_name
self._input_count = ops[0].input_count
self._output_count = ops[0].output_count
self._assignment = list(
......@@ -246,8 +259,8 @@ class ProcessingElement(Resource):
raise ValueError("Cannot map ProcessCollection to single ProcessingElement")
@property
def processes(self) -> Set[OperatorProcess]:
return {cast(OperatorProcess, p) for p in self._collection}
def processes(self) -> List[OperatorProcess]:
return [cast(OperatorProcess, p) for p in self._collection]
class Memory(Resource):
......@@ -350,11 +363,11 @@ of :class:`~b_asic.architecture.ProcessingElement`
):
super().__init__(entity_name)
self._processing_elements = (
set(processing_elements)
[processing_elements]
if isinstance(processing_elements, ProcessingElement)
else processing_elements
else list(processing_elements)
)
self._memories = set(memories) if isinstance(memories, Memory) else memories
self._memories = [memories] if isinstance(memories, Memory) else list(memories)
self._direct_interconnects = direct_interconnects
self._variable_inport_to_resource: Dict[InputPort, Tuple[Resource, int]] = {}
self._variable_outport_to_resource: Dict[OutputPort, Tuple[Resource, int]] = {}
......@@ -457,9 +470,9 @@ of :class:`~b_asic.architecture.ProcessingElement`
A dictionary with the ProcessingElements that are connected to the write and
read ports, respectively, with counts of the number of accesses.
"""
d_in = defaultdict(_interconnect_dict)
d_out = defaultdict(_interconnect_dict)
for var in mem._collection:
d_in: DefaultDict[Resource, int] = defaultdict(_interconnect_dict)
d_out: DefaultDict[Resource, int] = defaultdict(_interconnect_dict)
for var in mem.collection:
var = cast(MemoryVariable, var)
d_in[self._operation_outport_to_resource[var.write_port]] += 1
for read_port in var.read_ports:
......@@ -490,19 +503,22 @@ of :class:`~b_asic.architecture.ProcessingElement`
frequency of accesses.
"""
ops = cast(List[OperatorProcess], list(pe._collection))
d_in = [defaultdict(_interconnect_dict) for _ in ops[0].operation.inputs]
d_out = [defaultdict(_interconnect_dict) for _ in ops[0].operation.outputs]
for var in pe._collection:
d_in: List[DefaultDict[Tuple[Resource, int], int]] = [
defaultdict(_interconnect_dict) for _ in range(pe.input_count)
]
d_out: List[DefaultDict[Tuple[Resource, int], int]] = [
defaultdict(_interconnect_dict) for _ in range(pe.output_count)
]
for var in pe.collection:
var = cast(OperatorProcess, var)
for i, input in enumerate(var.operation.inputs):
d_in[i][self._variable_inport_to_resource[input]] += 1
for i, input_ in enumerate(var.operation.inputs):
d_in[i][self._variable_inport_to_resource[input_]] += 1
for i, output in enumerate(var.operation.outputs):
d_out[i][self._variable_outport_to_resource[output]] += 1
return [dict(d) for d in d_in], [dict(d) for d in d_out]
def _digraph(self) -> Digraph:
edges = set()
edges: Set[Tuple[str, str, str]] = set()
dg = Digraph(node_attr={'shape': 'record'})
# dg.attr(rankdir="LR")
for i, mem in enumerate(self._memories):
......@@ -529,16 +545,16 @@ of :class:`~b_asic.architecture.ProcessingElement`
f"{cnt}",
)
)
for src, dest, cnt in edges:
dg.edge(src, dest, label=cnt)
for src_str, dest_str, cnt_str in edges:
dg.edge(src_str, dest_str, label=cnt_str)
return dg
@property
def memories(self) -> Iterable[Memory]:
def memories(self) -> List[Memory]:
return self._memories
@property
def processing_elements(self) -> Iterable[ProcessingElement]:
def processing_elements(self) -> List[ProcessingElement]:
return self._processing_elements
@property
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment