diff --git a/b_asic/architecture.py b/b_asic/architecture.py index 5622ef8c94f2247992f99923521f46b0780bc251..cfe9d9dedb634f1922637d3275d2a790c19b2fa7 100644 --- a/b_asic/architecture.py +++ b/b_asic/architecture.py @@ -1,7 +1,8 @@ """ B-ASIC architecture classes. """ -from typing import Set, cast +from collections import defaultdict +from typing import List, Optional, Set, cast from b_asic.process import MemoryVariable, OperatorProcess, PlainMemoryVariable from b_asic.resources import ProcessCollection @@ -16,8 +17,8 @@ class ProcessingElement: process_collection : :class:`~b_asic.resources.ProcessCollection` """ - def __init__(self, process_collection: ProcessCollection): - if not len(ProcessCollection): + def __init__(self, process_collection: ProcessCollection, name=""): + if not len(process_collection): raise ValueError( "Do not create ProcessingElement with empty ProcessCollection" ) @@ -30,7 +31,7 @@ class ProcessingElement: " ProcessingElement" ) ops = [ - cast(operand, OperatorProcess).operation + cast(OperatorProcess, operand).operation for operand in process_collection.collection ] op_type = type(ops[0]) @@ -39,6 +40,20 @@ class ProcessingElement: self._collection = process_collection self._operation_type = op_type self._type_name = op_type.type_name() + self._name = name + + @property + def processes(self) -> Set[OperatorProcess]: + return {cast(OperatorProcess, p) for p in self._collection} + + def __str__(self): + return self._name or self._type_name + + def __repr__(self): + return self._name or self._type_name + + def set_name(self, name: str): + self._name = name def write_code(self, path: str, entity_name: str) -> None: """ @@ -65,8 +80,10 @@ class Memory: The type of memory. """ - def __init__(self, process_collection: ProcessCollection, memory_type: str = "RAM"): - if not len(ProcessCollection): + def __init__( + self, process_collection: ProcessCollection, memory_type: str = "RAM", name="" + ): + if not len(process_collection): raise ValueError("Do not create Memory with empty ProcessCollection") if not all( isinstance(operator, (MemoryVariable, PlainMemoryVariable)) @@ -78,6 +95,19 @@ class Memory: ) self._collection = process_collection self._memory_type = memory_type + self._name = name + + def __iter__(self): + return iter(self._collection) + + def set_name(self, name: str): + self._name = name + + def __str__(self): + return self._name or self._memory_type + + def __repr__(self): + return self._name or self._memory_type def write_code(self, path: str, entity_name: str) -> None: """ @@ -110,6 +140,8 @@ class Architecture: name : str, default: "arch" Name for the top-level architecture. Used for the entity and as prefix for all building blocks. + direct_interconnects : ProcessCollection, optional + Process collection of zero-time memory variables used for direct interconnects. """ def __init__( @@ -117,10 +149,82 @@ class Architecture: processing_elements: Set[ProcessingElement], memories: Set[Memory], name: str = "arch", + direct_interconnects: Optional[ProcessCollection] = None, ): self._processing_elements = processing_elements self._memories = memories self._name = name + self._direct_interconnects = direct_interconnects + self._variable_inport_to_resource = {} + self._variable_outport_to_resource = {} + self._operation_inport_to_resource = {} + self._operation_outport_to_resource = {} + + self._build_dicts() + + # Validate input and output ports + self.validate_ports() + + def _build_dicts(self): + for pe in self.processing_elements: + for operator in pe.processes: + for input_port in operator.operation.inputs: + self._operation_inport_to_resource[input_port] = pe + for output_port in operator.operation.outputs: + self._operation_outport_to_resource[output_port] = pe + + for memory in self.memories: + for mv in memory: + mv = cast(MemoryVariable, mv) + for read_port in mv.read_ports: + self._variable_inport_to_resource[read_port] = memory + self._variable_outport_to_resource[mv.write_port] = memory + if self._direct_interconnects: + for di in self._direct_interconnects: + di = cast(MemoryVariable, di) + for read_port in di.read_ports: + self._variable_inport_to_resource[ + read_port + ] = self._operation_outport_to_resource[di.write_port] + self._variable_outport_to_resource[ + di.write_port + ] = self._operation_inport_to_resource[read_port] + + def validate_ports(self): + # Validate inputs and outputs of memory variables in all the memories in this architecture + memory_read_ports = set() + memory_write_ports = set() + for memory in self.memories: + for mv in memory: + mv = cast(MemoryVariable, mv) + memory_write_ports.add(mv.write_port) + memory_read_ports.update(mv.read_ports) + if self._direct_interconnects: + for mv in self._direct_interconnects: + mv = cast(MemoryVariable, mv) + memory_write_ports.add(mv.write_port) + memory_read_ports.update(mv.read_ports) + + pe_input_ports = set() + pe_output_ports = set() + for pe in self.processing_elements: + for operator in pe.processes: + pe_input_ports.update(operator.operation.inputs) + pe_output_ports.update(operator.operation.outputs) + + read_port_diff = memory_read_ports.symmetric_difference(pe_input_ports) + write_port_diff = memory_write_ports.symmetric_difference(pe_output_ports) + if read_port_diff: + raise ValueError( + "Memory read port and PE output port difference:" + f" {[port.name for port in read_port_diff]}" + ) + if write_port_diff: + raise ValueError( + "Memory read port and PE output port difference:" + f" {[port.name for port in write_port_diff]}" + ) + # Make sure all inputs and outputs in the architecture are in use def write_code(self, path: str) -> None: """ @@ -132,3 +236,33 @@ class Architecture: Directory to write code in. """ raise NotImplementedError + + def get_interconnects_for_memory(self, mem: Memory): + d_in = defaultdict(lambda: 0) + d_out = defaultdict(lambda: 0) + for var in mem._collection: + var = cast(MemoryVariable, var) + d_in[self._operation_outport_to_resource[var.write_port]] += 1 + for read_port in var.read_ports: + d_out[self._operation_inport_to_resource[read_port]] += 1 + return dict(d_in), dict(d_out) + + def get_interconnects_for_pe(self, pe: ProcessingElement): + ops = cast(List[OperatorProcess], list(pe._collection)) + d_in = [defaultdict(lambda: 0) for _ in ops[0].operation.inputs] + d_out = [defaultdict(lambda: 0) for _ in ops[0].operation.outputs] + for var in pe._collection: + var = cast(OperatorProcess, var) + for i, input in enumerate(var.operation.inputs): + d_in[i][self._variable_inport_to_resource[input]] += 1 + for i, output in enumerate(var.operation.outputs): + d_out[i][self._variable_outport_to_resource[output]] += 1 + return [dict(d) for d in d_in], [dict(d) for d in d_out] + + @property + def memories(self) -> Set[Memory]: + return self._memories + + @property + def processing_elements(self) -> Set[ProcessingElement]: + return self._processing_elements diff --git a/b_asic/process.py b/b_asic/process.py index c6a4b91cf7203fb765d5796fcb78508ba52be1be..99013ea4b579850791e0531c30be7fbf063aaf8b 100644 --- a/b_asic/process.py +++ b/b_asic/process.py @@ -4,6 +4,7 @@ from typing import Dict, Optional, Tuple from b_asic.operation import Operation from b_asic.port import InputPort, OutputPort +from b_asic.types import TypeName class Process: @@ -92,6 +93,10 @@ class OperatorProcess(Process): """The Operation that the OperatorProcess corresponds to.""" return self._operation + @property + def type_name(self) -> TypeName: + return self._operation.type_name() + def __repr__(self) -> str: return f"OperatorProcess({self.start_time}, {self.operation}, {self.name!r})" diff --git a/b_asic/resources.py b/b_asic/resources.py index 922a5d62fe51ed73c7e5c1d9781474ee375ad5d4..95495672aeffa07abc7c767d4c08ab3e117fd269 100644 --- a/b_asic/resources.py +++ b/b_asic/resources.py @@ -1112,6 +1112,30 @@ class ProcessCollection: input_sync=input_sync, ) + def split_on_length(self, length: int = 0): + """ + Split the current ProcessCollection into two new ProcessCollection based on exectuion time length. + + Parameters + ---------- + length : int, default: 0 + The execution time length to split on. Length is inclusive for the smaller collection. + + Returns + ------- + A tuple of two ProcessCollections, one with short than or equal execution times and one with greater execution times. + """ + short = set() + long = set() + for process in self.collection: + if process.execution_time <= length: + short.add(process) + else: + long.add(process) + return ProcessCollection( + short, schedule_time=self.schedule_time + ), ProcessCollection(long, schedule_time=self.schedule_time) + def generate_register_based_storage_vhdl( self, filename: str, diff --git a/test/fixtures/schedule.py b/test/fixtures/schedule.py index adabad2e71f8f3eae5b0bc1bdd8cca30078fcc1d..4091bd4047750996e84b87af3c7dcae0ab1ab51d 100644 --- a/test/fixtures/schedule.py +++ b/test/fixtures/schedule.py @@ -2,6 +2,7 @@ import pytest from b_asic.core_operations import Addition, ConstantMultiplication from b_asic.schedule import Schedule +from b_asic.signal_flow_graph import SFG @pytest.fixture @@ -24,3 +25,29 @@ def secondorder_iir_schedule_with_execution_times(precedence_sfg_delays): schedule = Schedule(precedence_sfg_delays, scheduling_algorithm="ASAP") return schedule + + +@pytest.fixture +def schedule_direct_form_iir_lp_filter(sfg_direct_form_iir_lp_filter: SFG): + sfg_direct_form_iir_lp_filter.set_latency_of_type(Addition.type_name(), 4) + sfg_direct_form_iir_lp_filter.set_latency_of_type( + ConstantMultiplication.type_name(), 3 + ) + sfg_direct_form_iir_lp_filter.set_execution_time_of_type(Addition.type_name(), 2) + sfg_direct_form_iir_lp_filter.set_execution_time_of_type( + ConstantMultiplication.type_name(), 1 + ) + schedule = Schedule( + sfg_direct_form_iir_lp_filter, scheduling_algorithm="ASAP", cyclic=True + ) + schedule.move_operation('cmul4', -1) + schedule.move_operation('cmul3', -1) + schedule.move_operation('cmul4', -10) + schedule.move_operation('cmul4', 1) + schedule.move_operation('cmul3', -8) + schedule.move_operation('add4', 1) + schedule.move_operation('add4', 1) + schedule.move_operation('cmul2', 1) + schedule.move_operation('cmul2', 1) + schedule.move_operation('cmul4', 2) + return schedule diff --git a/test/fixtures/signal_flow_graph.py b/test/fixtures/signal_flow_graph.py index c8da0eebac9c531f42ba5a2591ffdf74d6ede971..e8d4f5b4d485e9dedab1855b36a6755240d2e9e8 100644 --- a/test/fixtures/signal_flow_graph.py +++ b/test/fixtures/signal_flow_graph.py @@ -306,3 +306,29 @@ def sfg_two_tap_fir(): Signal(source=add1.output(0), destination=out1.input(0)) Signal(source=cmul2.output(0), destination=add1.input(1)) return SFG(inputs=[in1], outputs=[out1], name='twotapfir') + + +@pytest.fixture +def sfg_direct_form_iir_lp_filter(): + """ + Signal flow graph of the second-order direct form 2 IIR filter used in the + first lab in the TSTE87 lab series. + + IN1>---->ADD1>----------+--->a0>--->ADD4>---->OUT1 + ^ | ^ + | T1 | + | | | + ADD2<---<a1<---+--->a1>--->ADD3 + ^ | ^ + | T2 | + | | | + +-----<a2<---+--->a2>-----+ + """ + a0, a1, a2, b1, b2 = 57 / 256, 55 / 128, 57 / 256, 179 / 512, -171 / 512 + x, y = Input(name="x"), Output(name="y") + d0, d1 = Delay(), Delay() + top_node = d0 * b1 + d1 * b2 + x + d0.input(0).connect(top_node) + d1.input(0).connect(d0) + y << a1 * d0 + a2 * d1 + a0 * top_node + return SFG(inputs=[x], outputs=[y], name='Direct Form 2 IIR Lowpass filter') diff --git a/test/test_architecture.py b/test/test_architecture.py new file mode 100644 index 0000000000000000000000000000000000000000..425ebfd3c85502e8ded75c8c58b0fd78d682dade --- /dev/null +++ b/test/test_architecture.py @@ -0,0 +1,121 @@ +from itertools import chain +from typing import List, Set, cast + +import matplotlib.pyplot as plt +import pytest + +from b_asic.architecture import Architecture, Memory, ProcessingElement +from b_asic.core_operations import Addition, ConstantMultiplication +from b_asic.process import MemoryVariable, OperatorProcess +from b_asic.resources import ProcessCollection +from b_asic.schedule import Schedule +from b_asic.signal_flow_graph import SFG +from b_asic.special_operations import Input, Output + + +def test_processing_element_exceptions(schedule_direct_form_iir_lp_filter: Schedule): + mvs = schedule_direct_form_iir_lp_filter.get_memory_variables() + with pytest.raises( + TypeError, + match="Can only have OperatorProcesses in ProcessCollection when creating", + ): + ProcessingElement(mvs) + empty_collection = ProcessCollection(collection=set(), schedule_time=5) + with pytest.raises( + ValueError, match="Do not create ProcessingElement with empty ProcessCollection" + ): + ProcessingElement(empty_collection) + + +def test_extract_processing_elements(schedule_direct_form_iir_lp_filter: Schedule): + # Extract operations from schedule + operations = schedule_direct_form_iir_lp_filter.get_operations() + + # Split into new process collections on overlapping execution time + adders = operations.get_by_type_name(Addition.type_name()).split_execution_time() + const_mults = operations.get_by_type_name( + ConstantMultiplication.type_name() + ).split_execution_time() + + # List of ProcessingElements + processing_elements: List[ProcessingElement] = [] + for adder_collection in adders: + processing_elements.append(ProcessingElement(adder_collection)) + for const_mult_collection in const_mults: + processing_elements.append(ProcessingElement(const_mult_collection)) + + assert len(processing_elements) == len(adders) + len(const_mults) + + +def test_memory_exceptions(schedule_direct_form_iir_lp_filter: Schedule): + mvs = schedule_direct_form_iir_lp_filter.get_memory_variables() + operations = schedule_direct_form_iir_lp_filter.get_operations() + empty_collection = ProcessCollection(collection=set(), schedule_time=5) + with pytest.raises( + ValueError, match="Do not create Memory with empty ProcessCollection" + ): + Memory(empty_collection) + with pytest.raises( + TypeError, match="Can only have MemoryVariable or PlainMemoryVariable" + ): + Memory(operations) + # No exception + Memory(mvs) + + +def test_architecture(schedule_direct_form_iir_lp_filter: Schedule): + # Extract memory variables and operations + mvs = schedule_direct_form_iir_lp_filter.get_memory_variables() + operations = schedule_direct_form_iir_lp_filter.get_operations() + + # Split operations further into chunks + adders = operations.get_by_type_name(Addition.type_name()).split_execution_time() + assert len(adders) == 1 + const_mults = operations.get_by_type_name( + ConstantMultiplication.type_name() + ).split_execution_time() + assert len(const_mults) == 1 + inputs = operations.get_by_type_name(Input.type_name()).split_execution_time() + assert len(inputs) == 1 + outputs = operations.get_by_type_name(Output.type_name()).split_execution_time() + assert len(outputs) == 1 + + # Create necessary processing elements + processing_elements: List[ProcessingElement] = [ + ProcessingElement(operation) + for operation in chain(adders, const_mults, inputs, outputs) + ] + for i, pe in enumerate(processing_elements): + pe.set_name(f"{pe._type_name.upper()}-{i}") + + # Extract zero-length memory variables + direct_conn, mvs = mvs.split_on_length() + + # Create Memories from the memory variables + memories: List[Memory] = [ + Memory(pc) for pc in mvs.split_ports(read_ports=1, write_ports=1) + ] + assert len(memories) == 1 + for i, memory in enumerate(memories): + memory.set_name(f"mem-{i}") + + # Create architecture from + architecture = Architecture( + set(processing_elements), set(memories), direct_interconnects=direct_conn + ) + + for pe in processing_elements: + print(pe) + for operation in pe._collection: + operation = cast(OperatorProcess, operation) + print(f' {operation}') + print(architecture.get_interconnects_for_pe(pe)) + + print("") + print("") + for memory in memories: + print(memory) + for mv in memory._collection: + mv = cast(MemoryVariable, mv) + print(f' {mv.start_time} -> {mv.execution_time}: {mv.write_port.name}') + print(architecture.get_interconnects_for_memory(memory))