diff --git a/README.md b/README.md index fb3ee09e2f38db7fa7cb428f74dbd5dcccddb2e3..f282c0e078fb03db14c53f4780f5f97c71566118 100755 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ To generate the documentation, the following additional packages are required: - [numpydoc](https://numpydoc.readthedocs.io/) - [Sphinx-Gallery](https://sphinx-gallery.github.io/) - [mplsignal](https://mplsignal.readthedocs.io/) - - [jupyters-sphinx](https://jupyter-sphinx.readthedocs.io/) + - [sphinx-copybutton](https://sphinx-copybutton.readthedocs.io/) ### Using setuptools to create a package diff --git a/b_asic/_preferences.py b/b_asic/_preferences.py index 44ee23e138deea9be43ace82d24b2e04786bca29..004e498ccf997bec68298ef88ac73e7ee19d962a 100755 --- a/b_asic/_preferences.py +++ b/b_asic/_preferences.py @@ -10,3 +10,12 @@ OPERATION_GAP: float = 0.5 SCHEDULE_OFFSET: float = 0.2 SPLINE_OFFSET: float = 0.2 + +# Colors for architecture Digraph +PE_COLOR = (0, 185, 231) # LiuBlue +PE_CLUSTER_COLOR = (210, 238, 249) # LiuBlue5 +MEMORY_COLOR = (0, 207, 181) # LiuGreen +MEMORY_CLUSTER_COLOR = (213, 241, 235) # LiuGreen5 +IO_COLOR = (23, 199, 210) # LiuTurqoise +IO_CLUSTER_COLOR = (215, 239, 242) # LiuTurqoise5 +MUX_COLOR = (255, 100, 66) # LiuOrange diff --git a/b_asic/architecture.py b/b_asic/architecture.py index cc73a3898659751e165e7c200f2e82c5d673ed46..2edcd3525d483b3e67b356a9498db04734d1fdf6 100755 --- a/b_asic/architecture.py +++ b/b_asic/architecture.py @@ -3,13 +3,37 @@ B-ASIC architecture classes. """ from collections import defaultdict from io import TextIOWrapper -from typing import Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union, cast +from itertools import chain +from typing import ( + DefaultDict, + Dict, + Iterable, + Iterator, + List, + Optional, + Set, + Tuple, + Type, + Union, + cast, +) import matplotlib.pyplot as plt from graphviz import Digraph +from b_asic._preferences import ( + IO_CLUSTER_COLOR, + IO_COLOR, + MEMORY_CLUSTER_COLOR, + MEMORY_COLOR, + MUX_COLOR, + PE_CLUSTER_COLOR, + PE_COLOR, +) +from b_asic.codegen.vhdl.common import is_valid_vhdl_identifier +from b_asic.operation import Operation from b_asic.port import InputPort, OutputPort -from b_asic.process import MemoryVariable, OperatorProcess, PlainMemoryVariable +from b_asic.process import MemoryProcess, MemoryVariable, OperatorProcess, Process from b_asic.resources import ProcessCollection @@ -29,11 +53,11 @@ class HardwareBlock: """ def __init__(self, entity_name: Optional[str] = None): - self._entity_name = None + self._entity_name: Optional[str] = None if entity_name is not None: self.set_entity_name(entity_name) - def set_entity_name(self, entity_name: str): + def set_entity_name(self, entity_name: str) -> None: """ Set entity name of hardware block. @@ -42,10 +66,8 @@ class HardwareBlock: entity_name : str The entity name. """ - # Should be a better check. - # See https://stackoverflow.com/questions/7959587/regex-for-vhdl-identifier - if " " in entity_name: - raise ValueError("Cannot have space in entity name") + if not is_valid_vhdl_identifier(entity_name): + raise ValueError(f'{entity_name} is not a valid VHDL identifier') self._entity_name = entity_name def write_code(self, path: str) -> None: @@ -70,6 +92,13 @@ class HardwareBlock: def _repr_png_(self): return self._digraph()._repr_mimebundle_(include=["image/png"])["image/png"] + def _repr_svg_(self): + return self._digraph()._repr_mimebundle_(include=["image/svg+xml"])[ + "image/svg+xml" + ] + + _repr_html_ = _repr_svg_ + @property def entity_name(self) -> str: if self._entity_name is None: @@ -143,7 +172,9 @@ class Resource(HardwareBlock): def _digraph(self) -> Digraph: dg = Digraph(node_attr={'shape': 'record'}) - dg.node(self.entity_name, self._struct_def()) + dg.node( + self.entity_name, self._struct_def(), style='filled', fillcolor=self._color + ) return dg @property @@ -157,39 +188,76 @@ class Resource(HardwareBlock): return self._output_count def _struct_def(self) -> str: + # Create GraphViz struct inputs = [f"in{i}" for i in range(self._input_count)] outputs = [f"out{i}" for i in range(self._output_count)] ret = "" if inputs: - instrs = [f"<{instr}> {instr}" for instr in inputs] - ret += f"{{{'|'.join(instrs)}}}|" - ret += f"{self.entity_name}" + in_strs = [f"<{in_str}> {in_str}" for in_str in inputs] + ret += f"{{{'|'.join(in_strs)}}}|" + ret += f"<{self.entity_name}> {self.entity_name}{self._info()}" if outputs: - outstrs = [f"<{outstr}> {outstr}" for outstr in outputs] - ret += f"|{{{'|'.join(outstrs)}}}" + out_strs = [f"<{out_str}> {out_str}" for out_str in outputs] + ret += f"|{{{'|'.join(out_strs)}}}" return "{" + ret + "}" + def _info(self): + return "" + @property def schedule_time(self) -> int: # doc-string inherited return self._collection.schedule_time - def plot_content(self, ax: plt.Axes) -> None: + def plot_content(self, ax: plt.Axes, **kwargs) -> None: + """ + Plot the content of the resource. + + This plots the assigned processes executed on this resource. + + Parameters + ---------- + ax : Axes + Matplotlib Axes to plot in. + **kwargs + Passed to :meth:`b_asic.resources.ProcessCollection.plot` + """ if not self.is_assigned: - self._collection.plot(ax) + self._collection.plot(ax, **kwargs) else: for i, pc in enumerate(self._assignment): # type: ignore - pc.plot(ax=ax, row=i) + pc.plot(ax=ax, row=i, **kwargs) - def show_content(self): + def show_content(self, title=None, **kwargs) -> None: + """ + Display the content of the resource. + + This displays the assigned processes executed on this resource. + + Parameters + ---------- + title : str, optional + **kwargs + Passed to :meth:`b_asic.resources.ProcessCollection.plot` + """ fig, ax = plt.subplots() - self.plot_content(ax) + self.plot_content(ax, **kwargs) + if title: + fig.suptitle(title) fig.show() # type: ignore @property def is_assigned(self) -> bool: return self._assignment is not None + def assign(self, heuristic: str = 'left_edge'): + """ + Perform assignment of processes to resource. + + See the specific resource types for more information. + """ + raise NotImplementedError() + @property def content(self) -> plt.Figure: """ @@ -202,6 +270,62 @@ class Resource(HardwareBlock): self.plot_content(ax) return fig + @property + def collection(self) -> ProcessCollection: + return self._collection + + @property + def operation_type(self) -> Union[Type[MemoryProcess], Type[OperatorProcess]]: + raise NotImplementedError("ABC Resource does not implement operation_type") + + def add_process(self, proc: Process, assign=False): + """ + Add a :class:`~b_asic.process.Process` to this :class:`Resource`. + + Raises :class:`KeyError` if the process being added is not of the same type + as the other processes. + + Parameters + ---------- + proc : :class:`~b_asic.process.Process` + The process to add. + assign : bool, default=False + Whether to perform assignment of the resource after adding. + """ + if isinstance(proc, OperatorProcess): + # operation_type marks OperatorProcess associated operation. + if not isinstance(proc._operation, self.operation_type): + raise KeyError(f"{proc} not of type {self.operation_type}") + else: + # operation_type is MemoryVariable or PlainMemoryVariable + if not isinstance(proc, self.operation_type): + raise KeyError(f"{proc} not of type {self.operation_type}") + self.collection.add_process(proc) + if assign: + self.assign() + else: + self._assignment = None + + def remove_process(self, proc: Process, assign: bool = False): + """ + Remove a :class:`~b_asic.process.Process` from this :class:`Resource`. + + Raises :class:`KeyError` if the process being added is not of the same type + as the other processes. + + Parameters + ---------- + proc : :class:`~b_asic.process.Process` + The process to remove. + assign : bool, default=False + Whether to perform assignment of the resource after removal. + """ + self.collection.remove_process(proc) + if assign: + self.assign() + else: + self._assignment = None + class ProcessingElement(Resource): """ @@ -213,10 +337,17 @@ class ProcessingElement(Resource): Process collection containing operations to map to processing element. entity_name : str, optional Name of processing element entity. + assign : bool, default True + Perform assignment when creating the ProcessingElement. """ + _color = f"#{''.join(f'{v:0>2X}' for v in PE_COLOR)}" + def __init__( - self, process_collection: ProcessCollection, entity_name: Optional[str] = None + self, + process_collection: ProcessCollection, + entity_name: Optional[str] = None, + assign: bool = True, ): super().__init__(process_collection=process_collection, entity_name=entity_name) if not all( @@ -234,21 +365,39 @@ class ProcessingElement(Resource): op_type = type(ops[0]) if not all(isinstance(op, op_type) for op in ops): raise TypeError("Different Operation types in ProcessCollection") - self._collection = process_collection self._operation_type = op_type self._type_name = op_type.type_name() - self._entity_name = entity_name self._input_count = ops[0].input_count self._output_count = ops[0].output_count + if assign: + self.assign() + + @property + def processes(self) -> List[OperatorProcess]: + return [cast(OperatorProcess, p) for p in self._collection] + + def assign(self, heuristic: str = "left_edge") -> None: + """ + Perform assignment of the processes. + + Parameters + ---------- + heuristic : str, default: 'left_edge' + The assignment algorithm. + + * 'left_edge': Left-edge algorithm. + * 'graph_color': Graph-coloring based on exclusion graph. + """ self._assignment = list( - self._collection.split_on_execution_time(heuristic="left_edge") + self._collection.split_on_execution_time(heuristic=heuristic) ) if len(self._assignment) > 1: + self._assignment = None raise ValueError("Cannot map ProcessCollection to single ProcessingElement") @property - def processes(self) -> Set[OperatorProcess]: - return {cast(OperatorProcess, p) for p in self._collection} + def operation_type(self) -> Type[Operation]: + return self._operation_type class Memory(Resource): @@ -267,8 +416,13 @@ class Memory(Resource): Number of read ports for memory. write_ports : int, optional Number of write ports for memory. + assign : bool, default False + Perform assignment when creating the Memory (using the default properties). + """ + _color = f"#{''.join(f'{v:0>2X}' for v in MEMORY_COLOR)}" + def __init__( self, process_collection: ProcessCollection, @@ -276,15 +430,15 @@ class Memory(Resource): entity_name: Optional[str] = None, read_ports: Optional[int] = None, write_ports: Optional[int] = None, + assign: bool = False, ): super().__init__(process_collection=process_collection, entity_name=entity_name) if not all( - isinstance(operator, (MemoryVariable, PlainMemoryVariable)) + isinstance(operator, MemoryProcess) for operator in process_collection.collection ): raise TypeError( - "Can only have MemoryVariable or PlainMemoryVariable in" - " ProcessCollection when creating Memory" + "Can only have MemoryProcess in ProcessCollection when creating Memory" ) if memory_type not in ("RAM", "register"): raise ValueError( @@ -305,23 +459,54 @@ class Memory(Resource): raise ValueError(f"At least {write_ports_bound} write ports required") self._input_count = write_ports self._memory_type = memory_type + if assign: + self.assign() + + memory_processes = [ + cast(MemoryProcess, process) for process in process_collection + ] + mem_proc_type = type(memory_processes[0]) + if not all(isinstance(proc, mem_proc_type) for proc in memory_processes): + raise TypeError("Different MemoryProcess types in ProcessCollection") + self._operation_type = mem_proc_type def __iter__(self) -> Iterator[MemoryVariable]: # Add information about the iterator type return cast(Iterator[MemoryVariable], iter(self._collection)) - def _assign_ram(self, heuristic: str = "left_edge"): + def _info(self): + if self.is_assigned: + if self._memory_type == "RAM": + plural_s = 's' if len(self._assignment) >= 2 else '' + return f": (RAM, {len(self._assignment)} cell{plural_s})" + return "" + + def assign(self, heuristic: str = "left_edge") -> None: """ - Perform RAM-type assignment of MemoryVariables in this Memory. + Perform assignment of the memory variables. Parameters ---------- - heuristic : {'left_edge', 'graph_color'} - The underlying heuristic to use when performing RAM assignment. + heuristic : str, default: 'left_edge' + The assignment algorithm. Depending on memory type the following are + available: + + * 'RAM' + * 'left_edge': Left-edge algorithm. + * 'graph_color': Graph-coloring based on exclusion graph. + * 'register' + * ... """ - self._assignment = list( - self._collection.split_on_execution_time(heuristic=heuristic) - ) + if self._memory_type == "RAM": + self._assignment = self._collection.split_on_execution_time( + heuristic=heuristic + ) + else: # "register" + raise NotImplementedError() + + @property + def operation_type(self) -> Type[MemoryProcess]: + return self._operation_type class Architecture(HardwareBlock): @@ -351,14 +536,18 @@ of :class:`~b_asic.architecture.ProcessingElement` ): super().__init__(entity_name) self._processing_elements = ( - set(processing_elements) + [processing_elements] if isinstance(processing_elements, ProcessingElement) - else processing_elements + else list(processing_elements) ) - self._memories = set(memories) if isinstance(memories, Memory) else memories + self._memories = [memories] if isinstance(memories, Memory) else list(memories) self._direct_interconnects = direct_interconnects - self._variable_inport_to_resource: Dict[InputPort, Tuple[Resource, int]] = {} - self._variable_outport_to_resource: Dict[OutputPort, Tuple[Resource, int]] = {} + self._variable_inport_to_resource: DefaultDict[ + InputPort, Set[Tuple[Resource, int]] + ] = defaultdict(set) + self._variable_outport_to_resource: DefaultDict[ + OutputPort, Set[Tuple[Resource, int]] + ] = defaultdict(set) self._operation_inport_to_resource: Dict[InputPort, Resource] = {} self._operation_outport_to_resource: Dict[OutputPort, Resource] = {} @@ -381,7 +570,15 @@ of :class:`~b_asic.architecture.ProcessingElement` raise ValueError(f"Different schedule times: {schedule_times}") return schedule_times.pop() - def _build_dicts(self): + def _build_dicts(self) -> None: + self._variable_inport_to_resource: DefaultDict[ + InputPort, Set[Tuple[Resource, int]] + ] = defaultdict(set) + self._variable_outport_to_resource: DefaultDict[ + OutputPort, Set[Tuple[Resource, int]] + ] = defaultdict(set) + self._operation_inport_to_resource = {} + self._operation_outport_to_resource = {} for pe in self.processing_elements: for operator in pe.processes: for input_port in operator.operation.inputs: @@ -392,22 +589,28 @@ of :class:`~b_asic.architecture.ProcessingElement` for memory in self.memories: for mv in memory: for read_port in mv.read_ports: - self._variable_inport_to_resource[read_port] = (memory, 0) # Fix - self._variable_outport_to_resource[mv.write_port] = (memory, 0) # Fix + self._variable_inport_to_resource[read_port].add((memory, 0)) # Fix + self._variable_outport_to_resource[mv.write_port].add( + (memory, 0) + ) # Fix if self._direct_interconnects: for di in self._direct_interconnects: di = cast(MemoryVariable, di) for read_port in di.read_ports: - self._variable_inport_to_resource[read_port] = ( - self._operation_outport_to_resource[di.write_port], - di.write_port.index, + self._variable_inport_to_resource[read_port].add( + ( + self._operation_outport_to_resource[di.write_port], + di.write_port.index, + ) ) - self._variable_outport_to_resource[di.write_port] = ( - self._operation_inport_to_resource[read_port], - read_port.index, + self._variable_outport_to_resource[di.write_port].add( + ( + self._operation_inport_to_resource[read_port], + read_port.index, + ) ) - def validate_ports(self): + def validate_ports(self) -> None: # Validate inputs and outputs of memory variables in all the memories in this # architecture memory_read_ports = set() @@ -429,6 +632,7 @@ of :class:`~b_asic.architecture.ProcessingElement` pe_input_ports.update(operator.operation.inputs) pe_output_ports.update(operator.operation.outputs) + # Make sure all inputs and outputs in the architecture are in use read_port_diff = memory_read_ports.symmetric_difference(pe_input_ports) write_port_diff = memory_write_ports.symmetric_difference(pe_output_ports) if read_port_diff: @@ -441,16 +645,17 @@ of :class:`~b_asic.architecture.ProcessingElement` "Memory read port and PE output port difference:" f" {[port.name for port in write_port_diff]}" ) - # Make sure all inputs and outputs in the architecture are in use - def get_interconnects_for_memory(self, mem: Memory): + def get_interconnects_for_memory( + self, mem: Union[Memory, str] + ) -> Tuple[Dict[Resource, int], Dict[Resource, int]]: """ Return a dictionary with interconnect information for a Memory. Parameters ---------- - mem : :class:`Memory` - The memory to obtain information about. + mem : :class:`Memory` or str + The memory or entity name to obtain information about. Returns ------- @@ -458,9 +663,12 @@ of :class:`~b_asic.architecture.ProcessingElement` A dictionary with the ProcessingElements that are connected to the write and read ports, respectively, with counts of the number of accesses. """ - d_in = defaultdict(_interconnect_dict) - d_out = defaultdict(_interconnect_dict) - for var in mem._collection: + if isinstance(mem, str): + mem = cast(Memory, self.resource_from_name(mem)) + + d_in: DefaultDict[Resource, int] = defaultdict(_interconnect_dict) + d_out: DefaultDict[Resource, int] = defaultdict(_interconnect_dict) + for var in mem.collection: var = cast(MemoryVariable, var) d_in[self._operation_outport_to_resource[var.write_port]] += 1 for read_port in var.read_ports: @@ -468,7 +676,7 @@ of :class:`~b_asic.architecture.ProcessingElement` return dict(d_in), dict(d_out) def get_interconnects_for_pe( - self, pe: ProcessingElement + self, pe: Union[str, ProcessingElement] ) -> Tuple[ List[Dict[Tuple[Resource, int], int]], List[Dict[Tuple[Resource, int], int]] ]: @@ -478,8 +686,8 @@ of :class:`~b_asic.architecture.ProcessingElement` Parameters ---------- - pe : :class:`ProcessingElement` - The processing element to get information for. + pe : :class:`ProcessingElement` or str + The processing element or entity name to get information for. Returns ------- @@ -491,55 +699,260 @@ of :class:`~b_asic.architecture.ProcessingElement` frequency of accesses. """ - ops = cast(List[OperatorProcess], list(pe._collection)) - d_in = [defaultdict(_interconnect_dict) for _ in ops[0].operation.inputs] - d_out = [defaultdict(_interconnect_dict) for _ in ops[0].operation.outputs] - for var in pe._collection: + if isinstance(pe, str): + pe = cast(ProcessingElement, self.resource_from_name(pe)) + + d_in: List[DefaultDict[Tuple[Resource, int], int]] = [ + defaultdict(_interconnect_dict) for _ in range(pe.input_count) + ] + d_out: List[DefaultDict[Tuple[Resource, int], int]] = [ + defaultdict(_interconnect_dict) for _ in range(pe.output_count) + ] + for var in pe.collection: var = cast(OperatorProcess, var) - for i, input in enumerate(var.operation.inputs): - d_in[i][self._variable_inport_to_resource[input]] += 1 + for i, input_ in enumerate(var.operation.inputs): + for v in self._variable_inport_to_resource[input_]: + d_in[i][v] += 1 for i, output in enumerate(var.operation.outputs): - d_out[i][self._variable_outport_to_resource[output]] += 1 + for v in self._variable_outport_to_resource[output]: + d_out[i][v] += 1 return [dict(d) for d in d_in], [dict(d) for d in d_out] - def _digraph(self) -> Digraph: - edges = set() + def resource_from_name(self, name: str): + re = {p.entity_name: p for p in chain(self.memories, self.processing_elements)} + return re[name] + + def move_process( + self, + proc: Union[str, Process], + re_from: Union[str, Resource], + re_to: Union[str, Resource], + assign: bool = False, + ): + """ + Move a :class:`b_asic.process.Process` from one resource to another in the + architecture. + + Both the resource moved from and will become unassigned after a process has been + moved. + + Raises :class:`KeyError` if ``proc`` is not present in resource ``re_from``. + + Parameters + ---------- + proc : :class:`b_asic.process.Process` or string + The process (or its name) to move. + re_from : :class:`b_asic.architecture.Resource` or str + The resource (or its entity name) to move the process from. + re_to : :class:`b_asic.architecture.Resource` or str + The resource (or its entity name) to move the process to. + assign : bool, default=False + Whether to perform assignment of the resources after moving. + """ + # Extract resources from name + if isinstance(re_from, str): + re_from = self.resource_from_name(re_from) + if isinstance(re_to, str): + re_to = self.resource_from_name(re_to) + + # Extract process from name + if isinstance(proc, str): + proc = re_from.collection.from_name(proc) + + # Move the process + if proc in re_from: + re_to.add_process(proc, assign=assign) + re_from.remove_process(proc, assign=assign) + else: + raise KeyError(f"{proc} not in {re_from.entity_name}") + self._build_dicts() + + def _digraph( + self, + branch_node: bool = True, + cluster: bool = True, + splines: str = "spline", + io_cluster: bool = True, + multiplexers: bool = True, + colored: bool = True, + ) -> Digraph: + """ + Parameters + ---------- + branch_node : bool, default: True + Whether to create a branch node for outputs with fan-out of two or higher. + cluster : bool, default: True + Whether to draw memories and PEs in separate clusters. + splines : str, default: "spline" + The type of interconnect to use for graph drawing. + io_cluster : bool, default: True + Whether Inputs and Outputs are drawn inside an IO cluster. Only relevant + if *cluster* is True. + multiplexers : bool, default: True + Whether input multiplexers are included. + colored : bool, default: True + Whether to color the nodes. + """ dg = Digraph(node_attr={'shape': 'record'}) - # dg.attr(rankdir="LR") - for i, mem in enumerate(self._memories): - dg.node(mem.entity_name, mem._struct_def()) - for i, pe in enumerate(self._processing_elements): - dg.node(pe.entity_name, pe._struct_def()) + dg.attr(splines=splines) + # Setup colors + pe_color = ( + f"#{''.join(f'{v:0>2X}' for v in PE_COLOR)}" if colored else "transparent" + ) + pe_cluster_color = ( + f"#{''.join(f'{v:0>2X}' for v in PE_CLUSTER_COLOR)}" + if colored + else "transparent" + ) + memory_color = ( + f"#{''.join(f'{v:0>2X}' for v in MEMORY_COLOR)}" + if colored + else "transparent" + ) + memory_cluster_color = ( + f"#{''.join(f'{v:0>2X}' for v in MEMORY_CLUSTER_COLOR)}" + if colored + else "transparent" + ) + io_color = ( + f"#{''.join(f'{v:0>2X}' for v in IO_COLOR)}" if colored else "transparent" + ) + io_cluster_color = ( + f"#{''.join(f'{v:0>2X}' for v in IO_CLUSTER_COLOR)}" + if colored + else "transparent" + ) + mux_color = ( + f"#{''.join(f'{v:0>2X}' for v in MUX_COLOR)}" if colored else "transparent" + ) + + # Add nodes for memories and PEs to graph + if cluster: + # Add subgraphs + if len(self._memories): + with dg.subgraph(name='cluster_memories') as c: + for mem in self._memories: + c.node( + mem.entity_name, + mem._struct_def(), + style='filled', + fillcolor=memory_color, + ) + label = "Memory" if len(self._memories) <= 1 else "Memories" + c.attr(label=label, bgcolor=memory_cluster_color) + with dg.subgraph(name='cluster_pes') as c: + for pe in self._processing_elements: + if pe._type_name not in ('in', 'out'): + c.node( + pe.entity_name, + pe._struct_def(), + style='filled', + fillcolor=pe_color, + ) + label = ( + "Processing element" + if len(self._processing_elements) <= 1 + else "Processing elements" + ) + c.attr(label=label, bgcolor=pe_cluster_color) + if io_cluster: + with dg.subgraph(name='cluster_io') as c: + for pe in self._processing_elements: + if pe._type_name in ('in', 'out'): + c.node( + pe.entity_name, + pe._struct_def(), + style='filled', + fillcolor=io_color, + ) + c.attr(label="IO", bgcolor=io_cluster_color) + else: + for pe in self._processing_elements: + if pe._type_name in ('in', 'out'): + dg.node( + pe.entity_name, + pe._struct_def(), + style='filled', + fillcolor=io_color, + ) + else: + for i, mem in enumerate(self._memories): + dg.node( + mem.entity_name, + mem._struct_def(), + style='filled', + fillcolor=memory_color, + ) + for i, pe in enumerate(self._processing_elements): + dg.node( + pe.entity_name, pe._struct_def(), style='filled', fillcolor=pe_color + ) + + # Create list of interconnects + edges: DefaultDict[str, Set[Tuple[str, str]]] = defaultdict(set) + destination_edges: DefaultDict[str, Set[str]] = defaultdict(set) for pe in self._processing_elements: inputs, outputs = self.get_interconnects_for_pe(pe) for i, inp in enumerate(inputs): for (source, port), cnt in inp.items(): - edges.add( + source_str = f"{source.entity_name}:out{port}" + destination_str = f"{pe.entity_name}:in{i}" + edges[source_str].add( ( - f"{source.entity_name}:out{port}", - f"{pe.entity_name}:in{i}", + destination_str, f"{cnt}", ) ) - for o, outp in enumerate(outputs): - for (dest, port), cnt in outp.items(): - edges.add( + destination_edges[destination_str].add(source_str) + for o, output in enumerate(outputs): + for (destination, port), cnt in output.items(): + source_str = f"{pe.entity_name}:out{o}" + destination_str = f"{destination.entity_name}:in{port}" + edges[source_str].add( ( - f"{pe.entity_name}:out{o}", - f"{dest.entity_name}:in{port}", + destination_str, f"{cnt}", ) ) - for src, dest, cnt in edges: - dg.edge(src, dest, label=cnt) + destination_edges[destination_str].add(source_str) + + destination_list = {k: list(v) for k, v in destination_edges.items()} + if multiplexers: + for destination, source_list in destination_list.items(): + if len(source_list) > 1: + # Create GraphViz struct for multiplexer + inputs = [f"in{i}" for i in range(len(source_list))] + ret = "" + in_strs = [f"<{in_str}> {in_str}" for in_str in inputs] + ret += f"{{{'|'.join(in_strs)}}}|" + name = f"{destination.replace(':', '_')}_mux" + ret += f"<{name}> {name}" + ret += "|<out0> out0" + dg.node(name, "{" + ret + "}", style='filled', fillcolor=mux_color) + # Add edge from mux output to resource input + dg.edge(f"{name}:out0", destination) + + # Add edges to graph + for src_str, destination_counts in edges.items(): + original_src_str = src_str + if len(destination_counts) > 1 and branch_node: + branch = f"{src_str}_branch".replace(":", "") + dg.node(branch, shape='point') + dg.edge(src_str, branch, arrowhead='none') + src_str = branch + for destination_str, cnt_str in destination_counts: + if multiplexers and len(destination_list[destination_str]) > 1: + idx = destination_list[destination_str].index(original_src_str) + destination_str = f"{destination_str.replace(':', '_')}_mux:in{idx}" + dg.edge(src_str, destination_str, label=cnt_str) return dg @property - def memories(self) -> Iterable[Memory]: + def memories(self) -> List[Memory]: return self._memories @property - def processing_elements(self) -> Iterable[ProcessingElement]: + def processing_elements(self) -> List[ProcessingElement]: return self._processing_elements @property diff --git a/b_asic/codegen/__init__.py b/b_asic/codegen/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/b_asic/codegen/vhdl/__init__.py b/b_asic/codegen/vhdl/__init__.py index 5d6625b402076302506d63650de985476f301783..2cce156aac66e7e720f2113c99f4515507eb2217 100755 --- a/b_asic/codegen/vhdl/__init__.py +++ b/b_asic/codegen/vhdl/__init__.py @@ -2,15 +2,14 @@ Module for basic VHDL code generation. """ -from io import TextIOWrapper -from typing import List, Optional, Tuple, Union +from typing import List, Optional, TextIO, Tuple, Union # VHDL code generation tab length VHDL_TAB = r" " def write( - f: TextIOWrapper, + f: TextIO, indent_level: int, text: str, *, @@ -20,13 +19,13 @@ def write( """ Base VHDL code generation utility. - `f'{VHDL_TAB*indent_level}'` is first written to the :class:`io.TextIOWrapper` - object `f`. Immediatly after the indentation, `text` is written to `f`. Finally, - `text` is also written to `f`. + ``f'{VHDL_TAB*indent_level}'`` is first written to the TextIO + object *f*. Immediately after the indentation, *text* is written to *f*. Finally, + *text* is also written to *f*. Parameters ---------- - f : :class:`io.TextIOWrapper` + f : TextIO The file object to emit VHDL code to. indent_level : int Indentation level to use. Exactly ``f'{VHDL_TAB*indent_level}`` is written @@ -43,29 +42,27 @@ def write( f.write(f'{VHDL_TAB*indent_level}{text}{end}') -def write_lines( - f: TextIOWrapper, lines: List[Union[Tuple[int, str], Tuple[int, str, str]]] -): +def write_lines(f: TextIO, lines: List[Union[Tuple[int, str], Tuple[int, str, str]]]): """ Multiline VHDL code generation utility. - Each tuple (int, str, [int]) in the list `lines` is written to the - :class:`io.TextIOWrapper` object `f` using the :function:`vhdl.write` function. + Each tuple ``(int, str, [int])`` in the list *lines* is written to the + TextIO object *f* using the :function:`vhdl.write` function. Parameters ---------- - f : :class:`io.TextIOWrapper` + f : TextIO The file object to emit VHDL code to. lines : list of tuple (int,str) [1], or list of tuple (int,str,str) [2] - [1]: The first `int` of the tuple is used as indentation level for the line and - the second `str` of the tuple is the content of the line. - [2]: Same as [1], but the third `str` of the tuple is passed to parameter `end` - when calling :function:`vhdl.write`. + [1]: The first ``int`` of the tuple is used as indentation level for the line + and the second ``str`` of the tuple is the content of the line. + [2]: Same as [1], but the third ``str`` of the tuple is passed to parameter + *end* when calling :function:`vhdl.write`. """ for tpl in lines: if len(tpl) == 2: - write(f, indent_level=tpl[0], text=tpl[1]) + write(f, indent_level=tpl[0], text=str(tpl[1])) elif len(tpl) == 3: - write(f, indent_level=tpl[0], text=tpl[1], end=tpl[2]) + write(f, indent_level=tpl[0], text=str(tpl[1]), end=str(tpl[2])) else: raise ValueError('All tuples in list `lines` must have length 2 or 3') diff --git a/b_asic/codegen/vhdl/architecture.py b/b_asic/codegen/vhdl/architecture.py index 4f585f702c0d97f115052d0caafabffb7cae4b77..bc8d822780f1a1f8ea453769d3f4f37e0c336eca 100755 --- a/b_asic/codegen/vhdl/architecture.py +++ b/b_asic/codegen/vhdl/architecture.py @@ -1,10 +1,9 @@ """ Module for code generation of VHDL architectures. """ -from io import TextIOWrapper -from typing import TYPE_CHECKING, Dict, Set, Tuple, cast +from typing import TYPE_CHECKING, Dict, List, Set, TextIO, Tuple, cast -from b_asic.codegen import vhdl +from b_asic.codegen.vhdl import common, write, write_lines from b_asic.process import MemoryVariable, PlainMemoryVariable if TYPE_CHECKING: @@ -12,8 +11,8 @@ if TYPE_CHECKING: def memory_based_storage( - f: TextIOWrapper, - assignment: Set["ProcessCollection"], + f: TextIO, + assignment: List["ProcessCollection"], entity_name: str, word_length: int, read_ports: int, @@ -26,9 +25,9 @@ def memory_based_storage( Parameters ---------- - f : TextIOWrapper - File object (or other TextIOWrapper object) to write the architecture onto. - assignment : dict + f : TextIO + File object (or other TextIO object) to write the architecture onto. + assignment : list A possible cell assignment to use when generating the memory based storage. The cell assignment is a dictionary int to ProcessCollection where the integer corresponds to the cell to assign all MemoryVariables in corresponding process @@ -52,27 +51,25 @@ def memory_based_storage( # Code settings mem_depth = len(assignment) architecture_name = "rtl" - schedule_time = next(iter(assignment))._schedule_time + schedule_time = next(iter(assignment)).schedule_time # Write architecture header - vhdl.write( - f, 0, f'architecture {architecture_name} of {entity_name} is', end='\n\n' - ) + write(f, 0, f'architecture {architecture_name} of {entity_name} is', end='\n\n') # - # Architecture declerative region begin + # Architecture declarative region begin # - vhdl.write(f, 1, '-- HDL memory description') - vhdl.common.constant_declaration( + write(f, 1, '-- HDL memory description') + common.constant_declaration( f, name='MEM_WL', signal_type='integer', value=word_length, name_pad=12 ) - vhdl.common.constant_declaration( + common.constant_declaration( f, name='MEM_DEPTH', signal_type='integer', value=mem_depth, name_pad=12 ) - vhdl.common.type_declaration( + common.type_declaration( f, 'mem_type', 'array(0 to MEM_DEPTH-1) of std_logic_vector(MEM_WL-1 downto 0)' ) - vhdl.common.signal_decl( + common.signal_declaration( f, name='memory', signal_type='mem_type', @@ -80,25 +77,25 @@ def memory_based_storage( vivado_ram_style='distributed', ) for i in range(read_ports): - vhdl.common.signal_decl( + common.signal_declaration( f, f'read_port_{i}', 'std_logic_vector(MEM_WL-1 downto 0)', name_pad=14 ) - vhdl.common.signal_decl( + common.signal_declaration( f, f'read_adr_{i}', f'integer range 0 to {schedule_time}-1', name_pad=14 ) - vhdl.common.signal_decl(f, f'read_en_{i}', 'std_logic', name_pad=14) + common.signal_declaration(f, f'read_en_{i}', 'std_logic', name_pad=14) for i in range(write_ports): - vhdl.common.signal_decl( + common.signal_declaration( f, f'write_port_{i}', 'std_logic_vector(MEM_WL-1 downto 0)', name_pad=14 ) - vhdl.common.signal_decl( + common.signal_declaration( f, f'write_adr_{i}', f'integer range 0 to {schedule_time}-1', name_pad=14 ) - vhdl.common.signal_decl(f, f'write_en_{i}', 'std_logic', name_pad=14) + common.signal_declaration(f, f'write_en_{i}', 'std_logic', name_pad=14) # Schedule time counter - vhdl.write(f, 1, '-- Schedule counter', start='\n') - vhdl.common.signal_decl( + write(f, 1, '-- Schedule counter', start='\n') + common.signal_declaration( f, name='schedule_cnt', signal_type=f'integer range 0 to {schedule_time}-1', @@ -107,23 +104,23 @@ def memory_based_storage( # Input sync signals if input_sync: - vhdl.write(f, 1, '-- Input synchronization', start='\n') + write(f, 1, '-- Input synchronization', start='\n') for i in range(read_ports): - vhdl.common.signal_decl( + common.signal_declaration( f, f'p_{i}_in_sync', 'std_logic_vector(WL-1 downto 0)', name_pad=14 ) # # Architecture body begin # - vhdl.write(f, 0, 'begin', start='\n', end='\n\n') - vhdl.write(f, 1, '-- Schedule counter') - vhdl.common.synchronous_process_prologue( + write(f, 0, 'begin', start='\n', end='\n\n') + write(f, 1, '-- Schedule counter') + common.synchronous_process_prologue( f=f, name='schedule_cnt_proc', clk='clk', ) - vhdl.write_lines( + write_lines( f, [ (3, 'if rst = \'1\' then'), @@ -139,30 +136,30 @@ def memory_based_storage( (3, 'end if;'), ], ) - vhdl.common.synchronous_process_epilogue( + common.synchronous_process_epilogue( f=f, name='schedule_cnt_proc', clk='clk', ) if input_sync: - vhdl.write(f, 1, '-- Input synchronization', start='\n') - vhdl.common.synchronous_process_prologue( + write(f, 1, '-- Input synchronization', start='\n') + common.synchronous_process_prologue( f=f, name='input_sync_proc', clk='clk', ) for i in range(read_ports): - vhdl.write(f, 3, f'p_{i}_in_sync <= p_{i}_in;') - vhdl.common.synchronous_process_epilogue( + write(f, 3, f'p_{i}_in_sync <= p_{i}_in;') + common.synchronous_process_epilogue( f=f, name='input_sync_proc', clk='clk', ) # Infer memory - vhdl.write(f, 1, '-- Memory', start='\n') - vhdl.common.asynchronous_read_memory( + write(f, 1, '-- Memory', start='\n') + common.asynchronous_read_memory( f=f, clk='clk', name=f'mem_{0}_proc', @@ -177,30 +174,28 @@ def memory_based_storage( ) # Write address generation - vhdl.write(f, 1, '-- Memory write address generation', start='\n') + write(f, 1, '-- Memory write address generation', start='\n') if input_sync: - vhdl.common.synchronous_process_prologue( - f, clk="clk", name="mem_write_address_proc" - ) + common.synchronous_process_prologue(f, clk="clk", name="mem_write_address_proc") else: - vhdl.common.process_prologue( + common.process_prologue( f, sensitivity_list="schedule_cnt", name="mem_write_address_proc" ) - vhdl.write(f, 3, 'case schedule_cnt is') + write(f, 3, 'case schedule_cnt is') for i, collection in enumerate(assignment): for mv in collection: mv = cast(MemoryVariable, mv) if mv.execution_time: - vhdl.write_lines( + write_lines( f, [ (4, f'-- {mv!r}'), - (4, f'when {(mv.start_time) % schedule_time} =>'), + (4, f'when {mv.start_time % schedule_time} =>'), (5, f'write_adr_0 <= {i};'), (5, 'write_en_0 <= \'1\';'), ], ) - vhdl.write_lines( + write_lines( f, [ (4, 'when others =>'), @@ -210,38 +205,36 @@ def memory_based_storage( ], ) if input_sync: - vhdl.common.synchronous_process_epilogue( - f, clk="clk", name="mem_write_address_proc" - ) + common.synchronous_process_epilogue(f, clk="clk", name="mem_write_address_proc") else: - vhdl.common.process_epilogue( + common.process_epilogue( f, sensitivity_list="clk", name="mem_write_address_proc" ) # Read address generation - vhdl.write(f, 1, '-- Memory read address generation', start='\n') - vhdl.common.synchronous_process_prologue(f, clk="clk", name="mem_read_address_proc") - vhdl.write(f, 3, 'case schedule_cnt is') + write(f, 1, '-- Memory read address generation', start='\n') + common.synchronous_process_prologue(f, clk="clk", name="mem_read_address_proc") + write(f, 3, 'case schedule_cnt is') for i, collection in enumerate(assignment): for mv in collection: mv = cast(PlainMemoryVariable, mv) - vhdl.write(f, 4, f'-- {mv!r}') + write(f, 4, f'-- {mv!r}') for read_time in mv.reads.values(): - vhdl.write( + write( f, 4, 'when' f' {(mv.start_time+read_time-int(not(input_sync))) % schedule_time}' ' =>', ) - vhdl.write_lines( + write_lines( f, [ (5, f'read_adr_0 <= {i};'), (5, 'read_en_0 <= \'1\';'), ], ) - vhdl.write_lines( + write_lines( f, [ (4, 'when others =>'), @@ -250,46 +243,46 @@ def memory_based_storage( (3, 'end case;'), ], ) - vhdl.common.synchronous_process_epilogue(f, clk="clk", name="mem_read_address_proc") + common.synchronous_process_epilogue(f, clk="clk", name="mem_read_address_proc") - vhdl.write(f, 1, '-- Input and output assignmentn', start='\n') + write(f, 1, '-- Input and output assignments', start='\n') if input_sync: - vhdl.write(f, 1, 'write_port_0 <= p_0_in_sync;') + write(f, 1, 'write_port_0 <= p_0_in_sync;') else: - vhdl.write(f, 1, 'write_port_0 <= p_0_in;') + write(f, 1, 'write_port_0 <= p_0_in;') p_zero_exec = filter( lambda p: p.execution_time == 0, (p for pc in assignment for p in pc) ) - vhdl.common.synchronous_process_prologue( + common.synchronous_process_prologue( f, clk='clk', name='output_reg_proc', ) - vhdl.write(f, 3, 'case schedule_cnt is') + write(f, 3, 'case schedule_cnt is') for p in p_zero_exec: if input_sync: write_time = (p.start_time + 1) % schedule_time - vhdl.write(f, 4, f'when {write_time} => p_0_out <= p_0_in_sync;') + write(f, 4, f'when {write_time} => p_0_out <= p_0_in_sync;') else: write_time = (p.start_time) % schedule_time - vhdl.write(f, 4, f'when {write_time} => p_0_out <= p_0_in;') - vhdl.write_lines( + write(f, 4, f'when {write_time} => p_0_out <= p_0_in;') + write_lines( f, [ (4, 'when others => p_0_out <= read_port_0;'), (3, 'end case;'), ], ) - vhdl.common.synchronous_process_epilogue( + common.synchronous_process_epilogue( f, clk='clk', name='output_reg_proc', ) - vhdl.write(f, 0, f'end architecture {architecture_name};', start='\n') + write(f, 0, f'end architecture {architecture_name};', start='\n') def register_based_storage( - f: TextIOWrapper, + f: TextIO, forward_backward_table: "_ForwardBackwardTable", entity_name: str, word_length: int, @@ -325,16 +318,14 @@ def register_based_storage( } # - # Architecture declerative region begin + # Architecture declarative region begin # # Write architecture header - vhdl.write( - f, 0, f'architecture {architecture_name} of {entity_name} is', end='\n\n' - ) + write(f, 0, f'architecture {architecture_name} of {entity_name} is', end='\n\n') # Schedule time counter - vhdl.write(f, 1, '-- Schedule counter') - vhdl.common.signal_decl( + write(f, 1, '-- Schedule counter') + common.signal_declaration( f, name='schedule_cnt', signal_type=f'integer range 0 to {schedule_time}-1', @@ -343,13 +334,13 @@ def register_based_storage( ) # Shift register - vhdl.write(f, 1, '-- Shift register', start='\n') - vhdl.common.type_declaration( + write(f, 1, '-- Shift register', start='\n') + common.type_declaration( f, name='shift_reg_type', alias=f'array(0 to {reg_cnt}-1) of std_logic_vector(WL-1 downto 0)', ) - vhdl.common.signal_decl( + common.signal_declaration( f, name='shift_reg', signal_type='shift_reg_type', @@ -357,8 +348,8 @@ def register_based_storage( ) # Back edge mux decoder - vhdl.write(f, 1, '-- Back-edge mux select signal', start='\n') - vhdl.common.signal_decl( + write(f, 1, '-- Back-edge mux select signal', start='\n') + common.signal_declaration( f, name='back_edge_mux_sel', signal_type=f'integer range 0 to {len(back_edges)}', @@ -366,8 +357,8 @@ def register_based_storage( ) # Output mux selector - vhdl.write(f, 1, '-- Output mux select signal', start='\n') - vhdl.common.signal_decl( + write(f, 1, '-- Output mux select signal', start='\n') + common.signal_declaration( f, name='out_mux_sel', signal_type=f'integer range 0 to {len(output_regs) - 1}', @@ -377,14 +368,14 @@ def register_based_storage( # # Architecture body begin # - vhdl.write(f, 0, 'begin', start='\n', end='\n\n') - vhdl.write(f, 1, '-- Schedule counter') - vhdl.common.synchronous_process_prologue( + write(f, 0, 'begin', start='\n', end='\n\n') + write(f, 1, '-- Schedule counter') + common.synchronous_process_prologue( f=f, name='schedule_cnt_proc', clk='clk', ) - vhdl.write_lines( + write_lines( f, [ (4, 'if en = \'1\' then'), @@ -396,26 +387,26 @@ def register_based_storage( (4, 'end if;'), ], ) - vhdl.common.synchronous_process_epilogue( + common.synchronous_process_epilogue( f=f, name='schedule_cnt_proc', clk='clk', ) # Shift register back-edge decoding - vhdl.write(f, 1, '-- Shift register back-edge decoding', start='\n') - vhdl.common.synchronous_process_prologue( + write(f, 1, '-- Shift register back-edge decoding', start='\n') + common.synchronous_process_prologue( f, clk='clk', name='shift_reg_back_edge_decode_proc', ) - vhdl.write(f, 3, 'case schedule_cnt is') + write(f, 3, 'case schedule_cnt is') for time, entry in enumerate(forward_backward_table): if entry.back_edge_to: assert len(entry.back_edge_to) == 1 for src, dst in entry.back_edge_to.items(): mux_idx = back_edge_table[(src, dst)] - vhdl.write_lines( + write_lines( f, [ (4, f'when {(time-1)%schedule_time} =>'), @@ -423,7 +414,7 @@ def register_based_storage( (5, f'back_edge_mux_sel <= {mux_idx};'), ], ) - vhdl.write_lines( + write_lines( f, [ (4, 'when others =>'), @@ -431,26 +422,26 @@ def register_based_storage( (3, 'end case;'), ], ) - vhdl.common.synchronous_process_epilogue( + common.synchronous_process_epilogue( f, clk='clk', name='shift_reg_back_edge_decode_proc', ) # Shift register multiplexer logic - vhdl.write(f, 1, '-- Multiplexers for shift register', start='\n') - vhdl.common.synchronous_process_prologue( + write(f, 1, '-- Multiplexers for shift register', start='\n') + common.synchronous_process_prologue( f, clk='clk', name='shift_reg_proc', ) if sync_rst: - vhdl.write(f, 3, 'if rst = \'1\' then') + write(f, 3, 'if rst = \'1\' then') for reg_idx in range(reg_cnt): - vhdl.write(f, 4, f'shift_reg({reg_idx}) <= (others => \'0\');') - vhdl.write(f, 3, 'else') + write(f, 4, f'shift_reg({reg_idx}) <= (others => \'0\');') + write(f, 3, 'else') - vhdl.write_lines( + write_lines( f, [ (3, '-- Default case'), @@ -458,17 +449,17 @@ def register_based_storage( ], ) for reg_idx in range(1, reg_cnt): - vhdl.write(f, 3, f'shift_reg({reg_idx}) <= shift_reg({reg_idx-1});') - vhdl.write(f, 3, 'case back_edge_mux_sel is') + write(f, 3, f'shift_reg({reg_idx}) <= shift_reg({reg_idx-1});') + write(f, 3, 'case back_edge_mux_sel is') for edge, mux_sel in back_edge_table.items(): - vhdl.write_lines( + write_lines( f, [ (4, f'when {mux_sel} =>'), (5, f'shift_reg({edge[1]}) <= shift_reg({edge[0]});'), ], ) - vhdl.write_lines( + write_lines( f, [ (4, 'when others => null;'), @@ -477,45 +468,45 @@ def register_based_storage( ) if sync_rst: - vhdl.write(f, 3, 'end if;') + write(f, 3, 'end if;') - vhdl.common.synchronous_process_epilogue( + common.synchronous_process_epilogue( f, clk='clk', name='shift_reg_proc', ) # Output multiplexer decoding logic - vhdl.write(f, 1, '-- Output muliplexer decoding logic', start='\n') - vhdl.common.synchronous_process_prologue(f, clk='clk', name='out_mux_decode_proc') - vhdl.write(f, 3, 'case schedule_cnt is') + write(f, 1, '-- Output multiplexer decoding logic', start='\n') + common.synchronous_process_prologue(f, clk='clk', name='out_mux_decode_proc') + write(f, 3, 'case schedule_cnt is') for i, entry in enumerate(forward_backward_table): if entry.outputs_from is not None: sel = output_mux_table[entry.outputs_from] - vhdl.write(f, 4, f'when {(i-1)%schedule_time} =>') - vhdl.write(f, 5, f'out_mux_sel <= {sel};') - vhdl.write(f, 3, 'end case;') - vhdl.common.synchronous_process_epilogue(f, clk='clk', name='out_mux_decode_proc') + write(f, 4, f'when {(i-1)%schedule_time} =>') + write(f, 5, f'out_mux_sel <= {sel};') + write(f, 3, 'end case;') + common.synchronous_process_epilogue(f, clk='clk', name='out_mux_decode_proc') # Output multiplexer logic - vhdl.write(f, 1, '-- Output muliplexer', start='\n') - vhdl.common.synchronous_process_prologue( + write(f, 1, '-- Output multiplexer', start='\n') + common.synchronous_process_prologue( f, clk='clk', name='out_mux_proc', ) - vhdl.write(f, 3, 'case out_mux_sel is') + write(f, 3, 'case out_mux_sel is') for reg_i, mux_i in output_mux_table.items(): - vhdl.write(f, 4, f'when {mux_i} =>') + write(f, 4, f'when {mux_i} =>') if reg_i < 0: - vhdl.write(f, 5, f'p_0_out <= p_{-1-reg_i}_in;') + write(f, 5, f'p_0_out <= p_{-1-reg_i}_in;') else: - vhdl.write(f, 5, f'p_0_out <= shift_reg({reg_i});') - vhdl.write(f, 3, 'end case;') - vhdl.common.synchronous_process_epilogue( + write(f, 5, f'p_0_out <= shift_reg({reg_i});') + write(f, 3, 'end case;') + common.synchronous_process_epilogue( f, clk='clk', name='out_mux_proc', ) - vhdl.write(f, 0, f'end architecture {architecture_name};', start='\n') + write(f, 0, f'end architecture {architecture_name};', start='\n') diff --git a/b_asic/codegen/vhdl/common.py b/b_asic/codegen/vhdl/common.py index 5e51e20b9faa69379054ed4716cd7bfe63f6cca2..a1ac777c108c0ba2614ba8bb9846375a37bfa20c 100755 --- a/b_asic/codegen/vhdl/common.py +++ b/b_asic/codegen/vhdl/common.py @@ -2,21 +2,21 @@ Generation of common VHDL constructs """ +import re from datetime import datetime -from io import TextIOWrapper from subprocess import PIPE, Popen -from typing import Any, Optional, Set, Tuple +from typing import Any, Optional, Set, TextIO, Tuple -from b_asic.codegen import vhdl +from b_asic.codegen.vhdl import write, write_lines -def b_asic_preamble(f: TextIOWrapper): +def b_asic_preamble(f: TextIO): """ Write a standard BASIC VHDL preamble comment. Parameters ---------- - f : :class:`io.TextIOWrapper` + f : TextIO The file object to write the header to. """ # Try to acquire the current git commit hash @@ -26,7 +26,7 @@ def b_asic_preamble(f: TextIOWrapper): git_commit_id = process.communicate()[0].decode('utf-8').strip() except: # noqa: E722 pass - vhdl.write_lines( + write_lines( f, [ (0, '--'), @@ -35,8 +35,8 @@ def b_asic_preamble(f: TextIOWrapper): ], ) if git_commit_id: - vhdl.write(f, 0, f'-- B-ASIC short commit hash: {git_commit_id}') - vhdl.write_lines( + write(f, 0, f'-- B-ASIC short commit hash: {git_commit_id}') + write_lines( f, [ (0, '-- URL: https://gitlab.liu.se/da/B-ASIC'), @@ -46,7 +46,7 @@ def b_asic_preamble(f: TextIOWrapper): def ieee_header( - f: TextIOWrapper, + f: TextIO, std_logic_1164: bool = True, numeric_std: bool = True, ): @@ -56,23 +56,23 @@ def ieee_header( Parameters ---------- - f : :class:`io.TextIOWrapper` - The TextIOWrapper object to write the IEEE header to. + f : TextIO + The TextIO object to write the IEEE header to. std_logic_1164 : bool, default: True Include the std_logic_1164 header. numeric_std : bool, default: True Include the numeric_std header. """ - vhdl.write(f, 0, 'library ieee;') + write(f, 0, 'library ieee;') if std_logic_1164: - vhdl.write(f, 0, 'use ieee.std_logic_1164.all;') + write(f, 0, 'use ieee.std_logic_1164.all;') if numeric_std: - vhdl.write(f, 0, 'use ieee.numeric_std.all;') - vhdl.write(f, 0, '') + write(f, 0, 'use ieee.numeric_std.all;') + write(f, 0, '') -def signal_decl( - f: TextIOWrapper, +def signal_declaration( + f: TextIO, name: str, signal_type: str, default_value: Optional[str] = None, @@ -87,8 +87,8 @@ def signal_decl( Parameters ---------- - f : :class:`io.TextIOWrapper` - The TextIOWrapper object to write the IEEE header to. + f : TextIO + The TextIO object to write the IEEE header to. name : str Signal name. signal_type : str @@ -108,12 +108,12 @@ def signal_decl( """ # Spacing of VHDL signals declaration always with a single tab name_pad = name_pad or 0 - vhdl.write(f, 1, f'signal {name:<{name_pad}} : {signal_type}', end='') + write(f, 1, f'signal {name:<{name_pad}} : {signal_type}', end='') if default_value is not None: - vhdl.write(f, 0, f' := {default_value}', end='') - vhdl.write(f, 0, ';') + write(f, 0, f' := {default_value}', end='') + write(f, 0, ';') if vivado_ram_style is not None: - vhdl.write_lines( + write_lines( f, [ (1, 'attribute ram_style : string;'), @@ -121,7 +121,7 @@ def signal_decl( ], ) if quartus_ram_style is not None: - vhdl.write_lines( + write_lines( f, [ (1, 'attribute ramstyle : string;'), @@ -131,7 +131,7 @@ def signal_decl( def constant_declaration( - f: TextIOWrapper, + f: TextIO, name: str, signal_type: str, value: Any, @@ -143,8 +143,8 @@ def constant_declaration( Parameters ---------- - f : :class:`io.TextIOWrapper` - The TextIOWrapper object to write the constant declaration to. + f : TextIO + The TextIO object to write the constant declaration to. name : str Signal name. signal_type : str @@ -155,11 +155,11 @@ def constant_declaration( An optional left padding value applied to the name. """ name_pad = 0 if name_pad is None else name_pad - vhdl.write(f, 1, f'constant {name:<{name_pad}} : {signal_type} := {str(value)};') + write(f, 1, f'constant {name:<{name_pad}} : {signal_type} := {str(value)};') def type_declaration( - f: TextIOWrapper, + f: TextIO, name: str, alias: str, ): @@ -168,18 +168,18 @@ def type_declaration( Parameters ---------- - f : :class:`io.TextIOWrapper` - The TextIOWrapper object to write the type declaration to. + f : TextIO + The TextIO object to write the type declaration to. name : str Type name alias. alias : str The type to tie the new name to. """ - vhdl.write(f, 1, f'type {name} is {alias};') + write(f, 1, f'type {name} is {alias};') def process_prologue( - f: TextIOWrapper, + f: TextIO, sensitivity_list: str, indent: int = 1, name: Optional[str] = None, @@ -191,8 +191,8 @@ def process_prologue( Parameters ---------- - f : :class:`io.TextIOWrapper` - The TextIOWrapper object to write the type declaration to. + f : TextIO + The TextIO object to write the type declaration to. sensitivity_list : str Content of the process sensitivity list. indent : int, default: 1 @@ -201,14 +201,14 @@ def process_prologue( An optional name for the process. """ if name is not None: - vhdl.write(f, indent, f'{name}: process({sensitivity_list})') + write(f, indent, f'{name}: process({sensitivity_list})') else: - vhdl.write(f, indent, f'process({sensitivity_list})') - vhdl.write(f, indent, 'begin') + write(f, indent, f'process({sensitivity_list})') + write(f, indent, 'begin') def process_epilogue( - f: TextIOWrapper, + f: TextIO, sensitivity_list: Optional[str] = None, indent: int = 1, name: Optional[str] = None, @@ -216,8 +216,8 @@ def process_epilogue( """ Parameters ---------- - f : :class:`io.TextIOWrapper` - The TextIOWrapper object to write the type declaration to. + f : TextIO + The TextIO object to write the type declaration to. sensitivity_list : str Content of the process sensitivity list. Not needed when writing the epilogue. indent : int, default: 1 @@ -228,14 +228,14 @@ def process_epilogue( An optional name of the ending process. """ _ = sensitivity_list - vhdl.write(f, indent, 'end process', end="") + write(f, indent, 'end process', end="") if name is not None: - vhdl.write(f, 0, ' ' + name, end="") - vhdl.write(f, 0, ';') + write(f, 0, ' ' + name, end="") + write(f, 0, ';') def synchronous_process_prologue( - f: TextIOWrapper, + f: TextIO, clk: str, indent: int = 1, name: Optional[str] = None, @@ -250,8 +250,8 @@ def synchronous_process_prologue( Parameters ---------- - f : :class:`io.TextIOWrapper` - The TextIOWrapper to write the VHDL code onto. + f : TextIO + The TextIO to write the VHDL code onto. clk : str Name of the clock. indent : int, default: 1 @@ -260,11 +260,11 @@ def synchronous_process_prologue( An optional name for the process. """ process_prologue(f, sensitivity_list=clk, indent=indent, name=name) - vhdl.write(f, indent + 1, 'if rising_edge(clk) then') + write(f, indent + 1, 'if rising_edge(clk) then') def synchronous_process_epilogue( - f: TextIOWrapper, + f: TextIO, clk: Optional[str], indent: int = 1, name: Optional[str] = None, @@ -277,8 +277,8 @@ def synchronous_process_epilogue( Parameters ---------- - f : :class:`io.TextIOWrapper` - The TextIOWrapper to write the VHDL code onto. + f : TextIO + The TextIO to write the VHDL code onto. clk : str Name of the clock. indent : int, default: 1 @@ -287,12 +287,12 @@ def synchronous_process_epilogue( An optional name for the process """ _ = clk - vhdl.write(f, indent + 1, 'end if;') + write(f, indent + 1, 'end if;') process_epilogue(f, sensitivity_list=clk, indent=indent, name=name) def synchronous_process( - f: TextIOWrapper, + f: TextIO, clk: str, body: str, indent: int = 1, @@ -306,8 +306,8 @@ def synchronous_process( Parameters ---------- - f : :class:`io.TextIOWrapper` - The TextIOWrapper to write the VHDL code onto. + f : TextIO + The TextIO to write the VHDL code onto. clk : str Name of the clock. body : str @@ -320,12 +320,12 @@ def synchronous_process( synchronous_process_prologue(f, clk, indent, name) for line in body.split('\n'): if len(line): - vhdl.write(f, indent + 2, f'{line}') + write(f, indent + 2, f'{line}') synchronous_process_epilogue(f, clk, indent, name) def synchronous_memory( - f: TextIOWrapper, + f: TextIO, clk: str, read_ports: Set[Tuple[str, str, str]], write_ports: Set[Tuple[str, str, str]], @@ -336,8 +336,8 @@ def synchronous_memory( Parameters ---------- - f : :class:`io.TextIOWrapper` - The TextIOWrapper to write the VHDL code onto. + f : TextIO + The TextIO to write the VHDL code onto. clk : str Name of clock identifier to the synchronous memory. read_ports : Set[Tuple[str,str]] @@ -350,17 +350,17 @@ def synchronous_memory( assert len(read_ports) >= 1 assert len(write_ports) >= 1 synchronous_process_prologue(f, clk=clk, name=name) - for read_name, address, re in read_ports: - vhdl.write_lines( + for read_name, address, read_enable in read_ports: + write_lines( f, [ - (3, f'if {re} = \'1\' then'), + (3, f'if {read_enable} = \'1\' then'), (4, f'{read_name} <= memory({address});'), (3, 'end if;'), ], ) for write_name, address, we in write_ports: - vhdl.write_lines( + write_lines( f, [ (3, f'if {we} = \'1\' then'), @@ -372,7 +372,7 @@ def synchronous_memory( def asynchronous_read_memory( - f: TextIOWrapper, + f: TextIO, clk: str, read_ports: Set[Tuple[str, str, str]], write_ports: Set[Tuple[str, str, str]], @@ -383,8 +383,8 @@ def asynchronous_read_memory( Parameters ---------- - f : :class:`io.TextIOWrapper` - The TextIOWrapper to write the VHDL code onto. + f : TextIO + The TextIO to write the VHDL code onto. clk : str Name of clock identifier to the synchronous memory. read_ports : Set[Tuple[str,str]] @@ -398,7 +398,7 @@ def asynchronous_read_memory( assert len(write_ports) >= 1 synchronous_process_prologue(f, clk=clk, name=name) for write_name, address, we in write_ports: - vhdl.write_lines( + write_lines( f, [ (3, f'if {we} = \'1\' then'), @@ -408,4 +408,166 @@ def asynchronous_read_memory( ) synchronous_process_epilogue(f, clk=clk, name=name) for read_name, address, _ in read_ports: - vhdl.write(f, 1, f'{read_name} <= memory({address});') + write(f, 1, f'{read_name} <= memory({address});') + + +def is_valid_vhdl_identifier(identifier: str) -> bool: + """ + Test if identifier is a valid VHDL identifier, as specified by VHDL 2019. + + An identifier is a valid VHDL identifier if it is not a VHDL reserved keyword and + it is a valid basic identifier as specified by IEEE STD 1076-2019 (VHDL standard). + + Parameters + ---------- + identifier : str + The identifier to test. + + Returns + ------- + Returns True if identifier is a valid VHDL identifier, False otherwise. + """ + # IEEE STD 1076-2019: + # Sec. 15.4.2, Basic identifiers: + # * A basic identifier consists only of letters, digits, and underlines. + # * A basic identifier is not a reserved VHDL keyword + is_basic_identifier = ( + re.fullmatch(pattern=r'[a-zA-Z][0-9a-zA-Z_]*', string=identifier) is not None + ) + return is_basic_identifier and not is_vhdl_reserved_keyword(identifier) + + +def is_vhdl_reserved_keyword(identifier: str) -> bool: + """ + Test if identifier is a reserved VHDL keyword. + + Parameters + ---------- + identifier : str + The identifier to test. + + Returns + ------- + Returns True if identifier is reserved, False otherwise. + """ + # List of reserved keyword in IEEE STD 1076-2019. + # Sec. 15.10, Reserved words: + reserved_keywords = ( + "abs", + "access", + "after", + "alias", + "all", + "and", + "architecture", + "array", + "assert", + "assume", + "attribute", + "begin", + "block", + "body", + "buffer", + "bus", + "case", + "component", + "configuration", + "constant", + "context", + "cover", + "default", + "disconnect", + "downto", + "else", + "elsif", + "end", + "entity", + "exit", + "fairness", + "file", + "for", + "force", + "function", + "generate", + "generic", + "group", + "guarded", + "if", + "impure", + "in", + "inertial", + "inout", + "is", + "label", + "library", + "linkage", + "literal", + "loop", + "map", + "mod", + "nand", + "new", + "next", + "nor", + "not", + "null", + "of", + "on", + "open", + "or", + "others", + "out", + "package", + "parameter", + "port", + "postponed", + "procedure", + "process", + "property", + "protected", + "private", + "pure", + "range", + "record", + "register", + "reject", + "release", + "rem", + "report", + "restrict", + "return", + "rol", + "ror", + "select", + "sequence", + "severity", + "signal", + "shared", + "sla", + "sll", + "sra", + "srl", + "strong", + "subtype", + "then", + "to", + "transport", + "type", + "unaffected", + "units", + "until", + "use", + "variable", + "view", + "vpkg", + "vmode", + "vprop", + "vunit", + "wait", + "when", + "while", + "with", + "xnor", + "xor", + ) + return identifier.lower() in reserved_keywords diff --git a/b_asic/codegen/vhdl/entity.py b/b_asic/codegen/vhdl/entity.py index 9673baac6c4a88023b51ca1d97cd70de66dd5902..f13d1a1777b65ee70d11dd962146c4292603784a 100755 --- a/b_asic/codegen/vhdl/entity.py +++ b/b_asic/codegen/vhdl/entity.py @@ -1,18 +1,16 @@ """ Module for code generation of VHDL entity declarations """ -from io import TextIOWrapper -from typing import Set +from typing import Set, TextIO -from b_asic.codegen import vhdl -from b_asic.codegen.vhdl import VHDL_TAB +from b_asic.codegen.vhdl import VHDL_TAB, write_lines from b_asic.port import Port from b_asic.process import MemoryVariable, PlainMemoryVariable from b_asic.resources import ProcessCollection def memory_based_storage( - f: TextIOWrapper, entity_name: str, collection: ProcessCollection, word_length: int + f: TextIO, entity_name: str, collection: ProcessCollection, word_length: int ): # Check that this is a ProcessCollection of (Plain)MemoryVariables is_memory_variable = all( @@ -29,7 +27,7 @@ def memory_based_storage( entity_name = entity_name # Write the entity header - vhdl.write_lines( + write_lines( f, [ (0, f'entity {entity_name} is'), @@ -42,7 +40,7 @@ def memory_based_storage( ) # Write the clock and reset signal - vhdl.write_lines( + write_lines( f, [ (0, '-- Clock, synchronous reset and enable signals'), @@ -56,7 +54,7 @@ def memory_based_storage( # Write the input port specification f.write(f'{2*VHDL_TAB}-- Memory port I/O\n') read_ports: set[Port] = set( - sum((mv.read_ports for mv in collection), ()) + read_port for mv in collection for read_port in mv.read_ports ) # type: ignore for idx, read_port in enumerate(read_ports): port_name = read_port if isinstance(read_port, int) else read_port.name @@ -80,6 +78,6 @@ def memory_based_storage( def register_based_storage( - f: TextIOWrapper, entity_name: str, collection: ProcessCollection, word_length: int + f: TextIO, entity_name: str, collection: ProcessCollection, word_length: int ): memory_based_storage(f, entity_name, collection, word_length) diff --git a/b_asic/gui_utils/icons.py b/b_asic/gui_utils/icons.py index 8ba60f4641ea0c3c26d68f69c34364081c407fcb..85fa6c818eafe920ef5e771d37a38f3086d96dc2 100755 --- a/b_asic/gui_utils/icons.py +++ b/b_asic/gui_utils/icons.py @@ -31,8 +31,8 @@ ICONS = { 'reorder': ('msc.graph-left', {'rotated': -90}), 'full-screen': 'mdi6.fullscreen', 'full-screen-exit': 'mdi6.fullscreen-exit', - 'warning': 'fa.warning', - 'port-numbers': 'fa.hashtag', + 'warning': 'fa5s.exclamation-triangle', + 'port-numbers': 'fa5s.hashtag', } diff --git a/b_asic/process.py b/b_asic/process.py index 626d04d4f30fa56b9e414c8fefbb90dc5f6c5e47..bf0c2ef3c9c3df50919bdf989e8990d5ba4c2c7a 100755 --- a/b_asic/process.py +++ b/b_asic/process.py @@ -1,6 +1,6 @@ """B-ASIC classes representing resource usage.""" -from typing import Dict, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, cast from b_asic.operation import Operation from b_asic.port import InputPort, OutputPort @@ -105,7 +105,149 @@ class OperatorProcess(Process): return f"OperatorProcess({self.start_time}, {self.operation}, {self.name!r})" -class MemoryVariable(Process): +class MemoryProcess(Process): + """ + Intermediate class (abstract) for memory processes. + + Different from regular :class:`Processe` objects, :class:`MemoryProcess` objects + can contain multiple read accesses and can be split into two new + :class:`MemoryProcess` objects based on these read times. + + Parameters + ---------- + write_time : int + Start time of process. + life_times : list of int + List of ints representing times after ``start_time`` this process is accessed. + name : str, default="" + Name of the process. + """ + + def __init__( + self, + write_time: int, + life_times: List[int], + name: str = "", + ): + pass + self._life_times = life_times + super().__init__( + start_time=write_time, + execution_time=max(self._life_times), + name=name, + ) + + @property + def read_times(self) -> List[int]: + return list(self.start_time + read for read in self._life_times) + + @property + def life_times(self) -> List[int]: + return self._life_times + + @property + def reads(self) -> Dict[Any, int]: + raise NotImplementedError("MultiReadProcess should be derived from") + + @property + def read_ports(self) -> List[Any]: + raise NotImplementedError("MultiReadProcess should be derived from") + + @property + def write_port(self) -> Any: + raise NotImplementedError("MultiReadProcess should be derived from") + + def split_on_length( + self, + length: int = 0, + ) -> Tuple[Optional["MemoryProcess"], Optional["MemoryProcess"]]: + """ + Split this :class:`MemoryProcess` into two new :class:`MemoryProcess` objects, + based on lifetimes of the read accesses. + + Parameters + ---------- + length : int, default: 0 + The lifetime length to split on. Length is inclusive for the smaller + process. + + Returns + ------- + Two-tuple where the first element is a :class:`MemoryProcess` consisting + of reads with read times smaller than or equal to ``length`` (or None if no such + reads exists), and vice-versa for the other tuple element. + """ + reads = self.reads + short_reads = {k: v for k, v in filter(lambda t: t[1] <= length, reads.items())} + long_reads = {k: v for k, v in filter(lambda t: t[1] > length, reads.items())} + short_process = None + long_process = None + if short_reads: + # Create a new Process of type self (which is a derived variant of + # MultiReadProcess) by calling the self constructor + short_process = type(self)( + self.start_time, # type: ignore + self.write_port, # type: ignore + short_reads, # type: ignore + self.name, # type: ignore + ) + if long_reads: + # Create a new Process of type self (which is a derived variant of + # MultiReadProcess) by calling the self constructor + long_process = type(self)( + self.start_time, # type: ignore + self.write_port, # type: ignore + long_reads, # type: ignore + self.name, # type: ignore + ) + return short_process, long_process + + def _add_life_time(self, life_time: int): + """ + Add a lifetime to this :class:`~b_asic.process.MultiReadProcess` set of + lifetimes. + + If the lifetime specified by ``life_time`` is already in this + :class:`~b_asic.process.MultiReadProcess`, nothing happens + + After adding a lifetime from this :class:`~b_asic.process.MultiReadProcess`, + the execution time is re-evaluated. + + Parameters + ---------- + life_time : int + The lifetime to add to this :class:`~b_asic.process.MultiReadProcess`. + """ + if life_time not in self.life_times: + self._life_times.append(life_time) + self._execution_time = max(self.life_times) + + def _remove_life_time(self, life_time: int): + """ + Remove a lifetime from this :class:`~b_asic.process.MultiReadProcess` + set of lifetimes. + + After removing a lifetime from this :class:`~b_asic.process.MultiReadProcess`, + the execution time is re-evaluated. + + Raises :class:`KeyError` if the specified lifetime is not a lifetime of this + :class:`~b_asic.process.MultiReadProcess`. + + Parameters + ---------- + life_time : int + The lifetime to remove from this :class:`~b_asic.process.MultiReadProcess`. + """ + if life_time not in self.life_times: + raise KeyError( + f"Process {self.name}: {life_time} not in life_times: {self.life_times}" + ) + else: + self._life_times.remove(life_time) + self._execution_time = max(self.life_times) + + +class MemoryVariable(MemoryProcess): """ Object that corresponds to a memory variable. @@ -130,13 +272,12 @@ class MemoryVariable(Process): reads: Dict[InputPort, int], name: Optional[str] = None, ): - self._read_ports = tuple(reads.keys()) - self._life_times = tuple(reads.values()) + self._read_ports = list(reads.keys()) self._reads = reads self._write_port = write_port super().__init__( - start_time=write_time, - execution_time=max(self._life_times), + write_time=write_time, + life_times=list(reads.values()), name=name, ) @@ -145,11 +286,7 @@ class MemoryVariable(Process): return self._reads @property - def life_times(self) -> Tuple[int, ...]: - return self._life_times - - @property - def read_ports(self) -> Tuple[InputPort, ...]: + def read_ports(self) -> List[InputPort]: return self._read_ports @property @@ -163,12 +300,36 @@ class MemoryVariable(Process): f" {reads!r}, {self.name!r})" ) - @property - def read_times(self) -> Tuple[int, ...]: - return tuple(self.start_time + read for read in self._life_times) + def split_on_length( + self, + length: int = 0, + ) -> Tuple[Optional["MemoryVariable"], Optional["MemoryVariable"]]: + """ + Split this :class:`MemoryVariable` into two new :class:`MemoryVariable` objects, + based on lifetimes of read accesses. + + Parameters + ---------- + length : int, default: 0 + The lifetime length to split on. Length is inclusive for the smaller + process. + + Returns + ------- + Two-tuple where the first element is a :class:`MemoryVariable` consisting + of reads with read times smaller than or equal to ``length`` (or None if no such + reads exists), and vice-versa for the other tuple element. + """ + # This method exists only for documentation purposes and for generating correct + # type annotations when calling it. Just call super().split_on_length() in here. + short_process, long_process = super().split_on_length(length) + return ( + cast(Optional["MemoryVariable"], short_process), + cast(Optional["MemoryVariable"], long_process), + ) -class PlainMemoryVariable(Process): +class PlainMemoryVariable(MemoryProcess): """ Object that corresponds to a memory variable which only use numbers for ports. @@ -196,8 +357,7 @@ class PlainMemoryVariable(Process): reads: Dict[int, int], name: Optional[str] = None, ): - self._read_ports = tuple(reads.keys()) - self._life_times = tuple(reads.values()) + self._read_ports = list(reads.keys()) self._write_port = write_port self._reads = reads if name is None: @@ -205,8 +365,8 @@ class PlainMemoryVariable(Process): PlainMemoryVariable._name_cnt += 1 super().__init__( - start_time=write_time, - execution_time=max(self._life_times), + write_time=write_time, + life_times=list(reads.values()), name=name, ) @@ -215,11 +375,7 @@ class PlainMemoryVariable(Process): return self._reads @property - def life_times(self) -> Tuple[int, ...]: - return self._life_times - - @property - def read_ports(self) -> Tuple[int, ...]: + def read_ports(self) -> List[int]: return self._read_ports @property @@ -233,9 +389,33 @@ class PlainMemoryVariable(Process): f" {reads!r}, {self.name!r})" ) - @property - def read_times(self) -> Tuple[int, ...]: - return tuple(self.start_time + read for read in self._life_times) + def split_on_length( + self, + length: int = 0, + ) -> Tuple[Optional["PlainMemoryVariable"], Optional["PlainMemoryVariable"]]: + """ + Split this :class:`PlainMemoryVariable` into two new + :class:`PlainMemoryVariable` objects, based on lifetimes of read accesses. + + Parameters + ---------- + length : int, default: 0 + The lifetime length to split on. Length is inclusive for the smaller + process. + + Returns + ------- + Two-tuple where the first element is a :class:`PlainMemoryVariable` consisting + of reads with read times smaller than or equal to ``length`` (or None if no such + reads exists), and vice-versa for the other tuple element. + """ + # This method exists only for documentation purposes and for generating correct + # type annotations when calling it. Just call super().split_on_length() in here. + short_process, long_process = super().split_on_length(length) + return ( + cast(Optional["PlainMemoryVariable"], short_process), + cast(Optional["PlainMemoryVariable"], long_process), + ) # Static counter for default names _name_cnt = 0 diff --git a/b_asic/quantization.py b/b_asic/quantization.py index e491bca301d10dddff0be862571d7b7369c62683..5de9a51720da2340247b255aa5aadbab546bbe28 100755 --- a/b_asic/quantization.py +++ b/b_asic/quantization.py @@ -19,11 +19,14 @@ class Quantization(Enum): "Magnitude truncation, i.e., round towards zero." JAMMING = 4 - "Jamming/von Neumann rounding, i.e., set the LSB to one" + "Jamming/von Neumann rounding, i.e., set the LSB to one." UNBIASED_ROUNDING = 5 "Unbiased rounding, i.e., tie rounds towards even." + UNBIASED_JAMMING = 6 + "Unbiased jamming/von Neumann rounding." + class Overflow(Enum): """Overflow types.""" @@ -33,7 +36,7 @@ class Overflow(Enum): SATURATION = 2 """ - Two's complement saturation, i.e., overflow return the most postive/negative + Two's complement saturation, i.e., overflow return the most positive/negative number. """ @@ -125,8 +128,13 @@ def quantize( v = math.ceil(v) elif quantization is Quantization.JAMMING: v = math.floor(v) | 1 - else: # Quantization.UNBIASED_ROUNDING + elif quantization is Quantization.UNBIASED_ROUNDING: v = round(v) + elif quantization is Quantization.UNBIASED_JAMMING: + f = math.floor(v) + v = f if v - f == 0 else f | 1 + else: + raise TypeError("Unknown quantization method: {quantization!r}") v = v / b i = 2 ** (integer_bits - 1) diff --git a/b_asic/resources.py b/b_asic/resources.py index e6f27bb44dacbbc1aa87458a8d6cecc2ede4f419..ef1204ab6032a3819b772e09bcda052df8996f36 100755 --- a/b_asic/resources.py +++ b/b_asic/resources.py @@ -2,7 +2,7 @@ import io import re from collections import Counter from functools import reduce -from typing import Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union +from typing import Dict, Iterable, List, Optional, Tuple, TypeVar, Union import matplotlib.pyplot as plt import networkx as nx @@ -10,7 +10,14 @@ from matplotlib.axes import Axes from matplotlib.ticker import MaxNLocator from b_asic._preferences import LATENCY_COLOR, WARNING_COLOR -from b_asic.process import MemoryVariable, OperatorProcess, PlainMemoryVariable, Process +from b_asic.codegen.vhdl.common import is_valid_vhdl_identifier +from b_asic.process import ( + MemoryProcess, + MemoryVariable, + OperatorProcess, + PlainMemoryVariable, + Process, +) from b_asic.types import TypeName # Default latency coloring RGB tuple @@ -83,7 +90,7 @@ def _sanitize_port_option( raise ValueError( f'Total ports ({total_ports}) less then write ports ({write_ports})' ) - return (read_ports, write_ports, total_ports) + return read_ports, write_ports, total_ports def draw_exclusion_graph_coloring( @@ -93,21 +100,30 @@ def draw_exclusion_graph_coloring( color_list: Optional[Union[List[str], List[Tuple[float, float, float]]]] = None, ) -> None: """ - Draw a colored exclusion graph from the memory assignment. + Helper function for drawing a colored exclusion graphs. + + Example usage: .. code-block:: python - _, ax = plt.subplots(1, 1) + import networkx as nx + import matplotlib.pyplot as plt + + _, ax = plt.subplots() collection = ProcessCollection(...) - exclusion_graph = collection.create_exclusion_graph_from_overlap() - color_dict = nx.greedy_color(exclusion_graph) - draw_exclusion_graph_coloring(exclusion_graph, color_dict, ax=ax[0]) + exclusion_graph = collection.create_exclusion_graph_from_ports( + read_ports = 1, + write_ports = 1, + total_ports = 2, + ) + coloring = nx.greedy_color(exclusion_graph) + draw_exclusion_graph_coloring(exclusion_graph, coloring, ax=ax) plt.show() Parameters ---------- - exclusion_graph : nx.Graph - A nx.Graph exclusion graph object that is to be drawn. + exclusion_graph : :class:`networkx.Graph` + The :class:`networkx.Graph` exclusion graph object that is to be drawn. color_dict : dict A dict where keys are :class:`~b_asic.process.Process` objects and values are integers representing colors. These dictionaries are automatically generated by @@ -208,12 +224,12 @@ class _ForwardBackwardTable: ProcessCollection to apply forward-backward allocation on """ # Generate an alive variable list - self._collection = collection - self._live_variables: List[int] = [0] * collection._schedule_time + self._collection = set(collection.collection) + self._live_variables: List[int] = [0] * collection.schedule_time for mv in self._collection: stop_time = mv.start_time + mv.execution_time for alive_time in range(mv.start_time, stop_time): - self._live_variables[alive_time % collection._schedule_time] += 1 + self._live_variables[alive_time % collection.schedule_time] += 1 # First, create an empty forward-backward table with the right dimensions self.table: List[_ForwardBackwardEntry] = [] @@ -250,7 +266,7 @@ class _ForwardBackwardTable: def _forward_backward_is_complete(self) -> bool: s = {proc for e in self.table for proc in e.outputs} - return len(self._collection._collection - s) == 0 + return len(self._collection - s) == 0 def _do_forward_allocation(self): """ @@ -412,31 +428,35 @@ class _ForwardBackwardTable: class ProcessCollection: - """ - Collection of one or more processes + r""" + Collection of :class:`~b_asic.process.Process` objects. Parameters ---------- - collection : set of :class:`~b_asic.process.Process` objects - The Process objects forming this ProcessCollection. + collection : Iterable of :class:`~b_asic.process.Process` objects + The :class:`~b_asic.process.Process` objects forming this + :class:`~b_asic.resources.ProcessCollection`. schedule_time : int - Length of the time-axis in the generated graph. + The scheduling time associated with this + :class:`~b_asic.resources.ProcessCollection`. cyclic : bool, default: False - If the processes operates cyclically, i.e., if time 0 == time *schedule_time*. + Whether the processes operates cyclically, i.e., if time + + .. math:: t = t \bmod T_{\textrm{schedule}}. """ def __init__( self, - collection: Set[Process], + collection: Iterable[Process], schedule_time: int, cyclic: bool = False, ): - self._collection = collection + self._collection = list(collection) self._schedule_time = schedule_time self._cyclic = cyclic @property - def collection(self) -> Set[Process]: + def collection(self) -> List[Process]: return self._collection @property @@ -448,25 +468,30 @@ class ProcessCollection: def add_process(self, process: Process): """ - Add a new process to this process collection. + Add a new :class:`~b_asic.process.Process` to this + :class:`~b_asic.resources.ProcessCollection`. Parameters ---------- - process : Process - The process object to be added to the collection. + process : :class:`~b_asic.process.Process` + The :class:`~b_asic.process.Process` object to add. """ - self.collection.add(process) + if process in self.collection: + raise ValueError("Process already in ProcessCollection") + self.collection.append(process) def remove_process(self, process: Process): """ - Remove a processes from this process collection. + Remove a :class:`~b_asic.process.Process` from this + :class:`~b_asic.resources.ProcessCollection`. - Raises KeyError if the process is not in this collection. + Raises :class:`KeyError` if the specified :class:`~b_asic.process.Process` is + not in this collection. Parameters ---------- - process : Process - The processes object to remove from this collection + process : :class:`~b_asic.process.Process` + The :class:`~b_asic.process.Process` object to remove from this collection. """ if process not in self.collection: raise KeyError( @@ -489,7 +514,16 @@ class ProcessCollection: allow_excessive_lifetimes: bool = False, ): """ - Plot a process variable lifetime chart. + Plot all :class:`~b_asic.process.Process` objects of this + :class:`~b_asic.resources.ProcessCollection` in a lifetime diagram. + + If the ``ax`` parameter is not specified, a new Matplotlib figure is created. + + Raises :class:`KeyError` if any :class:`~b_asic.process.Process` lifetime + excedes this :class:`~b_asic.resources.ProcessCollection` schedule time, + unless ``allow_excessive_lifetimes`` is specified. In that case, + :class:`~b_asic.process.Process` objects whose lifetime exceed the scheudle + time are drawn using the B-ASIC warning color. Parameters ---------- @@ -617,8 +651,12 @@ class ProcessCollection: ) _ax.grid(True) # type: ignore - _ax.xaxis.set_major_locator(MaxNLocator(integer=True)) # type: ignore - _ax.yaxis.set_major_locator(MaxNLocator(integer=True)) # type: ignore + _ax.xaxis.set_major_locator( + MaxNLocator(integer=True, min_n_ticks=1) + ) # type: ignore + _ax.yaxis.set_major_locator( + MaxNLocator(integer=True, min_n_ticks=1) + ) # type: ignore _ax.set_xlim(0, self._schedule_time) # type: ignore if row is None: _ax.set_ylim(0.25, len(self._collection) + 0.75) # type: ignore @@ -639,7 +677,8 @@ class ProcessCollection: title: Optional[str] = None, ) -> None: """ - Show the process collection using the current Matplotlib backend. + Display this :class:`~b_asic.resources.ProcessCollection` using the current + Matplotlib backend. Equivalent to creating a Matplotlib figure, passing it and arguments to :meth:`plot` and invoking :py:meth:`matplotlib.figure.Figure.show`. @@ -686,7 +725,9 @@ class ProcessCollection: total_ports: Optional[int] = None, ) -> nx.Graph: """ - Create an exclusion graph based on a number of read/write ports. + Create an exclusion graph from a given number of read and write ports based on + concurrent read and write accesses to this + :class:`~b_asic.resources.ProcessCollection`. Parameters ---------- @@ -702,7 +743,7 @@ class ProcessCollection: Returns ------- - nx.Graph + A :class:`networkx.Graph` object. """ read_ports, write_ports, total_ports = _sanitize_port_option( @@ -725,9 +766,10 @@ class ProcessCollection: exclusion_graph = nx.Graph() exclusion_graph.add_nodes_from(self._collection) for node1 in exclusion_graph: - node1_stop_times = tuple( + node1_stop_times = set( read_time % self.schedule_time for read_time in node1.read_times ) + node1_start_time = node1.start_time % self.schedule_time if total_ports == 1 and node1.start_time in node1_stop_times: print(node1.start_time, node1_stop_times) raise ValueError("Cannot read and write in same cycle.") @@ -738,34 +780,27 @@ class ProcessCollection: node2_stop_times = tuple( read_time % self.schedule_time for read_time in node2.read_times ) - for node1_stop_time in node1_stop_times: - for node2_stop_time in node2_stop_times: - if total_ports == 1: - # Single-port assignment - if node1.start_time == node2.start_time: - exclusion_graph.add_edge(node1, node2) - elif node1_stop_time == node2_stop_time: - exclusion_graph.add_edge(node1, node2) - elif node1.start_time == node2_stop_time: - exclusion_graph.add_edge(node1, node2) - elif node1_stop_time == node2.start_time: - exclusion_graph.add_edge(node1, node2) - else: - # Dual-port assignment - if node1.start_time == node2.start_time: - exclusion_graph.add_edge(node1, node2) - elif node1_stop_time == node2_stop_time: - exclusion_graph.add_edge(node1, node2) + node2_start_time = node2.start_time % self.schedule_time + if write_ports == 1 and node1_start_time == node2_start_time: + exclusion_graph.add_edge(node1, node2) + if read_ports == 1 and node1_stop_times.intersection( + node2_stop_times + ): + exclusion_graph.add_edge(node1, node2) + if total_ports == 1 and ( + node1_start_time in node2_stop_times + or node2_start_time in node1_stop_times + ): + exclusion_graph.add_edge(node1, node2) return exclusion_graph def create_exclusion_graph_from_execution_time(self) -> nx.Graph: """ - Generate exclusion graph based on processes overlapping in time + Create an exclusion graph from processes overlapping in execution time. Returns ------- - An nx.Graph exclusion graph where nodes are processes and arcs - between two processes indicated overlap in time + A :class:`networkx.Graph` object. """ exclusion_graph = nx.Graph() exclusion_graph.add_nodes_from(self._collection) @@ -818,7 +853,7 @@ class ProcessCollection: self, heuristic: str = "graph_color", coloring_strategy: str = "saturation_largest_first", - ) -> Set["ProcessCollection"]: + ) -> List["ProcessCollection"]: """ Split a ProcessCollection based on overlapping execution time. @@ -843,14 +878,10 @@ class ProcessCollection: Returns ------- - A set of new ProcessCollection objects with the process splitting. + A list of new ProcessCollection objects with the process splitting. """ if heuristic == "graph_color": - exclusion_graph = self.create_exclusion_graph_from_execution_time() - coloring = nx.coloring.greedy_color( - exclusion_graph, strategy=coloring_strategy - ) - return self._split_from_graph_coloring(coloring) + return self._graph_color_assignment(coloring_strategy) elif heuristic == "left_edge": return self._left_edge_assignment() else: @@ -862,7 +893,7 @@ class ProcessCollection: read_ports: Optional[int] = None, write_ports: Optional[int] = None, total_ports: Optional[int] = None, - ) -> Set["ProcessCollection"]: + ) -> List["ProcessCollection"]: """ Split this process storage based on concurrent read/write times according. @@ -907,7 +938,7 @@ class ProcessCollection: write_ports: int, total_ports: int, coloring_strategy: str = "saturation_largest_first", - ) -> Set["ProcessCollection"]: + ) -> List["ProcessCollection"]: """ Parameters ---------- @@ -944,7 +975,7 @@ class ProcessCollection: def _split_from_graph_coloring( self, coloring: Dict[Process, int], - ) -> Set["ProcessCollection"]: + ) -> List["ProcessCollection"]: """ Split :class:`Process` objects into a set of :class:`ProcessesCollection` objects based on a provided graph coloring. @@ -961,13 +992,13 @@ class ProcessCollection: ------- A set of new ProcessCollections. """ - process_collection_set_list = [set() for _ in range(max(coloring.values()) + 1)] + process_collection_set_list = [[] for _ in range(max(coloring.values()) + 1)] for process, color in coloring.items(): - process_collection_set_list[color].add(process) - return { + process_collection_set_list[color].append(process) + return [ ProcessCollection(process_collection_set, self._schedule_time, self._cyclic) for process_collection_set in process_collection_set_list - } + ] def _repr_svg_(self) -> str: """ @@ -980,6 +1011,9 @@ class ProcessCollection: fig.savefig(f, format="svg") # type: ignore return f.getvalue() + # SVG is valid HTML. This is useful for e.g. sphinx-gallery + _repr_html_ = _repr_svg_ + def __repr__(self): return ( f"ProcessCollection({self._collection}, {self._schedule_time}," @@ -1016,6 +1050,13 @@ class ProcessCollection: List[ProcessCollection] """ + for process in self: + if process.execution_time > self.schedule_time: + # Can not assign process to any cell + raise ValueError( + f"{process} has execution time greater than the schedule time" + ) + cell_assignment: Dict[int, ProcessCollection] = dict() exclusion_graph = self.create_exclusion_graph_from_execution_time() if coloring is None: @@ -1024,7 +1065,7 @@ class ProcessCollection: ) for process, cell in coloring.items(): if cell not in cell_assignment: - cell_assignment[cell] = ProcessCollection(set(), self._schedule_time) + cell_assignment[cell] = ProcessCollection([], self._schedule_time) cell_assignment[cell].add_process(process) return list(cell_assignment.values()) @@ -1036,47 +1077,71 @@ class ProcessCollection: Two or more processes can share a single resource if, and only if, they have no overlaping execution time. + Raises :class:`ValueError` if any process in this collection has an execution + time which is greater than the collection schedule time. + Returns ------- List[ProcessCollection] """ - next_empty_cell = 0 - cell_assignment: Dict[int, ProcessCollection] = dict() + assignment: List[ProcessCollection] = [] for next_process in sorted(self): - insert_to_new_cell = True - for cell in cell_assignment: - insert_to_this_cell = True - for process in cell_assignment[cell]: - next_process_stop_time = ( - next_process.start_time + next_process.execution_time - ) % self._schedule_time - if ( - next_process.start_time - < process.start_time + process.execution_time - or next_process.start_time - > next_process_stop_time - > process.start_time - ): - insert_to_this_cell = False - break - if insert_to_this_cell: - cell_assignment[cell].add_process(next_process) - insert_to_new_cell = False - break - if insert_to_new_cell: - cell_assignment[next_empty_cell] = ProcessCollection( - collection=set(), schedule_time=self._schedule_time + if next_process.execution_time > self.schedule_time: + # Can not assign process to any cell + raise ValueError( + f"{next_process} has execution time greater than the schedule time" ) - cell_assignment[next_empty_cell].add_process(next_process) - next_empty_cell += 1 - return [pc for pc in cell_assignment.values()] + elif next_process.execution_time == self.schedule_time: + # Always assign maximum lifetime process to new cell + assignment.append( + ProcessCollection( + (next_process,), + schedule_time=self.schedule_time, + cyclic=self._cyclic, + ) + ) + continue # Continue assigning next process + else: + next_process_stop_time = ( + next_process.start_time + next_process.execution_time + ) % self._schedule_time + insert_to_new_cell = True + for cell_assignment in assignment: + insert_to_this_cell = True + for process in cell_assignment: + # The next_process start_time is always greater than or equal to + # the start time of all other assigned processes + process_end_time = process.start_time + process.execution_time + if next_process.start_time < process_end_time: + insert_to_this_cell = False + break + if ( + next_process.start_time + > next_process_stop_time + > process.start_time + ): + insert_to_this_cell = False + break + if insert_to_this_cell: + cell_assignment.add_process(next_process) + insert_to_new_cell = False + break + if insert_to_new_cell: + assignment.append( + ProcessCollection( + (next_process,), + schedule_time=self.schedule_time, + cyclic=self._cyclic, + ) + ) + return assignment def generate_memory_based_storage_vhdl( self, filename: str, entity_name: str, word_length: int, - assignment: Set['ProcessCollection'], + assignment: List['ProcessCollection'], read_ports: int = 1, write_ports: int = 1, total_ports: int = 2, @@ -1093,7 +1158,7 @@ class ProcessCollection: Name used for the VHDL entity. word_length : int Word length of the memory variable objects. - assignment : set + assignment : list A possible cell assignment to use when generating the memory based storage. The cell assignment is a dictionary int to ProcessCollection where the integer corresponds to the cell to assign all MemoryVariables in @@ -1116,6 +1181,10 @@ class ProcessCollection: (which is added automatically). For large interleavers, this can improve timing significantly. """ + # Check that entity name is a valid VHDL identifier + if not is_valid_vhdl_identifier(entity_name): + raise KeyError(f'{entity_name} is not a valid identifier') + # Check that this is a ProcessCollection of (Plain)MemoryVariables is_memory_variable = all( isinstance(process, MemoryVariable) for process in self._collection @@ -1134,7 +1203,7 @@ class ProcessCollection: read_ports, write_ports, total_ports ) - # Make sure the provided assignment (Set[ProcessCollection]) only + # Make sure the provided assignment (List[ProcessCollection]) only # contains memory variables from this (self). for collection in assignment: for mv in collection: @@ -1202,16 +1271,26 @@ class ProcessCollection: A tuple of two ProcessCollections, one with shorter than or equal execution times and one with longer execution times. """ - short = set() - long = set() + short = [] + long = [] for process in self.collection: if process.execution_time <= length: - short.add(process) + short.append(process) else: - long.add(process) - return ProcessCollection( - short, schedule_time=self.schedule_time - ), ProcessCollection(long, schedule_time=self.schedule_time) + if isinstance(process, MemoryProcess): + # Split this MultiReadProcess into two new processes + p_short, p_long = process.split_on_length(length) + if p_short is not None: + short.append(p_short) + if p_long is not None: + long.append(p_long) + else: + # Not a MultiReadProcess: has only a single read + long.append(process) + return ( + ProcessCollection(short, self.schedule_time, self._cyclic), + ProcessCollection(long, self.schedule_time, self._cyclic), + ) def generate_register_based_storage_vhdl( self, @@ -1227,7 +1306,7 @@ class ProcessCollection: Forward-Backward Register Allocation [1]. [1]: K. Parhi: VLSI Digital Signal Processing Systems: Design and - Implementation, Ch. 6.3.2 + Implementation, Ch. 6.3.2 Parameters ---------- @@ -1249,6 +1328,10 @@ class ProcessCollection: The total number of ports used when splitting process collection based on memory variable access. """ + # Check that entity name is a valid VHDL identifier + if not is_valid_vhdl_identifier(entity_name): + raise KeyError(f'{entity_name} is not a valid identifier') + # Check that this is a ProcessCollection of (Plain)MemoryVariables is_memory_variable = all( isinstance(process, MemoryVariable) for process in self._collection @@ -1271,14 +1354,14 @@ class ProcessCollection: forward_backward_table = _ForwardBackwardTable(self) with open(filename, 'w') as f: - from b_asic.codegen import vhdl + from b_asic.codegen.vhdl import architecture, common, entity - vhdl.common.b_asic_preamble(f) - vhdl.common.ieee_header(f) - vhdl.entity.register_based_storage( + common.b_asic_preamble(f) + common.ieee_header(f) + entity.register_based_storage( f, entity_name=entity_name, collection=self, word_length=word_length ) - vhdl.architecture.register_based_storage( + architecture.register_based_storage( f, forward_backward_table=forward_backward_table, entity_name=entity_name, @@ -1290,16 +1373,17 @@ class ProcessCollection: def get_by_type_name(self, type_name: TypeName) -> "ProcessCollection": """ - Return a ProcessCollection with only a given type of operation. + Return a new :class:`~b_asic.resources.ProcessCollection` with only a given + type of operation. Parameters ---------- type_name : TypeName - The type_name of the operation. + The TypeName of the operation to extract. Returns ------- - ProcessCollection + A new :class:`~b_asic.resources.ProcessCollection`. """ return ProcessCollection( @@ -1314,6 +1398,14 @@ class ProcessCollection: ) def read_ports_bound(self) -> int: + """ + Get the read port lower-bound (maximum number of concurrent reads) of this + :class:`~b_asic.resources.ProcessCollection`. + + Returns + ------- + int + """ reads = [] for process in self._collection: reads.extend( @@ -1323,6 +1415,32 @@ class ProcessCollection: return max(count.values()) def write_ports_bound(self) -> int: + """ + Get the write port lower-bound (maximum number of concurrent writes) of this + :class:`~b_asic.resources.ProcessCollection`. + + Returns + ------- + int + """ writes = [process.start_time for process in self._collection] count = Counter(writes) return max(count.values()) + + def from_name(self, name: str): + """ + Get a :class:`~b_asic.process.Process` from this collection from its name. + + Raises :class:`KeyError` if no processes with ``name`` is found in this + colleciton. + + Parameters + ---------- + name : str + The name of the process to retrieve. + """ + name_to_proc = {p.name: p for p in self.collection} + if name not in name_to_proc: + raise KeyError(f'{name} not in {self}') + else: + return name_to_proc[name] diff --git a/b_asic/schedule.py b/b_asic/schedule.py index 61b5bd847f0605c2921c063bb137a5fd3c537a1e..db668d10c300123e86666ed0f04181207afe4b12 100755 --- a/b_asic/schedule.py +++ b/b_asic/schedule.py @@ -7,7 +7,7 @@ Contains the schedule class for scheduling operations in an SFG. import io import sys from collections import defaultdict -from typing import Dict, List, Optional, Sequence, Tuple, cast +from typing import Dict, List, Optional, Sequence, Tuple, Union, cast import matplotlib.pyplot as plt import numpy as np @@ -35,11 +35,18 @@ from b_asic.process import MemoryVariable, OperatorProcess from b_asic.resources import ProcessCollection from b_asic.signal_flow_graph import SFG from b_asic.special_operations import Delay, Input, Output +from b_asic.types import TypeName # Need RGB from 0 to 1 -_EXECUTION_TIME_COLOR = tuple(c / 255 for c in EXECUTION_TIME_COLOR) -_LATENCY_COLOR = tuple(c / 255 for c in LATENCY_COLOR) -_SIGNAL_COLOR = tuple(c / 255 for c in SIGNAL_COLOR) +_EXECUTION_TIME_COLOR: Union[ + Tuple[float, float, float], Tuple[float, float, float, float] +] = tuple(float(c / 255) for c in EXECUTION_TIME_COLOR) +_LATENCY_COLOR: Union[ + Tuple[float, float, float], Tuple[float, float, float, float] +] = tuple(float(c / 255) for c in LATENCY_COLOR) +_SIGNAL_COLOR: Union[ + Tuple[float, float, float], Tuple[float, float, float, float] +] = tuple(float(c / 255) for c in SIGNAL_COLOR) def _laps_default(): @@ -337,10 +344,10 @@ class Schedule: -------- get_max_time """ - if time < self.get_max_end_time(): + max_end_time = self.get_max_end_time() + if time < max_end_time: raise ValueError( - f"New schedule time ({time}) too short, minimum:" - f" {self.get_max_end_time()}." + f"New schedule time ({time}) too short, minimum: {max_end_time}." ) self._schedule_time = time return self @@ -442,11 +449,11 @@ class Schedule: def get_possible_time_resolution_decrements(self) -> List[int]: """Return a list with possible factors to reduce time resolution.""" vals = self._get_all_times() - maxloop = min(val for val in vals if val) - if maxloop <= 1: + max_loop = min(val for val in vals if val) + if max_loop <= 1: return [1] ret = [1] - for candidate in range(2, maxloop + 1): + for candidate in range(2, max_loop + 1): if not any(val % candidate for val in vals): ret.append(candidate) return ret @@ -478,6 +485,23 @@ class Schedule: self._schedule_time = self._schedule_time // factor return self + def set_execution_time_of_type( + self, type_name: TypeName, execution_time: int + ) -> None: + """ + Set the execution time of all operations with the given type name. + + Parameters + ---------- + type_name : TypeName + The type name of the operation. For example, obtained as + ``Addition.type_name()``. + execution_time : int + The execution time of the operation. + """ + self._sfg.set_execution_time_of_type(type_name, execution_time) + self._original_sfg.set_execution_time_of_type(type_name, execution_time) + def move_y_location( self, graph_id: GraphID, new_y: int, insert: bool = False ) -> None: @@ -557,7 +581,6 @@ class Schedule: time : int The time to move. If positive move forward, if negative move backward. """ - print(f"schedule.move_operation({graph_id!r}, {time})") if graph_id not in self._start_times: raise ValueError(f"No operation with graph_id {graph_id} in schedule") @@ -683,7 +706,6 @@ class Schedule: ) source_port = inport.signals[0].source - source_end_time = None if source_port.operation.graph_id in non_schedulable_ops: source_end_time = 0 else: @@ -741,7 +763,8 @@ class Schedule: } ret.append( MemoryVariable( - start_time + cast(int, outport.latency_offset), + (start_time + cast(int, outport.latency_offset)) + % self.schedule_time, outport, reads, outport.name, @@ -775,13 +798,19 @@ class Schedule: """ return ProcessCollection( { - OperatorProcess(start_time, self._sfg.find_by_id(graph_id)) + OperatorProcess( + start_time, cast(Operation, self._sfg.find_by_id(graph_id)) + ) for graph_id, start_time in self._start_times.items() }, self.schedule_time, self.cyclic, ) + def get_used_type_names(self) -> List[TypeName]: + """Get a list of all TypeNames used in the Schedule.""" + return self._sfg.get_used_type_names() + def _get_y_position( self, graph_id, operation_height=1.0, operation_gap=None ) -> float: @@ -965,7 +994,7 @@ class Schedule: + (OPERATION_GAP if operation_gap is None else operation_gap) ) ax.axis([-1, self._schedule_time + 1, y_position_max, 0]) # Inverted y-axis - ax.xaxis.set_major_locator(MaxNLocator(integer=True)) + ax.xaxis.set_major_locator(MaxNLocator(integer=True, min_n_ticks=1)) ax.axvline( 0, linestyle="--", @@ -1043,3 +1072,6 @@ class Schedule: fig.savefig(buffer, format="svg") return buffer.getvalue() + + # SVG is valid HTML. This is useful for e.g. sphinx-gallery + _repr_html_ = _repr_svg_ diff --git a/b_asic/scheduler_gui/main_window.py b/b_asic/scheduler_gui/main_window.py index a182ad2ee6ee1a52f7292862442c52892ce340c4..db83d4256001d9f1f69e2ed0bd540e680fd6d543 100755 --- a/b_asic/scheduler_gui/main_window.py +++ b/b_asic/scheduler_gui/main_window.py @@ -52,6 +52,7 @@ from b_asic._version import __version__ from b_asic.graph_component import GraphComponent, GraphID from b_asic.gui_utils.about_window import AboutWindow from b_asic.gui_utils.icons import get_icon +from b_asic.gui_utils.mpl_window import MPLWindow from b_asic.schedule import Schedule from b_asic.scheduler_gui.axes_item import AxesItem from b_asic.scheduler_gui.operation_item import OperationItem @@ -122,7 +123,7 @@ class ScheduleMainWindow(QMainWindow, Ui_MainWindow): self._file_name = None self._show_incorrect_execution_time = True self._show_port_numbers = True - + self._execution_time_for_variables = None # Recent files self._max_recent_files = 4 self._recent_files_actions: List[QAction] = [] @@ -161,6 +162,9 @@ class ScheduleMainWindow(QMainWindow, Ui_MainWindow): self.action_show_port_numbers.triggered.connect(self._toggle_port_number) self.actionPlot_schedule.setIcon(get_icon('plot-schedule')) self.actionPlot_schedule.triggered.connect(self._plot_schedule) + self.action_view_variables.triggered.connect( + self._show_execution_times_for_variables + ) self.actionZoom_to_fit.setIcon(get_icon('zoom-to-fit')) self.actionZoom_to_fit.triggered.connect(self._zoom_to_fit) self.actionToggle_full_screen.setIcon(get_icon('full-screen')) @@ -408,6 +412,8 @@ class ScheduleMainWindow(QMainWindow, Ui_MainWindow): self.info_table_clear() self.update_statusbar("Closed schedule") self._toggle_file_loaded(False) + self.action_view_variables.setEnabled(False) + self.menu_view_execution_times.setEnabled(False) @Slot() def save(self) -> None: @@ -577,7 +583,7 @@ class ScheduleMainWindow(QMainWindow, Ui_MainWindow): 'hide_exit_dialog' in settings. """ settings = QSettings() - hide_dialog = settings.value("scheduler/hide_exit_dialog", False, bool) + hide_dialog = settings.value("scheduler/hide_exit_dialog", True, bool) ret = QMessageBox.StandardButton.Yes if not hide_dialog: @@ -643,6 +649,8 @@ class ScheduleMainWindow(QMainWindow, Ui_MainWindow): self._graph._signals.redraw_all.connect(self._redraw_all) self._graph._signals.reopen.connect(self._reopen_schedule) self.info_table_fill_schedule(self._schedule) + self._update_operation_types() + self.action_view_variables.setEnabled(True) self.update_statusbar(self.tr("Schedule loaded successfully")) def _redraw_all(self) -> None: @@ -802,6 +810,29 @@ class ScheduleMainWindow(QMainWindow, Ui_MainWindow): self._update_recent_file_list() + def _update_operation_types(self): + self.menu_view_execution_times.setEnabled(True) + for action in self.menu_view_execution_times.actions(): + self.menu_view_execution_times.removeAction(action) + for type_name in self._schedule.get_used_type_names(): + type_action = QAction(self.menu_view_execution_times) + type_action.setText(type_name) + type_action.triggered.connect( + lambda b=0, x=type_name: self._show_execution_times_for_type(x) + ) + self.menu_view_execution_times.addAction(type_action) + + def _show_execution_times_for_type(self, type_name): + self._graph._execution_time_plot(type_name) + + def _show_execution_times_for_variables(self): + print("Show") + self._execution_time_for_variables = MPLWindow("Execution times for variables") + self._schedule.get_memory_variables().plot( + self._execution_time_for_variables.axes, allow_excessive_lifetimes=True + ) + self._execution_time_for_variables.show() + def _update_recent_file_list(self): settings = QSettings() diff --git a/b_asic/scheduler_gui/main_window.ui b/b_asic/scheduler_gui/main_window.ui index 4bd2dd0c52ec4cacd2c4588b587ade4adae80b26..8f0dab650b80909045a73e1c788c7b7844469728 100755 --- a/b_asic/scheduler_gui/main_window.ui +++ b/b_asic/scheduler_gui/main_window.ui @@ -1,6 +1,6 @@ <?xml version="1.0" encoding="UTF-8"?> <ui version="4.0"> - <author>Andreas Bolin</author> + <author>Andreas Bolin and Oscar Gustafsson</author> <class>MainWindow</class> <widget class="QMainWindow" name="MainWindow"> <property name="geometry"> @@ -228,6 +228,14 @@ <property name="title"> <string>&View</string> </property> + <widget class="QMenu" name="menu_view_execution_times"> + <property name="title"> + <string>View execution times of type</string> + </property> + <property name="enabled"> + <bool>false</bool> + </property> + </widget> <addaction name="menu_node_info"/> <addaction name="actionToolbar"/> <addaction name="actionStatus_bar"/> @@ -235,6 +243,8 @@ <addaction name="action_show_port_numbers"/> <addaction name="separator"/> <addaction name="actionPlot_schedule"/> + <addaction name="action_view_variables"/> + <addaction name="menu_view_execution_times"/> <addaction name="separator"/> <addaction name="actionZoom_to_fit"/> <addaction name="actionToggle_full_screen"/> @@ -385,6 +395,9 @@ <property name="checkable"> <bool>true</bool> </property> + <property name="checked"> + <bool>true</bool> + </property> <property name="icon"> <iconset theme="view-close"> <normaloff>../../../.designer/backup</normaloff>../../../.designer/backup</iconset> @@ -440,6 +453,17 @@ <string>Plot schedule</string> </property> </action> + <action name="action_view_variables"> + <property name="text"> + <string>View execution times of variables</string> + </property> + <property name="enabled"> + <bool>false</bool> + </property> + <property name="toolTip"> + <string>View all variables</string> + </property> + </action> <action name="actionUndo"> <property name="enabled"> <bool>false</bool> diff --git a/b_asic/scheduler_gui/scheduler_event.py b/b_asic/scheduler_gui/scheduler_event.py index b7fce2575050268b12ca444fbf00eec0121e9879..7610d510a65358367c4d261849ff333cafd6dfef 100755 --- a/b_asic/scheduler_gui/scheduler_event.py +++ b/b_asic/scheduler_gui/scheduler_event.py @@ -223,6 +223,10 @@ class SchedulerEvent: # PyQt5 math.ceil(pos_y), (pos_y % 1) != 0, ) + print( + f"schedule.move_y_location({item.operation.graph_id!r}," + f" {math.ceil(pos_y)}, {(pos_y % 1) != 0})" + ) self._signals.redraw_all.emit() # Operation has been moved in x-direction if redraw: diff --git a/b_asic/scheduler_gui/scheduler_item.py b/b_asic/scheduler_gui/scheduler_item.py index 74051251c82df1ecdbc8d37e53f9a895a868975b..66986c0a4fcd7bbbc41ee0649d860eb2d74341e4 100755 --- a/b_asic/scheduler_gui/scheduler_item.py +++ b/b_asic/scheduler_gui/scheduler_item.py @@ -139,10 +139,7 @@ class SchedulerItem(SchedulerEvent, QGraphicsItemGroup): # PySide2 / PyQt5 def _redraw_all_lines(self) -> None: """Redraw all lines in schedule.""" - s = set() - for signals in self._signal_dict.values(): - s.update(signals) - for signal in s: + for signal in self._get_all_signals(): signal.update_path() def _redraw_lines(self, item: OperationItem) -> None: @@ -150,6 +147,12 @@ class SchedulerItem(SchedulerEvent, QGraphicsItemGroup): # PySide2 / PyQt5 for signal in self._signal_dict[item]: signal.update_path() + def _get_all_signals(self): + s = set() + for signals in self._signal_dict.values(): + s.update(signals) + return s + def set_warnings(self, warnings: bool = True): """ Set warnings for long execution times. @@ -162,10 +165,7 @@ class SchedulerItem(SchedulerEvent, QGraphicsItemGroup): # PySide2 / PyQt5 """ if warnings != self._warnings: self._warnings = warnings - s = set() - for signals in self._signal_dict.values(): - s.update(signals) - for signal in s: + for signal in self._get_all_signals(): signal.set_inactive() def set_port_numbers(self, port_numbers: bool = True): @@ -227,6 +227,7 @@ class SchedulerItem(SchedulerEvent, QGraphicsItemGroup): # PySide2 / PyQt5 move_time = new_start_time - op_start_time if move_time: self.schedule.move_operation(item.graph_id, move_time) + print(f"schedule.move_operation({item.graph_id!r}, {move_time})") def is_valid_delta_time(self, delta_time: int) -> bool: """ diff --git a/b_asic/scheduler_gui/ui_main_window.py b/b_asic/scheduler_gui/ui_main_window.py index 1409e7fc1f82cb7bce01c77b8282698eb9e3e088..15e768a0df4caf332a1d69aba524425bc66fd2d8 100755 --- a/b_asic/scheduler_gui/ui_main_window.py +++ b/b_asic/scheduler_gui/ui_main_window.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Form implementation generated from reading ui file './main_window.ui' +# Form implementation generated from reading ui file '.\main_window.ui' # # Created by: PyQt5 UI code generator 5.15.7 # @@ -135,6 +135,9 @@ class Ui_MainWindow(object): self.menu_Recent_Schedule.setObjectName("menu_Recent_Schedule") self.menuView = QtWidgets.QMenu(self.menubar) self.menuView.setObjectName("menuView") + self.menu_view_execution_times = QtWidgets.QMenu(self.menuView) + self.menu_view_execution_times.setEnabled(False) + self.menu_view_execution_times.setObjectName("menu_view_execution_times") self.menu_Edit = QtWidgets.QMenu(self.menubar) self.menu_Edit.setObjectName("menu_Edit") self.menuWindow = QtWidgets.QMenu(self.menubar) @@ -186,6 +189,7 @@ class Ui_MainWindow(object): self.menu_save_as.setObjectName("menu_save_as") self.menu_exit_dialog = QtWidgets.QAction(MainWindow) self.menu_exit_dialog.setCheckable(True) + self.menu_exit_dialog.setChecked(True) icon = QtGui.QIcon.fromTheme("view-close") self.menu_exit_dialog.setIcon(icon) self.menu_exit_dialog.setObjectName("menu_exit_dialog") @@ -202,6 +206,9 @@ class Ui_MainWindow(object): self.actionReorder.setObjectName("actionReorder") self.actionPlot_schedule = QtWidgets.QAction(MainWindow) self.actionPlot_schedule.setObjectName("actionPlot_schedule") + self.action_view_variables = QtWidgets.QAction(MainWindow) + self.action_view_variables.setEnabled(False) + self.action_view_variables.setObjectName("action_view_variables") self.actionUndo = QtWidgets.QAction(MainWindow) self.actionUndo.setEnabled(False) self.actionUndo.setObjectName("actionUndo") @@ -259,6 +266,8 @@ class Ui_MainWindow(object): self.menuView.addAction(self.action_show_port_numbers) self.menuView.addSeparator() self.menuView.addAction(self.actionPlot_schedule) + self.menuView.addAction(self.action_view_variables) + self.menuView.addAction(self.menu_view_execution_times.menuAction()) self.menuView.addSeparator() self.menuView.addAction(self.actionZoom_to_fit) self.menuView.addAction(self.actionToggle_full_screen) @@ -312,6 +321,9 @@ class Ui_MainWindow(object): self.menuFile.setTitle(_translate("MainWindow", "&File")) self.menu_Recent_Schedule.setTitle(_translate("MainWindow", "Open &recent")) self.menuView.setTitle(_translate("MainWindow", "&View")) + self.menu_view_execution_times.setTitle( + _translate("MainWindow", "View execution times of type") + ) self.menu_Edit.setTitle(_translate("MainWindow", "&Edit")) self.menuWindow.setTitle(_translate("MainWindow", "&Window")) self.menuHelp.setTitle(_translate("MainWindow", "&Help")) @@ -352,6 +364,12 @@ class Ui_MainWindow(object): ) self.actionPlot_schedule.setText(_translate("MainWindow", "&Plot schedule")) self.actionPlot_schedule.setToolTip(_translate("MainWindow", "Plot schedule")) + self.action_view_variables.setText( + _translate("MainWindow", "View execution times of variables") + ) + self.action_view_variables.setToolTip( + _translate("MainWindow", "View all variables") + ) self.actionUndo.setText(_translate("MainWindow", "Undo")) self.actionUndo.setShortcut(_translate("MainWindow", "Ctrl+Z")) self.actionRedo.setText(_translate("MainWindow", "Redo")) diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 5869e38a41da772917866c43d23a44e2e6f7ea36..7bcb273ea369611eec98e3568c59b8f18efb7518 100755 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -1393,7 +1393,7 @@ class SFG(AbstractOperation): self, show_id: bool = False, engine: Optional[str] = None, - branch_node: bool = False, + branch_node: bool = True, port_numbering: bool = True, splines: str = "spline", ) -> Digraph: @@ -1409,7 +1409,7 @@ class SFG(AbstractOperation): engine : string, optional Graphviz layout engine to be used, see https://graphviz.org/documentation/. Most common are "dot" and "neato". Default is None leading to dot. - branch_node : bool, default: False + branch_node : bool, default: True Add a branch node in case the fan-out of a signal is two or more. port_numbering : bool, default: True Show the port number in case the number of ports (input or output) is two or @@ -1490,12 +1490,20 @@ class SFG(AbstractOperation): def _repr_png_(self): return self.sfg_digraph()._repr_mimebundle_(include=["image/png"])["image/png"] + def _repr_svg_(self): + return self.sfg_digraph()._repr_mimebundle_(include=["image/svg+xml"])[ + "image/svg+xml" + ] + + # SVG is valid HTML. This is useful for e.g. sphinx-gallery + _repr_html_ = _repr_svg_ + def show( self, fmt: Optional[str] = None, show_id: bool = False, engine: Optional[str] = None, - branch_node: bool = False, + branch_node: bool = True, port_numbering: bool = True, splines: str = "spline", ) -> None: @@ -1514,7 +1522,7 @@ class SFG(AbstractOperation): engine : string, optional Graphviz layout engine to be used, see https://graphviz.org/documentation/. Most common are "dot" and "neato". Default is None leading to dot. - branch_node : bool, default: False + branch_node : bool, default: True Add a branch node in case the fan-out of a signal is two or more. port_numbering : bool, default: True Show the port number in case the number of ports (input or output) is two or @@ -1716,3 +1724,9 @@ class SFG(AbstractOperation): @property def is_constant(self) -> bool: return all(output.is_constant for output in self._output_operations) + + def get_used_type_names(self) -> List[TypeName]: + """Get a list of all TypeNames used in the SFG.""" + ret = list({op.type_name() for op in self.operations}) + ret.sort() + return ret diff --git a/docs_sphinx/conf.py b/docs_sphinx/conf.py index aae9c64e8b3dc5f825b6c1b58cb396b9a8d3634d..74075c277ced34395c6dfe7dc8f39c29d750eadc 100755 --- a/docs_sphinx/conf.py +++ b/docs_sphinx/conf.py @@ -5,6 +5,7 @@ # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + import shutil project = 'B-ASIC' @@ -24,7 +25,7 @@ extensions = [ 'sphinx.ext.intersphinx', 'sphinx_gallery.gen_gallery', 'numpydoc', # Needs to be loaded *after* autodoc. - 'jupyter_sphinx', + 'sphinx_copybutton', ] templates_path = ['_templates'] @@ -66,4 +67,12 @@ sphinx_gallery_conf = { 'filename_pattern': '.', 'doc_module': ('b_asic',), 'reference_url': {'b_asic': None}, + 'image_scrapers': ( + # qtgallery.qtscraper, + 'matplotlib', + ), + 'reset_modules': ( + # qtgallery.reset_qapp, + 'matplotlib', + ), } diff --git a/examples/connectmultiplesfgs.py b/examples/connectmultiplesfgs.py index 2a0786010f3efa9b90592362a1320f8eb6c1bbf1..4ab897b43ab189b941002282720cd69f52a54568 100755 --- a/examples/connectmultiplesfgs.py +++ b/examples/connectmultiplesfgs.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- """ ======================== Connecting multiple SFGs @@ -14,46 +12,40 @@ SFGs but the operations of these. To do this, one will have to use the method :func:`~b_asic.signal_flow_graph.SFG.connect_external_signals_to_components`. This example illustrates how it can be done. +""" -.. jupyter-execute:: - - from b_asic.sfg_generators import wdf_allpass - from b_asic.signal_flow_graph import SFG - from b_asic.special_operations import Input, Output - - # Generate allpass branches for fifth-ordet LWDF filter - allpass1 = wdf_allpass([0.2, 0.5]) - allpass2 = wdf_allpass([-0.5, 0.2, 0.5]) - - in_lwdf = Input() - allpass1 << in_lwdf - allpass2 << in_lwdf - out_lwdf = Output((allpass1 + allpass2) * 0.5) - - # Create SFG of LWDF with two internal SFGs - sfg_with_sfgs = SFG( - [in_lwdf], [out_lwdf], name="LWDF with separate internals SFGs for allpass branches" - ) - -The resulting SFG looks like: +from b_asic.sfg_generators import wdf_allpass +from b_asic.signal_flow_graph import SFG +from b_asic.special_operations import Input, Output -.. jupyter-execute:: +# Generate allpass branches for fifth-ordet LWDF filter +allpass1 = wdf_allpass([0.2, 0.5]) +allpass2 = wdf_allpass([-0.5, 0.2, 0.5]) - sfg_with_sfgs +in_lwdf = Input() +allpass1 << in_lwdf +allpass2 << in_lwdf +out_lwdf = Output((allpass1 + allpass2) * 0.5) +# Create SFG of LWDF with two internal SFGs +sfg_with_sfgs = SFG( + [in_lwdf], [out_lwdf], name="LWDF with separate internals SFGs for allpass branches" +) -Now, to create a LWDF where the SFGs are flattened. Note that the original SFGs -``allpass1`` and ``allpass2`` currently cannot be printed etc after this operation. +# %% +# The resulting SFG looks like: -.. jupyter-execute:: +sfg_with_sfgs - allpass1.connect_external_signals_to_components() - allpass2.connect_external_signals_to_components() - flattened_sfg = SFG([in_lwdf], [out_lwdf], name="Flattened LWDF") +# %% +# Now, to create a LWDF where the SFGs are flattened. Note that the original SFGs +# ``allpass1`` and ``allpass2`` currently cannot be printed etc after this operation. -Resulting in: +allpass1.connect_external_signals_to_components() +allpass2.connect_external_signals_to_components() +flattened_sfg = SFG([in_lwdf], [out_lwdf], name="Flattened LWDF") -.. jupyter-execute:: +# %% +# Resulting in: - flattened_sfg -""" +flattened_sfg diff --git a/examples/firstorderiirfilter.py b/examples/firstorderiirfilter.py index 16baebebe0353cb4205e4ce3afb1b209f9cebe00..9f02c0910d71d85af677006e00c7543a2ae4c115 100755 --- a/examples/firstorderiirfilter.py +++ b/examples/firstorderiirfilter.py @@ -7,11 +7,12 @@ In this example, a direct form first-order IIR filter is designed. First, we need to import the operations that will be used in the example: """ -from b_asic.core_operations import Addition, ConstantMultiplication +from b_asic.core_operations import ConstantMultiplication from b_asic.special_operations import Delay, Input, Output # %% -# Then, we continue by defining the input and delay element, which we can optionally name. +# Then, we continue by defining the input and delay element, which we can optionally +# name. input = Input(name="My input") delay = Delay(name="The only delay") @@ -27,15 +28,17 @@ a1 = ConstantMultiplication(0.5, delay) first_addition = a1 + input # %% -# Or by creating them, but connecting the input later. Each operation has a function :func:`~b_asic.operation.Operation.input` -# that is used to access a specific input (or output, by using :func:`~b_asic.operation.Operation.output`). +# Or by creating them, but connecting the input later. Each operation has a function +# :func:`~b_asic.operation.Operation.input`that is used to access a specific input +# (or output, by using :func:`~b_asic.operation.Operation.output`). b1 = ConstantMultiplication(0.7) b1.input(0).connect(delay) # %% -# The latter is useful when there is not a single order to create the signal flow graph, e.g., for recursive algorithms. -# In this example, we could not connect the output of the delay as that was not yet available. +# The latter is useful when there is not a single order to create the signal flow +# graph, e.g., for recursive algorithms. In this example, we could not connect the +# output of the delay as that was not yet available. # # There is also a shorthand form to connect signals using the ``<<`` operator: @@ -47,47 +50,27 @@ delay << first_addition output = Output(b1 + first_addition) # %% -# Now, we should create a signal flow graph, but first it must be imported (normally, this should go at the top of the file). +# Now, we should create a signal flow graph, but first it must be imported (normally, +# this should go at the top of the file). -from b_asic.signal_flow_graph import SFG +from b_asic.signal_flow_graph import SFG # noqa: E402 # %% -# The signal flow graph is defined by its inputs and outputs, so these must be provided. As, in general, there can be -# multiple inputs and outputs, there should be provided as a list or a tuple. +# The signal flow graph is defined by its inputs and outputs, so these must be +# provided. As, in general, there can be multiple inputs and outputs, there should +# be provided as a list or a tuple. firstorderiir = SFG([input], [output]) # %% -# If this is executed in an enriched terminal, such as a Jupyter Notebook, Jupyter QtConsole, or Spyder, just typing -# the variable name will return a graphical representation of the signal flow graph. +# If this is executed in an enriched terminal, such as a Jupyter Notebook, Jupyter +# QtConsole, or Spyder, just typing the variable name will return a graphical +# representation of the signal flow graph. firstorderiir # %% -# This will look something like -# -# .. graphviz:: -# -# digraph { -# rankdir=LR -# in1 [shape=cds] -# in1 -> add1 -# out1 [shape=cds] -# add2 -> out1 -# add1 [shape=ellipse] -# cmul1 -> add1 -# cmul1 [shape=ellipse] -# add1 -> t1 -# t1 [shape=square] -# add1 -> add2 -# add2 [shape=ellipse] -# cmul2 -> add2 -# cmul2 [shape=ellipse] -# t1 -> cmul2 -# t1 -> cmul1 -# } -# -# For now, we can print the precendence relations of the SFG +# For now, we can print the precedence relations of the SFG firstorderiir.print_precedence_graph() # %% @@ -145,21 +128,23 @@ firstorderiir.print_precedence_graph() # add2 [label=add2 shape=ellipse] # } # -# As seen, each operation has an id, in addition to the optional name. This can be used to access the operation. -# For example, +# As seen, each operation has an id, in addition to the optional name. +# This can be used to access the operation. For example, firstorderiir.find_by_id('cmul1') # %% -# Note that this operation differs from ``a1`` defined above as the operations are copied and recreated once inserted -# into a signal flow graph. +# Note that this operation differs from ``a1`` defined above as the operations are +# copied and recreated once inserted into a signal flow graph. # -# The signal flow graph can also be simulated. For this, we must import :class:`.Simulation`. +# The signal flow graph can also be simulated. For this, we must import +# :class:`.Simulation`. -from b_asic.simulation import Simulation +from b_asic.simulation import Simulation # noqa: E402 # %% -# The :class:`.Simulation` class require that we provide inputs. These can either be arrays of values or we can use functions -# that provides the values when provided a time index. +# The :class:`.Simulation` class require that we provide inputs. These can either be +# arrays of values or we can use functions that provides the values when provided a +# time index. # # Let us create a simulation that simulates a short impulse response: @@ -171,18 +156,18 @@ sim = Simulation(firstorderiir, [[1, 0, 0, 0, 0]]) sim.run() # %% -# The returned value is the output after the final iteration. However, we may often be interested in the results from -# the whole simulation. -# The results from the simulation, which is a dictionary of all the nodes in the signal flow graph, -# can be obtained as +# The returned value is the output after the final iteration. However, we may often be +# interested in the results from the whole simulation. +# The results from the simulation, which is a dictionary of all the nodes in the signal +# flow graph, can be obtained as sim.results # %% -# Hence, we can obtain the results that we are interested in and, for example, plot the output and the value after the -# first addition: +# Hence, we can obtain the results that we are interested in and, for example, plot the +# output and the value after the first addition: -import matplotlib.pyplot as plt +import matplotlib.pyplot as plt # noqa: E402 plt.plot(sim.results['0'], label="Output") plt.plot(sim.results['add1'], label="After first addition") @@ -193,24 +178,27 @@ plt.show() # %% # To compute and plot the frequency response, it is possible to use mplsignal -from mplsignal.freq_plots import freqz_fir +from mplsignal.freq_plots import freqz_fir # noqa: E402 freqz_fir(sim.results["0"]) plt.show() # %% -# As seen, the output has not converged to zero, leading to that the frequency-response may not be correct, so we want -# to simulate for a longer time. -# Instead of just adding zeros to the input array, we can use a function that generates the impulse response instead. -# There are a number of those defined in B-ASIC for convenience, and the one for an impulse response is called :class:`.Impulse`. +# As seen, the output has not converged to zero, leading to that the frequency-response +# may not be correct, so we want to simulate for a longer time. +# Instead of just adding zeros to the input array, we can use a function that generates +# the impulse response instead. +# There are a number of those defined in B-ASIC for convenience, and the one for an +# impulse response is called :class:`.Impulse`. -from b_asic.signal_generator import Impulse +from b_asic.signal_generator import Impulse # noqa: E402 sim = Simulation(firstorderiir, [Impulse()]) # %% -# Now, as the functions will not have an end, we must run the simulation for a given number of cycles, say 30. +# Now, as the functions will not have an end, we must run the simulation for a given +# number of cycles, say 30. # This is done using :func:`~b_asic.simulation.Simulation.run_for` instead: sim.run_for(30) diff --git a/examples/fivepointwinograddft.py b/examples/fivepointwinograddft.py new file mode 100755 index 0000000000000000000000000000000000000000..2d26c4de2acef020683d28dc8304a03a8cab05ef --- /dev/null +++ b/examples/fivepointwinograddft.py @@ -0,0 +1,211 @@ +""" +======================= +Five-point Winograd DFT +======================= +""" + +from math import cos, pi, sin + +from b_asic.architecture import Architecture, Memory, ProcessingElement +from b_asic.core_operations import AddSub, Butterfly, ConstantMultiplication +from b_asic.schedule import Schedule +from b_asic.signal_flow_graph import SFG +from b_asic.special_operations import Input, Output + +u = -2 * pi / 5 +c50 = (cos(u) + cos(2 * u)) / 2 - 1 +c51 = (cos(u) - cos(2 * u)) / 2 +c52 = 1j * (sin(u) + sin(2 * u)) / 2 +c53 = 1j * (sin(2 * u)) +c54 = 1j * (sin(u) - sin(2 * u)) + + +in0 = Input("x0") +in1 = Input("x1") +in2 = Input("x2") +in3 = Input("x3") +in4 = Input("x4") +bf0 = Butterfly(in1, in3) +bf1 = Butterfly(in4, in2) +bf2 = Butterfly(bf0.output(0), bf1.output(0)) +a0 = AddSub(True, bf0.output(1), bf1.output(0)) +a1 = AddSub(True, bf2.output(0), in0) +# Should overload float*OutputPort as well +m0 = ConstantMultiplication(c50, bf2.output(0)) +m1 = ConstantMultiplication(c51, bf0.output(1)) +m2 = c52 * a0 +m3 = ConstantMultiplication(c53, bf2.output(1)) +m4 = ConstantMultiplication(c54, bf1.output(1)) +a2 = AddSub(True, m0, a1) +a3 = AddSub(False, m3, m2) +a4 = AddSub(True, m3, m4) +bf3 = Butterfly(a2, m1) +bf4 = Butterfly(bf3.output(0), a3) +bf5 = Butterfly(bf3.output(1), a4) + +out0 = Output(a1, "X0") +out1 = Output(bf4.output(0), "X1") +out2 = Output(bf4.output(1), "X2") +out4 = Output(bf5.output(0), "X4") +out3 = Output(bf5.output(1), "X3") + +sfg = SFG( + inputs=[in0, in1, in2, in3, in4], + outputs=[out0, out1, out2, out3, out4], + name="5-point Winograd DFT", +) + +# %% +# The SFG looks like +sfg + +# %% +# Set latencies and execution times +sfg.set_latency_of_type(ConstantMultiplication.type_name(), 2) +sfg.set_latency_of_type(AddSub.type_name(), 1) +sfg.set_latency_of_type(Butterfly.type_name(), 1) +sfg.set_execution_time_of_type(ConstantMultiplication.type_name(), 1) +sfg.set_execution_time_of_type(AddSub.type_name(), 1) +sfg.set_execution_time_of_type(Butterfly.type_name(), 1) + +# %% +# Generate schedule +schedule = Schedule(sfg, cyclic=True) +schedule.show() + +# Reschedule to only use one AddSub and one multiplier + +schedule.move_operation('out2', 4) +schedule.move_operation('out3', 4) +schedule.move_operation('out4', 3) +schedule.move_operation('out5', 6) +schedule.set_schedule_time(15) +schedule.move_operation('out5', 3) +schedule.move_operation('out4', 5) +schedule.move_operation('out3', 3) +schedule.move_operation('out2', 2) +schedule.move_operation('out1', 2) +schedule.move_operation('bfly4', 16) +schedule.move_operation('bfly3', 14) +schedule.move_operation('bfly2', 14) +schedule.move_operation('addsub3', 17) +schedule.move_operation('addsub5', 15) +schedule.move_operation('addsub2', 14) +schedule.move_operation('cmul5', 15) +schedule.move_operation('cmul3', 15) +schedule.move_operation('cmul1', 14) +schedule.move_operation('addsub1', 2) +schedule.move_operation('cmul2', 16) +schedule.move_operation('addsub4', 15) +schedule.move_operation('out1', 15) +schedule.move_operation('addsub1', 13) +schedule.move_operation('cmul4', 18) +schedule.move_operation('bfly1', 14) +schedule.move_operation('bfly6', 14) +schedule.move_operation('bfly5', 14) +schedule.move_operation('in5', 1) +schedule.move_operation('in3', 2) +schedule.move_operation('in2', 3) +schedule.move_operation('in4', 4) +schedule.move_operation('bfly6', -5) +schedule.move_operation('bfly5', -6) +schedule.move_operation('addsub1', -1) +schedule.move_operation('bfly1', -1) +schedule.move_operation('bfly1', -4) +schedule.move_operation('addsub1', -5) +schedule.move_operation('addsub4', -6) +schedule.move_operation('cmul4', -10) +schedule.move_operation('cmul2', -7) +schedule.move_operation('cmul1', -2) +schedule.move_operation('cmul3', -6) +schedule.move_operation('cmul5', -5) +schedule.move_operation('cmul1', -3) +schedule.move_operation('cmul5', -1) +schedule.set_schedule_time(13) +schedule.move_operation('bfly5', -6) +schedule.move_operation('bfly6', -1) +schedule.move_operation('cmul4', -6) +schedule.move_operation('addsub1', 4) +schedule.move_operation('cmul3', 4) +schedule.move_operation('cmul1', 3) +schedule.move_operation('bfly1', 3) +schedule.move_operation('cmul2', 5) +schedule.move_operation('cmul5', 4) +schedule.move_operation('addsub4', 4) +schedule.set_schedule_time(10) +schedule.move_operation('addsub1', -1) +schedule.move_operation('cmul4', 1) +schedule.move_operation('addsub4', -1) +schedule.move_operation('cmul5', -1) +schedule.move_operation('cmul2', -2) +schedule.move_operation('bfly6', -4) +schedule.move_operation('bfly1', -1) +schedule.move_operation('addsub1', -1) +schedule.move_operation('cmul1', -1) +schedule.move_operation('cmul2', -3) +schedule.move_operation('addsub2', -1) +schedule.move_operation('bfly2', -1) +schedule.move_operation('bfly1', -1) +schedule.move_operation('cmul1', -1) +schedule.move_operation('addsub2', -1) +schedule.move_operation('addsub4', -1) +schedule.move_operation('addsub4', -3) +schedule.move_operation('cmul4', -1) +schedule.move_operation('bfly1', -2) +schedule.move_operation('cmul2', -1) +schedule.move_operation('cmul1', -2) +schedule.move_operation('cmul5', -4) +schedule.move_operation('cmul1', 1) +schedule.move_operation('cmul3', -5) +schedule.move_operation('cmul5', 2) +schedule.move_operation('addsub3', -3) +schedule.move_operation('addsub1', -3) +schedule.move_operation('addsub2', -1) +schedule.move_operation('addsub3', -4) +schedule.move_operation('bfly2', -2) +schedule.move_operation('addsub5', -3) +schedule.move_operation('bfly3', -2) +schedule.show() + +# Extract memory variables and operation executions +operations = schedule.get_operations() +adders = operations.get_by_type_name(AddSub.type_name()) +adders.show(title="AddSub executions") +mults = operations.get_by_type_name('cmul') +mults.show(title="Multiplier executions") +butterflies = operations.get_by_type_name(Butterfly.type_name()) +butterflies.show(title="Butterfly executions") +inputs = operations.get_by_type_name('in') +inputs.show(title="Input executions") +outputs = operations.get_by_type_name('out') +outputs.show(title="Output executions") + +addsub = ProcessingElement(adders, entity_name="addsub") +butterfly = ProcessingElement(butterflies, entity_name="butterfly") +multiplier = ProcessingElement(mults, entity_name="multiplier") +pe_in = ProcessingElement(inputs, entity_name='input') +pe_out = ProcessingElement(outputs, entity_name='output') + +mem_vars = schedule.get_memory_variables() +mem_vars.show(title="All memory variables") +direct, mem_vars = mem_vars.split_on_length() +mem_vars.show(title="Non-zero time memory variables") +direct.show(title="Direct interconnects") +mem_vars_set = mem_vars.split_on_ports(read_ports=1, write_ports=1, total_ports=2) + +memories = [] +for i, mem in enumerate(mem_vars_set): + memory = Memory(mem, memory_type="RAM", entity_name=f"memory{i}") + memories.append(memory) + mem.show(title=f"{memory.entity_name}") + memory.assign("left_edge") + memory.show_content(title=f"Assigned {memory.entity_name}") + + +arch = Architecture( + {addsub, butterfly, multiplier, pe_in, pe_out}, + memories, + direct_interconnects=direct, +) + +arch diff --git a/examples/folding_example_with_architecture.py b/examples/folding_example_with_architecture.py new file mode 100755 index 0000000000000000000000000000000000000000..85a7fc65e47911beecd4ce592126fcf49af4ddc1 --- /dev/null +++ b/examples/folding_example_with_architecture.py @@ -0,0 +1,110 @@ +""" +======================= +Comparison with folding +======================= + +This is a common example when illustrating folding. + +In general, the main problem with folding is to determine a suitable folding order. This +corresponds to scheduling the operations. + +Here, the folding order is the same for the adders as in the standard solution to this +problem, but the order of the multipliers is different to keep each memory variable +shorter than the scheduling period. + +""" + +from b_asic.architecture import Architecture, Memory, ProcessingElement +from b_asic.core_operations import Addition, ConstantMultiplication +from b_asic.schedule import Schedule +from b_asic.signal_flow_graph import SFG +from b_asic.special_operations import Delay, Input, Output + +in1 = Input("IN") +T1 = Delay() +T2 = Delay(T1) +a = ConstantMultiplication(0.2, T1, "a") +b = ConstantMultiplication(0.3, T1, "b") +c = ConstantMultiplication(0.4, T2, "c") +d = ConstantMultiplication(0.6, T2, "d") +add2 = a + c +add1 = in1 + add2 +add3 = b + d +T1 << add1 +out1 = Output(add1 + add3, "OUT") + +sfg = SFG(inputs=[in1], outputs=[out1], name="Bi-quad folding example") + +# %% +# The SFG looks like: +sfg + +# %% +# Set latencies and execution times +sfg.set_latency_of_type(ConstantMultiplication.type_name(), 2) +sfg.set_latency_of_type(Addition.type_name(), 1) +sfg.set_execution_time_of_type(ConstantMultiplication.type_name(), 1) +sfg.set_execution_time_of_type(Addition.type_name(), 1) + +# %% +# Create schedule +schedule = Schedule(sfg, cyclic=True) +schedule.show(title='Original schedule') + +# %% +# Reschedule to only require one adder and one multiplier +schedule.move_operation('out1', 2) +schedule.move_operation('add3', 2) +schedule.move_operation('cmul3', -3) +schedule.move_operation('add4', 3) +schedule.move_operation('cmul2', -3) +schedule.set_schedule_time(4) +schedule.move_operation('cmul2', 1) +schedule.move_operation('cmul1', 1) +schedule.move_operation('in1', 3) +schedule.move_operation('cmul3', -1) +schedule.move_operation('cmul1', 1) +schedule.show(title='Improved schedule') + +# %% +# Extract operations and create processing elements +operations = schedule.get_operations() +adders = operations.get_by_type_name('add') +adders.show(title="Adder executions") +mults = operations.get_by_type_name('cmul') +mults.show(title="Multiplier executions") +inputs = operations.get_by_type_name('in') +inputs.show(title="Input executions") +outputs = operations.get_by_type_name('out') +outputs.show(title="Output executions") + +p1 = ProcessingElement(adders, entity_name="adder") +p2 = ProcessingElement(mults, entity_name="cmul") +p_in = ProcessingElement(inputs, entity_name='input') +p_out = ProcessingElement(outputs, entity_name='output') + +# %% +# Extract and assign memory variables +mem_vars = schedule.get_memory_variables() +mem_vars.show(title="All memory variables") +direct, mem_vars = mem_vars.split_on_length() +mem_vars.show(title="Non-zero time memory variables") +mem_vars_set = mem_vars.split_on_ports(read_ports=1, write_ports=1, total_ports=2) + +memories = [] +for i, mem in enumerate(mem_vars_set): + memory = Memory(mem, memory_type="RAM", entity_name=f"memory{i}") + memories.append(memory) + mem.show(title=f"{memory.entity_name}") + memory.assign("left_edge") + memory.show_content(title=f"Assigned {memory.entity_name}") + +direct.show(title="Direct interconnects") + +# %% +# Create architecture +arch = Architecture({p1, p2, p_in, p_out}, memories, direct_interconnects=direct) + +# %% +# The architecture can be rendered in enriched shells. +arch diff --git a/examples/introduction.py b/examples/introduction.py index c14b6a0315d414e4bc133caa2e8747eab1ec6fba..d4a6d469ed49071267acd33c1199ca9c7072cede 100755 --- a/examples/introduction.py +++ b/examples/introduction.py @@ -1,7 +1,7 @@ """ -=============================== -Introduction example for course -=============================== +========================================== +Introduction example for the TSTE87 course +========================================== """ from b_asic.core_operations import Addition, ConstantMultiplication from b_asic.signal_flow_graph import SFG @@ -23,15 +23,6 @@ d.input(0).connect(a) sfg = SFG([i], [o]) -sim = Simulation(sfg, [Impulse()]) - -sim.run_for(100) - -fig, ax = plt.subplots() # Create a figure with a single Axes (plotting area) -ax.stem(sim.results['0']) # Plot the output using stem in the ax object -fig.show() # Show the figure (if it is not already shown) - # %% - -fig = freqz_fir(sim.results['0']) # Plot the frequency response in the ax Axes -fig.show() # Show the figure (if it is not already shown) \ No newline at end of file +# The SFG looks like: +sfg diff --git a/examples/secondorderdirectformiir.py b/examples/secondorderdirectformiir.py index e78a48f32e25b53cfa18f301de2aa2873379b1f1..b4eee8255a65c8ad0b5ff8e268e1fa186863c4f3 100755 --- a/examples/secondorderdirectformiir.py +++ b/examples/secondorderdirectformiir.py @@ -29,6 +29,10 @@ out1 = Output(add4, "OUT1") sfg = SFG(inputs=[in1], outputs=[out1], name="Second-order direct form IIR filter") +# %% +# The SFG looks like +sfg + # %% # Set latencies and execution times sfg.set_latency_of_type(ConstantMultiplication.type_name(), 2) diff --git a/examples/secondorderdirectformiir_architecture.py b/examples/secondorderdirectformiir_architecture.py index 2cb2022484882ae45ae69f693fc2a608cc2ad4e7..72d23961bcc44c681af4688f545b690900f47945 100755 --- a/examples/secondorderdirectformiir_architecture.py +++ b/examples/secondorderdirectformiir_architecture.py @@ -30,19 +30,23 @@ out1 = Output(add4, "OUT1") sfg = SFG(inputs=[in1], outputs=[out1], name="Second-order direct form IIR filter") # %% -# Set latencies and execution times +# The SFG is +sfg + +# %% +# Set latencies and execution times. sfg.set_latency_of_type(ConstantMultiplication.type_name(), 2) sfg.set_latency_of_type(Addition.type_name(), 1) sfg.set_execution_time_of_type(ConstantMultiplication.type_name(), 1) sfg.set_execution_time_of_type(Addition.type_name(), 1) # %% -# Create schedule +# Create schedule. schedule = Schedule(sfg, cyclic=True) schedule.show(title='Original schedule') # %% -# Rescheudle to only require one adder and one multiplier +# Reschedule to only require one adder and one multiplier. schedule.move_operation('add4', 2) schedule.move_operation('cmul5', -4) schedule.move_operation('cmul4', -5) @@ -51,7 +55,7 @@ schedule.move_operation('cmul3', 1) schedule.show(title='Improved schedule') # %% -# Extract operations and create processing elements +# Extract operations and create processing elements. operations = schedule.get_operations() adders = operations.get_by_type_name('add') adders.show(title="Adder executions") @@ -62,53 +66,75 @@ inputs.show(title="Input executions") outputs = operations.get_by_type_name('out') outputs.show(title="Output executions") -p1 = ProcessingElement(adders, entity_name="adder") -p2 = ProcessingElement(mults, entity_name="cmul") -p_in = ProcessingElement(inputs, entity_name='in') -p_out = ProcessingElement(outputs, entity_name='out') +adder = ProcessingElement(adders, entity_name="adder") +multiplier = ProcessingElement(mults, entity_name="multiplier") +pe_in = ProcessingElement(inputs, entity_name='input') +pe_out = ProcessingElement(outputs, entity_name='output') # %% -# Extract memory variables +# Extract and assign memory variables. mem_vars = schedule.get_memory_variables() mem_vars.show(title="All memory variables") direct, mem_vars = mem_vars.split_on_length() -direct.show(title="Direct interconnects") mem_vars.show(title="Non-zero time memory variables") mem_vars_set = mem_vars.split_on_ports(read_ports=1, write_ports=1, total_ports=2) -memories = set() +memories = [] for i, mem in enumerate(mem_vars_set): - memories.add(Memory(mem, entity_name=f"memory{i}")) - mem.show(title=f"memory{i}") + memory = Memory(mem, memory_type="RAM", entity_name=f"memory{i}") + memories.append(memory) + mem.show(title=f"{memory.entity_name}") + memory.assign("left_edge") + memory.show_content(title=f"Assigned {memory.entity_name}") + +direct.show(title="Direct interconnects") # %% -# Create architecture -arch = Architecture({p1, p2, p_in, p_out}, memories, direct_interconnects=direct) +# Create architecture. +arch = Architecture( + {adder, multiplier, pe_in, pe_out}, memories, direct_interconnects=direct +) # %% # The architecture can be rendered in enriched shells. -# -# .. graphviz:: -# -# digraph { -# node [shape=record] -# memory1 [label="{{<in0> in0}|memory1|{<out0> out0}}"] -# memory0 [label="{{<in0> in0}|memory0|{<out0> out0}}"] -# memory2 [label="{{<in0> in0}|memory2|{<out0> out0}}"] -# in [label="{in|{<out0> out0}}"] -# out [label="{{<in0> in0}|out}"] -# cmul [label="{{<in0> in0}|cmul|{<out0> out0}}"] -# adder [label="{{<in0> in0|<in1> in1}|adder|{<out0> out0}}"] -# memory1:out0 -> adder:in1 [label=1] -# cmul:out0 -> adder:in0 [label=1] -# cmul:out0 -> memory0:in0 [label=3] -# memory0:out0 -> adder:in0 [label=1] -# adder:out0 -> adder:in1 [label=1] -# memory1:out0 -> cmul:in0 [label=5] -# memory0:out0 -> adder:in1 [label=2] -# adder:out0 -> memory1:in0 [label=2] -# adder:out0 -> out:in0 [label=1] -# memory2:out0 -> adder:in0 [label=2] -# cmul:out0 -> memory2:in0 [label=2] -# in:out0 -> cmul:in0 [label=1] -# } +arch + +# %% +# To reduce the amount of interconnect, the ``cuml3.0`` variable can be moved from +# ``memory0`` to ``memory2``. In this way, ``memory0`` only gets variables from the +# adder and an input multiplexer can be avoided. The memories must be assigned again as +# the contents have changed. +arch.move_process('cmul3.0', 'memory0', 'memory2') +memories[0].assign() +memories[2].assign() + +memories[0].show_content("New assigned memory0") +memories[2].show_content("New assigned memory2") + +# %% +# Looking at the architecture it is clear that there is now only one input to +# ``memory0``, so no input multiplexer is required. +arch + +# %% +# It is of course also possible to move ``add4.0`` to ``memory2`` to save one memory +# cell. It is possible to pass ``assign=True`` to perform assignment after moving. +arch.move_process('add4.0', 'memory0', 'memory2', assign=True) + +memories[0].show_content("New assigned memory0") +memories[2].show_content("New assigned memory2") + +# %% +# However, this comes at the expense of an additional input to ``memory2``. +arch + +# %% +# Finally, by noting that ``cmul1.0`` is the only variable from ``memory1`` going to +# ``in0`` of ``adder``, another multiplexer can be reduced by: +arch.move_process('cmul1.0', 'memory1', 'memory2', assign=True) +memories[1].show_content("New assigned memory1") +memories[2].show_content("New assigned memory2") + +# %% +# Leading to +arch diff --git a/examples/thirdorderblwdf.py b/examples/thirdorderblwdf.py index 117041fcecaf2d9e49df26c46363b15c26266119..df19a1ed59b09c45845223924c0cc691226d289d 100755 --- a/examples/thirdorderblwdf.py +++ b/examples/thirdorderblwdf.py @@ -5,9 +5,13 @@ Third-order Bireciprocal LWDF Small bireciprocal lattice wave digital filter. """ +import numpy as np +from mplsignal.freq_plots import freqz_fir + from b_asic.core_operations import Addition, SymmetricTwoportAdaptor from b_asic.schedule import Schedule from b_asic.signal_flow_graph import SFG +from b_asic.signal_generator import Impulse from b_asic.simulation import Simulation from b_asic.special_operations import Delay, Input, Output @@ -21,19 +25,28 @@ a = s.output(0) + D0 out0 = Output(a, "y") sfg = SFG(inputs=[in0], outputs=[out0], name="Third-order BLWDF") +# %% +# The SFG looks like +sfg -# Set latencies and exection times +# %% +# Set latencies and execution times sfg.set_latency_of_type(SymmetricTwoportAdaptor.type_name(), 4) sfg.set_latency_of_type(Addition.type_name(), 1) sfg.set_execution_time_of_type(SymmetricTwoportAdaptor.type_name(), 1) sfg.set_execution_time_of_type(Addition.type_name(), 1) -sim = Simulation(sfg, [lambda n: 0 if n else 1]) +# %% +# Simulate +sim = Simulation(sfg, [Impulse()]) sim.run_for(1000) -import numpy as np -from mplsignal.freq_plots import freqz_fir +# %% +# Display output freqz_fir(np.array(sim.results['0']) / 2) +# %% +# Create and display schedule schedule = Schedule(sfg, cyclic=True) +schedule.show() diff --git a/examples/threepointwinograddft.py b/examples/threepointwinograddft.py index 17ec5c42248e4635677dcca13381a6c0d47ec1de..ff5cfa5a1bbc8fa4586a4a2f28921003912633f0 100755 --- a/examples/threepointwinograddft.py +++ b/examples/threepointwinograddft.py @@ -6,7 +6,11 @@ Three-point Winograd DFT from math import cos, pi, sin -from b_asic.core_operations import Addition, ConstantMultiplication, Subtraction +import matplotlib.pyplot as plt +import networkx as nx + +from b_asic.architecture import Architecture, Memory, ProcessingElement +from b_asic.core_operations import AddSub, ConstantMultiplication from b_asic.schedule import Schedule from b_asic.signal_flow_graph import SFG from b_asic.special_operations import Input, Output @@ -19,14 +23,14 @@ c31 = sin(u) in0 = Input("x0") in1 = Input("x1") in2 = Input("x2") -a0 = in1 + in2 -a1 = in1 - in2 -a2 = a0 + in0 +a0 = AddSub(True, in1, in2) +a1 = AddSub(False, in1, in2) +a2 = AddSub(True, a0, in0) m0 = c30 * a0 m1 = c31 * a1 -a3 = a2 + m0 -a4 = a3 + m1 -a5 = a3 - m1 +a3 = AddSub(True, a2, m0) +a4 = AddSub(True, a3, m1) +a5 = AddSub(False, a3, m1) out0 = Output(a2, "X0") out1 = Output(a4, "X1") out2 = Output(a5, "X2") @@ -37,18 +41,109 @@ sfg = SFG( name="3-point Winograd DFT", ) +# %% +# The SFG looks like +sfg + # %% # Set latencies and execution times sfg.set_latency_of_type(ConstantMultiplication.type_name(), 2) -sfg.set_latency_of_type(Addition.type_name(), 1) -sfg.set_latency_of_type(Subtraction.type_name(), 1) +sfg.set_latency_of_type(AddSub.type_name(), 1) sfg.set_execution_time_of_type(ConstantMultiplication.type_name(), 1) -sfg.set_execution_time_of_type(Addition.type_name(), 1) -sfg.set_execution_time_of_type(Subtraction.type_name(), 1) +sfg.set_execution_time_of_type(AddSub.type_name(), 1) # %% # Generate schedule schedule = Schedule(sfg, cyclic=True) schedule.show() -pc = schedule.get_memory_variables() +# Reschedule to only use one AddSub and one multiplier +schedule.set_schedule_time(10) +schedule.move_operation('out2', 3) +schedule.move_operation('out3', 4) +schedule.move_operation('addsub5', 2) +schedule.move_operation('addsub4', 3) +schedule.move_operation('addsub3', 2) +schedule.move_operation('cmul2', 2) +schedule.move_operation('cmul1', 2) +schedule.move_operation('out1', 5) +schedule.move_operation('addsub1', 3) +schedule.move_operation('addsub6', 2) +schedule.move_operation('addsub2', 2) +schedule.move_operation('in2', 1) +schedule.move_operation('in3', 2) +schedule.move_operation('cmul2', 1) +schedule.move_operation('out3', 6) +schedule.move_operation('out2', 6) +schedule.move_operation('out1', 6) +schedule.move_operation('addsub6', 1) +schedule.move_operation('addsub4', 3) +schedule.move_operation('addsub5', 4) +schedule.move_operation('addsub4', 1) +schedule.move_operation('addsub5', 4) +schedule.move_operation('cmul2', 3) +schedule.move_operation('addsub4', 2) +schedule.move_operation('cmul2', 3) +schedule.move_operation('addsub3', 5) +schedule.set_schedule_time(6) +schedule.move_operation('addsub1', 1) +schedule.move_operation('addsub4', -1) +schedule.move_operation('cmul2', -2) +schedule.move_operation('addsub4', -1) +schedule.move_operation('addsub1', -1) +schedule.move_operation('addsub3', -1) +schedule.move_operation('addsub5', -4) +schedule.show() + +# Extract memory variables and operation executions +operations = schedule.get_operations() +adders = operations.get_by_type_name(AddSub.type_name()) +adders.show(title="AddSub executions") +mults = operations.get_by_type_name('cmul') +mults.show(title="Multiplier executions") +inputs = operations.get_by_type_name('in') +inputs.show(title="Input executions") +outputs = operations.get_by_type_name('out') +outputs.show(title="Output executions") + +addsub = ProcessingElement(adders, entity_name="addsub") +multiplier = ProcessingElement(mults, entity_name="multiplier") +pe_in = ProcessingElement(inputs, entity_name='input') +pe_out = ProcessingElement(outputs, entity_name='output') + +mem_vars = schedule.get_memory_variables() +mem_vars.show(title="All memory variables") +direct, mem_vars = mem_vars.split_on_length() +mem_vars.show(title="Non-zero time memory variables") +mem_vars_set = mem_vars.split_on_ports(read_ports=1, write_ports=1, total_ports=2) +direct.show(title="Direct interconnects") + +fig, ax = plt.subplots() +fig.suptitle('Exclusion graph based on ports') +nx.draw(mem_vars.create_exclusion_graph_from_ports(1, 1, 2), ax=ax) + +memories = [] +for i, mem in enumerate(mem_vars_set): + memory = Memory(mem, memory_type="RAM", entity_name=f"memory{i}") + memories.append(memory) + mem.show(title=f"{memory.entity_name}") + memory.assign("left_edge") + memory.show_content(title=f"Assigned {memory.entity_name}") + + +arch = Architecture( + {addsub, multiplier, pe_in, pe_out}, memories, direct_interconnects=direct +) + +arch + +# %% +# Move memory variables +arch.move_process('addsub2.0', memories[2], memories[1]) +arch.move_process('addsub4.0', memories[1], memories[2], assign=True) +memories[1].assign() + +memories[1].show_content(title="Assigned memory1") +memories[2].show_content(title="Assigned memory2") + +arch diff --git a/requirements_doc.txt b/requirements_doc.txt index 5112621bc53d459111ac351cf2d9648481a50e42..4474f4925f060ee814b5b14cb4af552c7d6b0f8c 100755 --- a/requirements_doc.txt +++ b/requirements_doc.txt @@ -3,4 +3,4 @@ furo numpydoc sphinx-gallery mplsignal -jupyter-sphinx +sphinx-copybutton diff --git a/test/test_architecture.py b/test/test_architecture.py index 8b52b0700e21097264185241d2d486200e6c4397..729fe0343adcb97d868e1a4de3627428407be460 100755 --- a/test/test_architecture.py +++ b/test/test_architecture.py @@ -1,11 +1,11 @@ from itertools import chain -from typing import List, cast +from typing import List import pytest from b_asic.architecture import Architecture, Memory, ProcessingElement from b_asic.core_operations import Addition, ConstantMultiplication -from b_asic.process import MemoryVariable, OperatorProcess +from b_asic.process import PlainMemoryVariable from b_asic.resources import ProcessCollection from b_asic.schedule import Schedule from b_asic.special_operations import Input, Output @@ -25,6 +25,24 @@ def test_processing_element_exceptions(schedule_direct_form_iir_lp_filter: Sched ProcessingElement(empty_collection) +def test_add_remove_process_from_resource(schedule_direct_form_iir_lp_filter: Schedule): + mvs = schedule_direct_form_iir_lp_filter.get_memory_variables() + operations = schedule_direct_form_iir_lp_filter.get_operations() + memory = Memory(mvs) + pe = ProcessingElement( + operations.get_by_type_name(ConstantMultiplication.type_name()) + ) + for process in operations: + with pytest.raises(KeyError, match=f"{process} not of type"): + memory.add_process(process) + for process in mvs: + with pytest.raises(KeyError, match=f"{process} not of type"): + pe.add_process(process) + + with pytest.raises(KeyError, match="PlainMV not of type"): + memory.add_process(PlainMemoryVariable(0, 0, {0: 2}, "PlainMV")) + + def test_extract_processing_elements(schedule_direct_form_iir_lp_filter: Schedule): # Extract operations from schedule operations = schedule_direct_form_iir_lp_filter.get_operations() @@ -53,9 +71,7 @@ def test_memory_exceptions(schedule_direct_form_iir_lp_filter: Schedule): ValueError, match="Do not create Resource with empty ProcessCollection" ): Memory(empty_collection) - with pytest.raises( - TypeError, match="Can only have MemoryVariable or PlainMemoryVariable" - ): + with pytest.raises(TypeError, match="Can only have MemoryProcess"): Memory(operations) # No exception Memory(mvs) @@ -79,21 +95,24 @@ def test_architecture(schedule_direct_form_iir_lp_filter: Schedule): assert len(outputs) == 1 # Create necessary processing elements + adder = ProcessingElement(adders[0], entity_name="adder") + multiplier = ProcessingElement(const_mults[0], entity_name="multiplier") + input_pe = ProcessingElement(inputs[0], entity_name="input") + output_pe = ProcessingElement(outputs[0], entity_name="output") processing_elements: List[ProcessingElement] = [ - ProcessingElement(operation) - for operation in chain(adders, const_mults, inputs, outputs) + adder, + multiplier, + input_pe, + output_pe, ] - for i, pe in enumerate(processing_elements): - pe.set_entity_name(f"{pe._type_name.upper()}{i}") - if pe._type_name == 'add': - s = ( - 'digraph {\n\tnode [shape=record]\n\t' - + pe._entity_name - + ' [label="{{<in0> in0|<in1> in1}|' - + pe._entity_name - + '|{<out0> out0}}"]\n}' - ) - assert pe._digraph().source in (s, s + '\n') + s = ( + 'digraph {\n\tnode [shape=record]\n\t' + + "adder" + + ' [label="{{<in0> in0|<in1> in1}|' + + '<adder> adder' + + '|{<out0> out0}}" fillcolor="#00B9E7" style=filled]\n}' + ) + assert adder._digraph().source in (s, s + '\n') # Extract zero-length memory variables direct_conn, mvs = mvs.split_on_length() @@ -107,33 +126,104 @@ def test_architecture(schedule_direct_form_iir_lp_filter: Schedule): for i, memory in enumerate(memories): memory.set_entity_name(f"MEM{i}") s = ( - 'digraph {\n\tnode [shape=record]\n\tMEM0 [label="{{<in0> in0}|MEM0|{<out0>' - ' out0}}"]\n}' + 'digraph {\n\tnode [shape=record]\n\tMEM0 [label="{{<in0> in0}|<MEM0>' + ' MEM0|{<out0> out0}}" fillcolor="#00CFB5" style=filled]\n}' ) assert memory.schedule_time == 18 assert memory._digraph().source in (s, s + '\n') + assert not memory.is_assigned + memory.assign() + assert memory.is_assigned + assert len(memory._assignment) == 4 + + # Set invalid name + with pytest.raises(ValueError, match='32 is not a valid VHDL identifier'): + adder.set_entity_name("32") + assert adder.entity_name == "adder" # Create architecture from architecture = Architecture( processing_elements, memories, direct_interconnects=direct_conn ) + assert architecture.direct_interconnects == direct_conn + + # Graph representation + # Parts are non-deterministic, but this first part seems OK + s = ( + 'digraph {\n\tnode [shape=record]\n\tsplines=spline\n\tsubgraph' + ' cluster_memories' + ) + assert architecture._digraph().source.startswith(s) + s = 'digraph {\n\tnode [shape=record]\n\tsplines=spline\n\tMEM0' + assert architecture._digraph(cluster=False).source.startswith(s) assert architecture.schedule_time == 18 - # assert architecture._digraph().source == "foo" for pe in processing_elements: - print(pe) assert pe.schedule_time == 18 - for operation in pe._collection: - operation = cast(OperatorProcess, operation) - print(f' {operation}') - print(architecture.get_interconnects_for_pe(pe)) - - print("") - print("") - for memory in memories: - print(memory) - for mv in memory._collection: - mv = cast(MemoryVariable, mv) - print(f' {mv.start_time} -> {mv.execution_time}: {mv.write_port.name}') - print(architecture.get_interconnects_for_memory(memory)) + + assert architecture.resource_from_name('adder') == adder + + +def test_move_process(schedule_direct_form_iir_lp_filter: Schedule): + # Resources + mvs = schedule_direct_form_iir_lp_filter.get_memory_variables() + operations = schedule_direct_form_iir_lp_filter.get_operations() + adders1, adders2 = operations.get_by_type_name(Addition.type_name()).split_on_ports( + total_ports=1 + ) + adders1 = [adders1] # Fake two PEs needed for the adders + adders2 = [adders2] # Fake two PEs needed for the adders + const_mults = operations.get_by_type_name( + ConstantMultiplication.type_name() + ).split_on_execution_time() + inputs = operations.get_by_type_name(Input.type_name()).split_on_execution_time() + outputs = operations.get_by_type_name(Output.type_name()).split_on_execution_time() + + # Create necessary processing elements + processing_elements: List[ProcessingElement] = [ + ProcessingElement(operation, entity_name=f'pe{i}') + for i, operation in enumerate(chain(adders1, adders2, const_mults)) + ] + for i, pc in enumerate(inputs): + processing_elements.append(ProcessingElement(pc, entity_name=f'input{i}')) + for i, pc in enumerate(outputs): + processing_elements.append(ProcessingElement(pc, entity_name=f'output{i}')) + + # Extract zero-length memory variables + direct_conn, mvs = mvs.split_on_length() + + # Create Memories from the memory variables (split on length to get two memories) + memories: List[Memory] = [Memory(pc) for pc in mvs.split_on_length(6)] + + # Create architecture + architecture = Architecture( + processing_elements, memories, direct_interconnects=direct_conn + ) + + # Some movement that must work + assert memories[1].collection.from_name('cmul4.0') + architecture.move_process('cmul4.0', memories[1], memories[0]) + assert memories[0].collection.from_name('cmul4.0') + + assert memories[1].collection.from_name('in1.0') + architecture.move_process('in1.0', memories[1], memories[0]) + assert memories[0].collection.from_name('in1.0') + + assert processing_elements[1].collection.from_name('add1') + architecture.move_process('add1', processing_elements[1], processing_elements[0]) + assert processing_elements[0].collection.from_name('add1') + + # Processes leave the resources they have moved from + with pytest.raises(KeyError): + memories[1].collection.from_name('cmul4.0') + with pytest.raises(KeyError): + memories[1].collection.from_name('in1.0') + with pytest.raises(KeyError): + processing_elements[1].collection.from_name('add1') + + # Processes can only be moved when the source and destination process-types match + with pytest.raises(KeyError, match="cmul4.0 not of type"): + architecture.move_process('cmul4.0', memories[0], processing_elements[0]) + with pytest.raises(KeyError, match="invalid_name not in"): + architecture.move_process('invalid_name', memories[0], processing_elements[1]) diff --git a/test/test_codegen.py b/test/test_codegen.py new file mode 100755 index 0000000000000000000000000000000000000000..b52be275b1cf42efff3798aab913fd1494706973 --- /dev/null +++ b/test/test_codegen.py @@ -0,0 +1,38 @@ +from b_asic.codegen.vhdl.common import is_valid_vhdl_identifier + + +def test_is_valid_vhdl_identifier(): + identifier_pass = { + "COUNT", + "X", + "c_out", + "FFT", + "Decoder", + "VHSIC", + "X1", + "PageCount", + "STORE_NEXT_ITEM", + "ValidIdentifier123", + "valid_identifier", + } + identifier_fail = { + "", + "architecture", + "Architecture", + "ArChItEctUrE", + "architectURE", + "entity", + "invalid+", + "invalid}", + "not-valid", + "(invalid)", + "invalid£", + "1nvalid", + "_abc", + } + + for identifier in identifier_pass: + assert is_valid_vhdl_identifier(identifier) + + for identifier in identifier_fail: + assert not is_valid_vhdl_identifier(identifier) diff --git a/test/test_process.py b/test/test_process.py index 213003afd72b920c64344582dc4f3365ad28c1e7..7ed1517957eda987811da881af8c84e2983dfc32 100755 --- a/test/test_process.py +++ b/test/test_process.py @@ -10,8 +10,8 @@ def test_PlainMemoryVariable(): assert mem.write_port == 0 assert mem.start_time == 3 assert mem.execution_time == 2 - assert mem.life_times == (1, 2) - assert mem.read_ports == (4, 5) + assert mem.life_times == [1, 2] + assert mem.read_ports == [4, 5] assert repr(mem) == "PlainMemoryVariable(3, 0, {4: 1, 5: 2}, 'Var. 0')" mem2 = PlainMemoryVariable(2, 0, {4: 2, 5: 3}, 'foo') @@ -39,3 +39,35 @@ def test_MemoryVariables(secondorder_iir_schedule): def test_OperatorProcess_error(secondorder_iir_schedule): with pytest.raises(ValueError, match="does not have an execution time specified"): _ = secondorder_iir_schedule.get_operations() + + +def test_MultiReadProcess(): + mv = PlainMemoryVariable(3, 0, {0: 1, 1: 2, 2: 5}, name="MV") + + with pytest.raises(KeyError, match=r'Process MV: 3 not in life_times: \[1, 2, 5\]'): + mv._remove_life_time(3) + + assert mv.life_times == [1, 2, 5] + assert mv.execution_time == 5 + mv._remove_life_time(5) + assert mv.life_times == [1, 2] + assert mv.execution_time == 2 + mv._add_life_time(4) + assert mv.execution_time == 4 + assert mv.life_times == [1, 2, 4] + mv._add_life_time(4) + assert mv.life_times == [1, 2, 4] + + +def test_split_on_length(): + mv = PlainMemoryVariable(3, 0, {0: 1, 1: 2, 2: 5}, name="MV") + short, long = mv.split_on_length(2) + assert short is not None and long is not None + assert short.start_time == 3 and long.start_time == 3 + assert short.execution_time == 2 and long.execution_time == 5 + assert short.reads == {0: 1, 1: 2} + assert long.reads == {2: 5} + + short, long = mv.split_on_length(0) + assert short is None + assert long is not None diff --git a/test/test_quantization.py b/test/test_quantization.py index 2923787bfc9e51ba95f0b6dba70e99c47625067f..b668a45d3936ec429b43b750adb71a1ea62fa287 100755 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -8,10 +8,14 @@ def test_quantization(): assert quantize(a, 4, quantization=Quantization.ROUNDING) == 0.3125 assert quantize(a, 4, quantization=Quantization.MAGNITUDE_TRUNCATION) == 0.25 assert quantize(a, 4, quantization=Quantization.JAMMING) == 0.3125 + assert quantize(a, 4, quantization=Quantization.UNBIASED_ROUNDING) == 0.3125 + assert quantize(a, 4, quantization=Quantization.UNBIASED_JAMMING) == 0.3125 assert quantize(-a, 4, quantization=Quantization.TRUNCATION) == -0.3125 assert quantize(-a, 4, quantization=Quantization.ROUNDING) == -0.3125 assert quantize(-a, 4, quantization=Quantization.MAGNITUDE_TRUNCATION) == -0.25 assert quantize(-a, 4, quantization=Quantization.JAMMING) == -0.3125 + assert quantize(-a, 4, quantization=Quantization.UNBIASED_ROUNDING) == -0.3125 + assert quantize(-a, 4, quantization=Quantization.UNBIASED_JAMMING) == -0.3125 assert quantize(complex(a, -a), 4) == complex(0.25, -0.3125) assert quantize( complex(a, -a), 4, quantization=Quantization.MAGNITUDE_TRUNCATION @@ -26,3 +30,8 @@ def test_quantization(): ) == 0.9375 ) + + assert quantize(0.3125, 3, quantization=Quantization.ROUNDING) == 0.375 + assert quantize(0.3125, 3, quantization=Quantization.UNBIASED_ROUNDING) == 0.25 + assert quantize(0.25, 3, quantization=Quantization.JAMMING) == 0.375 + assert quantize(0.25, 3, quantization=Quantization.UNBIASED_JAMMING) == 0.25 diff --git a/test/test_resources.py b/test/test_resources.py index b5f5b13d089b3bb527ef6079fb49c397453f68a9..df034037f76af143c1dbabfbcb2002fc8bfb4144 100755 --- a/test/test_resources.py +++ b/test/test_resources.py @@ -210,3 +210,57 @@ class TestProcessCollectionPlainMemoryVariable: assert exclusion_graph.degree(p1) == 3 assert exclusion_graph.degree(p2) == 1 assert exclusion_graph.degree(p3) == 3 + + def test_left_edge_maximum_lifetime(self): + a = PlainMemoryVariable(2, 0, {0: 1}, "cmul1.0") + b = PlainMemoryVariable(4, 0, {0: 7}, "cmul4.0") + c = PlainMemoryVariable(5, 0, {0: 4}, "cmul5.0") + collection = ProcessCollection([a, b, c], schedule_time=7, cyclic=True) + for heuristic in ("graph_color", "left_edge"): + assignment = collection.split_on_execution_time(heuristic) + assert len(assignment) == 2 + a_idx = 0 if a in assignment[0] else 1 + assert b not in assignment[a_idx] + assert c in assignment[a_idx] + + def test_split_on_execution_lifetime_assert(self): + a = PlainMemoryVariable(3, 0, {0: 10}, "MV0") + collection = ProcessCollection([a], schedule_time=9, cyclic=True) + for heuristic in ("graph_color", "left_edge"): + with pytest.raises( + ValueError, + match="MV0 has execution time greater than the schedule time", + ): + collection.split_on_execution_time(heuristic) + + def test_split_on_length(self): + # Test 1: Exclude a zero-time access time + collection = ProcessCollection( + collection=[PlainMemoryVariable(0, 1, {0: 1, 1: 2, 2: 3})], + schedule_time=4, + ) + short, long = collection.split_on_length(0) + assert len(short) == 0 and len(long) == 1 + for split_time in [1, 2]: + short, long = collection.split_on_length(split_time) + assert len(short) == 1 and len(long) == 1 + short, long = collection.split_on_length(3) + assert len(short) == 1 and len(long) == 0 + + # Test 2: Include a zero-time access time + collection = ProcessCollection( + collection=[PlainMemoryVariable(0, 1, {0: 0, 1: 1, 2: 2, 3: 3})], + schedule_time=4, + ) + short, long = collection.split_on_length(0) + assert len(short) == 1 and len(long) == 1 + for split_time in [1, 2]: + short, long = collection.split_on_length(split_time) + assert len(short) == 1 and len(long) == 1 + + def test_from_name(self): + a = PlainMemoryVariable(0, 0, {0: 2}, name="cool name 1337") + collection = ProcessCollection([a], schedule_time=5, cyclic=True) + with pytest.raises(KeyError, match="epic_name not in ..."): + collection.from_name("epic_name") + assert a == collection.from_name("cool name 1337") diff --git a/test/test_schedule.py b/test/test_schedule.py index 3a7059b3bc33f8c987243550c46d1529e7b3cff0..e5fc64b375ee55da016372bbffc8a4261ee3b09c 100755 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -499,6 +499,11 @@ class TestProcesses: pc = secondorder_iir_schedule.get_memory_variables() assert len(pc) == 12 + def test_get_operations(self, secondorder_iir_schedule_with_execution_times): + pc = secondorder_iir_schedule_with_execution_times.get_operations() + assert len(pc) == 13 + assert all(isinstance(operand, OperatorProcess) for operand in pc.collection) + class TestFigureGeneration: @pytest.mark.mpl_image_compare(remove_text=True, style='mpl20') @@ -564,7 +569,12 @@ class TestErrors: ): Schedule(sfg_simple_filter, scheduling_algorithm="foo") - def test_get_operations(self, secondorder_iir_schedule_with_execution_times): - pc = secondorder_iir_schedule_with_execution_times.get_operations() - assert len(pc) == 13 - assert all(isinstance(operand, OperatorProcess) for operand in pc.collection) + +class TestGetUsedTypeNames: + def test_secondorder_iir_schedule(self, secondorder_iir_schedule): + assert secondorder_iir_schedule.get_used_type_names() == [ + 'add', + 'cmul', + 'in', + 'out', + ] diff --git a/test/test_sfg.py b/test/test_sfg.py index d88c62894b5e0a777c6522d25f5b00d43e805b4c..f6a42e7aaa39014b7dd2cb155c989ca11eafc572 100755 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -83,7 +83,8 @@ class TestPrintSfg: sfg.__str__() == "id: no_id, \tname: SFG1, \tinputs: {0: '-'}, \toutputs: {0: '-'}\n" + "Internal Operations:\n" - + "----------------------------------------------------------------------------------------------------\n" + + "--------------------------------------------------------------------" + + "--------------------------------\n" + str(sfg.find_by_name("INP1")[0]) + "\n" + str(sfg.find_by_name("INP2")[0]) @@ -92,7 +93,8 @@ class TestPrintSfg: + "\n" + str(sfg.find_by_name("OUT1")[0]) + "\n" - + "----------------------------------------------------------------------------------------------------\n" + + "--------------------------------------------------------------------" + + "--------------------------------\n" ) def test_add_mul(self): @@ -108,7 +110,8 @@ class TestPrintSfg: sfg.__str__() == "id: no_id, \tname: mac_sfg, \tinputs: {0: '-'}, \toutputs: {0: '-'}\n" + "Internal Operations:\n" - + "----------------------------------------------------------------------------------------------------\n" + + "--------------------------------------------------------------------" + + "--------------------------------\n" + str(sfg.find_by_name("INP1")[0]) + "\n" + str(sfg.find_by_name("INP2")[0]) @@ -121,7 +124,8 @@ class TestPrintSfg: + "\n" + str(sfg.find_by_name("OUT1")[0]) + "\n" - + "----------------------------------------------------------------------------------------------------\n" + + "--------------------------------------------------------------------" + + "--------------------------------\n" ) def test_constant(self): @@ -136,7 +140,8 @@ class TestPrintSfg: sfg.__str__() == "id: no_id, \tname: sfg, \tinputs: {0: '-'}, \toutputs: {0: '-'}\n" + "Internal Operations:\n" - + "----------------------------------------------------------------------------------------------------\n" + + "--------------------------------------------------------------------" + + "--------------------------------\n" + str(sfg.find_by_name("CONST")[0]) + "\n" + str(sfg.find_by_name("INP1")[0]) @@ -145,7 +150,8 @@ class TestPrintSfg: + "\n" + str(sfg.find_by_name("OUT1")[0]) + "\n" - + "----------------------------------------------------------------------------------------------------\n" + + "--------------------------------------------------------------------" + + "--------------------------------\n" ) def test_simple_filter(self, sfg_simple_filter): @@ -154,7 +160,8 @@ class TestPrintSfg: == "id: no_id, \tname: simple_filter, \tinputs: {0: '-'}," " \toutputs: {0: '-'}\n" + "Internal Operations:\n" - + "----------------------------------------------------------------------------------------------------\n" + + "--------------------------------------------------------------------" + + "--------------------------------\n" + str(sfg_simple_filter.find_by_name("IN1")[0]) + "\n" + str(sfg_simple_filter.find_by_name("ADD1")[0]) @@ -165,7 +172,8 @@ class TestPrintSfg: + "\n" + str(sfg_simple_filter.find_by_name("OUT1")[0]) + "\n" - + "----------------------------------------------------------------------------------------------------\n" + + "--------------------------------------------------------------------" + + "--------------------------------\n" ) @@ -818,7 +826,8 @@ class TestConnectExternalSignalsToComponentsSoloComp: def test_connect_external_signals_to_components_operation_tree( self, operation_tree ): - """Replaces an SFG with only a operation_tree component with its inner components + """ + Replaces an SFG with only a operation_tree component with its inner components """ sfg1 = SFG(outputs=[Output(operation_tree)]) out1 = Output(None, "OUT1") @@ -832,7 +841,9 @@ class TestConnectExternalSignalsToComponentsSoloComp: def test_connect_external_signals_to_components_large_operation_tree( self, large_operation_tree ): - """Replaces an SFG with only a large_operation_tree component with its inner components + """ + Replaces an SFG with only a large_operation_tree component with its inner + components """ sfg1 = SFG(outputs=[Output(large_operation_tree)]) out1 = Output(None, "OUT1") @@ -1239,7 +1250,10 @@ class TestSFGGraph: ' [shape=ellipse]\n\tcmul1 -> add1 [headlabel=1]\n\tcmul1' ' [shape=ellipse]\n\tadd1 -> t1\n\tt1 [shape=square]\n\tt1 -> cmul1\n}' ) - assert sfg_simple_filter.sfg_digraph().source in (res, res + "\n") + assert sfg_simple_filter.sfg_digraph(branch_node=False).source in ( + res, + res + "\n", + ) def test_sfg_show_id(self, sfg_simple_filter): res = ( @@ -1250,7 +1264,9 @@ class TestSFGGraph: ' [shape=square]\n\tt1 -> cmul1 [label=s5]\n}' ) - assert sfg_simple_filter.sfg_digraph(show_id=True).source in ( + assert sfg_simple_filter.sfg_digraph( + show_id=True, branch_node=False + ).source in ( res, res + "\n", ) @@ -1265,7 +1281,7 @@ class TestSFGGraph: ' cmul1\n}' ) - assert sfg_simple_filter.sfg_digraph(branch_node=True).source in ( + assert sfg_simple_filter.sfg_digraph().source in ( res, res + "\n", ) @@ -1278,7 +1294,9 @@ class TestSFGGraph: ' -> cmul1\n}' ) - assert sfg_simple_filter.sfg_digraph(port_numbering=False).source in ( + assert sfg_simple_filter.sfg_digraph( + port_numbering=False, branch_node=False + ).source in ( res, res + "\n", ) @@ -1474,8 +1492,8 @@ class TestUnfold: for k in count1.keys(): assert count1[k] * multiple == count2[k] - # This is horrifying, but I can't figure out a way to run the test on multiple fixtures, - # so this is an ugly hack until someone that knows pytest comes along + # This is horrifying, but I can't figure out a way to run the test on multiple + # fixtures, so this is an ugly hack until someone that knows pytest comes along def test_two_inputs_two_outputs(self, sfg_two_inputs_two_outputs: SFG): self.do_tests(sfg_two_inputs_two_outputs) @@ -1639,3 +1657,15 @@ class TestInsertComponentAfter: sfg = SFG(outputs=[Output(large_operation_tree_names)]) with pytest.raises(ValueError, match="Unknown component:"): sfg.insert_operation_after('foo', SquareRoot()) + + +class TestGetUsedTypeNames: + def test_single_accumulator(self, sfg_simple_accumulator: SFG): + assert sfg_simple_accumulator.get_used_type_names() == ['add', 'in', 'out', 't'] + + def test_sfg_nested(self, sfg_nested: SFG): + assert sfg_nested.get_used_type_names() == ['in', 'out', 'sfg'] + + def test_large_operation_tree(self, large_operation_tree): + sfg = SFG(outputs=[Output(large_operation_tree)]) + assert sfg.get_used_type_names() == ['add', 'c', 'out']