diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py index a1a149d787f831405558b774993b1b0ef86fe0be..f696126c8a1ad7836c0df57da7e4ffc4f0ab10bb 100644 --- a/b_asic/core_operations.py +++ b/b_asic/core_operations.py @@ -254,6 +254,7 @@ class ConstantDivision(AbstractOperation): def evaluate(self, a): return a / self.param("value") + class Butterfly(AbstractOperation): """Butterfly operation that returns two outputs. The first output is a + b and the second output is a - b. diff --git a/b_asic/operation.py b/b_asic/operation.py index 45ce01a00f976f2ff06739fbd8287990f8ba5bb4..b4b9d243bb161a985bc7c9603e7750e23706bd1f 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -157,13 +157,14 @@ class AbstractOperation(Operation, AbstractGraphComponent): if input_sources is not None: source_count = len(input_sources) if source_count != input_count: - raise ValueError(f"Operation expected {input_count} input sources but only got {source_count}") + raise ValueError( + f"Operation expected {input_count} input sources but only got {source_count}") for i, src in enumerate(input_sources): if src is not None: self._input_ports[i].connect(src.source) @abstractmethod - def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ + def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ """Evaluate the operation and generate a list of output values given a list of input values. """ @@ -308,5 +309,6 @@ class AbstractOperation(Operation, AbstractGraphComponent): def source(self) -> OutputPort: if self.output_count != 1: diff = "more" if self.output_count > 1 else "less" - raise TypeError(f"{self.__class__.__name__} cannot be used as an input source because it has {diff} than 1 output") + raise TypeError( + f"{self.__class__.__name__} cannot be used as an input source because it has {diff} than 1 output") return self.output(0) diff --git a/b_asic/port.py b/b_asic/port.py index 4f249e3cf81d19943996e2056499a323d6c10a73..103d076af2702e7e565067f7568bb6035d24a2c8 100644 --- a/b_asic/port.py +++ b/b_asic/port.py @@ -8,6 +8,7 @@ from copy import copy from typing import NewType, Optional, List, Iterable, TYPE_CHECKING from b_asic.signal import Signal +from b_asic.graph_component import Name if TYPE_CHECKING: from b_asic.operation import Operation @@ -144,22 +145,24 @@ class InputPort(AbstractPort): """ return None if self._source_signal is None else self._source_signal.source - def connect(self, src: SignalSourceProvider) -> Signal: + def connect(self, src: SignalSourceProvider, name: Name = "") -> Signal: """Connect the provided signal source to this input port by creating a new signal. Returns the new signal. """ assert self._source_signal is None, "Attempted to connect already connected input port." - return Signal(src.source, self) # self._source_signal is set by the signal constructor. - + # self._source_signal is set by the signal constructor. + return Signal(source=src.source, destination=self, name=name) + @property def value_length(self) -> Optional[int]: """Get the number of bits that this port should truncate received values to.""" return self._value_length - + @value_length.setter def value_length(self, bits: Optional[int]) -> None: """Set the number of bits that this port should truncate received values to.""" - assert bits is None or (isinstance(bits, int) and bits >= 0), "Value length must be non-negative." + assert bits is None or (isinstance( + bits, int) and bits >= 0), "Value length must be non-negative." self._value_length = bits @@ -185,7 +188,7 @@ class OutputPort(AbstractPort, SignalSourceProvider): def add_signal(self, signal: Signal) -> None: assert signal not in self._destination_signals, "Attempted to add already connected signal." self._destination_signals.append(signal) - signal.set_source(self) + signal.set_source(self) def remove_signal(self, signal: Signal) -> None: assert signal in self._destination_signals, "Attempted to remove already removed signal." @@ -195,7 +198,7 @@ class OutputPort(AbstractPort, SignalSourceProvider): def clear(self) -> None: for signal in copy(self._destination_signals): self.remove_signal(signal) - + @property def source(self) -> "OutputPort": - return self \ No newline at end of file + return self diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index baa0c0040c7490c39bc1937b413590d37be25f98..8e3f7ab9b5e0afa339cbc7a251b1a11fe51a4d0c 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -5,7 +5,7 @@ TODO: More info. from typing import List, Iterable, Sequence, Dict, MutableMapping, Optional, DefaultDict, Set from numbers import Number -from collections import defaultdict +from collections import defaultdict, deque from b_asic.port import SignalSourceProvider, OutputPort from b_asic.operation import Operation, AbstractOperation, results_key @@ -23,75 +23,133 @@ class GraphIDGenerator: self._next_id_number = defaultdict(lambda: id_number_offset) def next_id(self, type_name: TypeName) -> GraphID: - """Return the next graph id for a certain graph id type.""" + """Get the next graph id for a certain graph id type.""" self._next_id_number[type_name] += 1 return type_name + str(self._next_id_number[type_name]) + @property + def id_number_offset(self) -> GraphIDNumber: + """Get the graph id number offset of this generator.""" + return self._next_id_number.default_factory() + class SFG(AbstractOperation): """Signal flow graph. TODO: More info. """ - + _components_by_id: Dict[GraphID, GraphComponent] _components_by_name: DefaultDict[Name, List[GraphComponent]] _graph_id_generator: GraphIDGenerator _input_operations: List[Input] _output_operations: List[Output] - _original_components_added: Set[GraphComponent] - _original_input_signals: Dict[Signal, int] - _original_output_signals: Dict[Signal, int] + _original_components_to_new: Set[GraphComponent] + _original_input_signals_to_indexes: Dict[Signal, int] + _original_output_signals_to_indexes: Dict[Signal, int] def __init__(self, input_signals: Sequence[Signal] = [], output_signals: Sequence[Signal] = [], \ inputs: Sequence[Input] = [], outputs: Sequence[Output] = [], \ id_number_offset: GraphIDNumber = 0, name: Name = "", \ input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None): super().__init__( - input_count = len(input_signals) + len(inputs), - output_count = len(output_signals) + len(outputs), - name = name, - input_sources = input_sources) + input_count=len(input_signals) + len(inputs), + output_count=len(output_signals) + len(outputs), + name=name, + input_sources=input_sources) self._components_by_id = dict() self._components_by_name = defaultdict(list) + self._components_in_dfs_order = [] self._graph_id_generator = GraphIDGenerator(id_number_offset) self._input_operations = [] self._output_operations = [] - self._original_components_added = set() - self._original_input_signals = {} - self._original_output_signals = {} - - # Setup input operations and signals. - for i, s in enumerate(input_signals): - self._input_operations.append(self._add_component_copy_unconnected(Input())) - self._original_input_signals[s] = i - for i, op in enumerate(inputs, len(input_signals)): - self._input_operations.append(self._add_component_copy_unconnected(op)) - for s in op.output(0).signals: - self._original_input_signals[s] = i - - # Setup output operations and signals. - for i, s in enumerate(output_signals): - self._output_operations.append(self._add_component_copy_unconnected(Output())) - self._original_output_signals[s] = i - for i, op in enumerate(outputs, len(output_signals)): - self._output_operations.append(self._add_component_copy_unconnected(op)) - for s in op.input(0).signals: - self._original_output_signals[s] = i - + self._original_components_to_new = {} + self._original_input_signals_to_indexes = {} + self._original_output_signals_to_indexes = {} + + # Setup input signals. + for input_index, sig in enumerate(input_signals): + assert sig not in self._original_components_to_new, "Duplicate input signals sent to SFG construcctor." + + new_input_op = self._add_component_copy_unconnected(Input()) + new_sig = self._add_component_copy_unconnected(sig) + new_sig.set_source(new_input_op.output(0)) + + self._input_operations.append(new_input_op) + self._original_input_signals_to_indexes[sig] = input_index + + # Setup input operations, starting from indexes ater input signals. + for input_index, input_op in enumerate(inputs, len(input_signals)): + assert input_op not in self._original_components_to_new, "Duplicate input operations sent to SFG constructor." + new_input_op = self._add_component_copy_unconnected(input_op) + + for sig in input_op.output(0).signals: + assert sig not in self._original_components_to_new, "Duplicate input signals connected to input ports sent to SFG construcctor." + new_sig = self._add_component_copy_unconnected(sig) + new_sig.set_source(new_input_op.output(0)) + + self._original_input_signals_to_indexes[sig] = input_index + + self._input_operations.append(new_input_op) + + # Setup output signals. + for output_ind, sig in enumerate(output_signals): + new_out = self._add_component_copy_unconnected(Output()) + if sig in self._original_components_to_new: + # Signal already added when setting up inputs + new_sig = self._original_components_to_new[sig] + new_sig.set_destination(new_out.input(0)) + else: + # New signal has to be created + new_sig = self._add_component_copy_unconnected(sig) + new_sig.set_destination(new_out.input(0)) + + self._output_operations.append(new_out) + self._original_output_signals_to_indexes[sig] = output_ind + + # Setup output operations, starting from indexes after output signals. + for output_ind, output_op in enumerate(outputs, len(output_signals)): + assert output_op not in self._original_components_to_new, "Duplicate output operations sent to SFG constructor." + + new_out = self._add_component_copy_unconnected(output_op) + for sig in output_op.input(0).signals: + if sig in self._original_components_to_new: + # Signal already added when setting up inputs + new_sig = self._original_components_to_new[sig] + new_sig.set_destination(new_out.input(0)) + else: + # New signal has to be created + new_sig = self._add_component_copy_unconnected(sig) + new_sig.set_destination(new_out.input(0)) + + self._original_output_signals_to_indexes[sig] = output_ind + + self._output_operations.append(new_out) + + output_operations_set = set(self._output_operations) + # Search the graph inwards from each input signal. - for s, i in self._original_input_signals.items(): - if s.destination is None: - raise ValueError(f"Input signal #{i} is missing destination in SFG") - if s.destination.operation not in self._original_components_added: - self._add_operation_copy_recursively(s.destination.operation) + for sig, input_index in self._original_input_signals_to_indexes.items(): + # Check if already added destination. + new_sig = self._original_components_to_new[sig] + if new_sig.destination is not None and new_sig.destination.operation in output_operations_set: + # Add directly connected input to output to dfs order list + self._components_in_dfs_order.extend([new_sig.source.operation, new_sig, new_sig.destination.operation]) + elif sig.destination is None: + raise ValueError(f"Input signal #{input_index} is missing destination in SFG") + elif sig.destination.operation not in self._original_components_to_new: + self._copy_structure_from_operation_dfs( + sig.destination.operation) # Search the graph inwards from each output signal. - for s, i in self._original_output_signals.items(): - if s.source is None: - raise ValueError(f"Output signal #{i} is missing source in SFG") - if s.source.operation not in self._original_components_added: - self._add_operation_copy_recursively(s.source.operation) + for sig, output_index in self._original_output_signals_to_indexes.items(): + # Check if already added source. + mew_sig = self._original_components_to_new[sig] + if new_sig.source is None: + if sig.source is None: + raise ValueError(f"Output signal #{output_index} is missing source in SFG") + if sig.source.operation not in self._original_components_to_new: + self._copy_structure_from_operation_dfs(sig.source.operation) @property def type_name(self) -> TypeName: @@ -129,8 +187,8 @@ class SFG(AbstractOperation): @property def components(self) -> Iterable[GraphComponent]: - """Get all components of this graph.""" - return self._components_by_id.values() + """Get all components of this graph in the dfs-traversal order.""" + return self._components_in_dfs_order def find_by_id(self, graph_id: GraphID) -> Optional[GraphComponent]: """Find a graph object based on the entered Graph ID and return it. If no graph @@ -151,57 +209,112 @@ class SFG(AbstractOperation): """ return self._components_by_name.get(name, []) - def _add_component_copy_unconnected(self, original_comp: GraphComponent) -> GraphComponent: - assert original_comp not in self._original_components_added, "Tried to add duplicate SFG component" - self._original_components_added.add(original_comp) + def deep_copy(self) -> "SFG": + """Returns a deep copy of self without any connections.""" + return SFG(inputs = self._input_operations, outputs = self._output_operations, + id_number_offset = self._graph_id_generator.id_number_offset, name = self.name) + def _add_component_copy_unconnected(self, original_comp: GraphComponent) -> GraphComponent: + assert original_comp not in self._original_components_to_new, "Tried to add duplicate SFG component" new_comp = original_comp.copy_component() + self._original_components_to_new[original_comp] = new_comp new_id = self._graph_id_generator.next_id(new_comp.type_name) new_comp.graph_id = new_id self._components_by_id[new_id] = new_comp self._components_by_name[new_comp.name].append(new_comp) return new_comp - def _add_operation_copy_recursively(self, original_op: Operation) -> Operation: - # Add a copy of the operation without any connections. - new_op = self._add_component_copy_unconnected(original_op) - - # Connect input ports. - for original_input_port, new_input_port in zip(original_op.inputs, new_op.inputs): - if original_input_port.signal_count < 1: - raise ValueError("Unconnected input port in SFG") - for original_signal in original_input_port.signals: - if original_signal in self._original_input_signals: # Check if the signal is one of the SFG's input signals. - new_signal = self._add_component_copy_unconnected(original_signal) - new_signal.set_destination(new_input_port) - new_signal.set_source(self._input_operations[self._original_input_signals[original_signal]].output(0)) - elif original_signal not in self._original_components_added: # Only add the signal if it wasn't already added. - new_signal = self._add_component_copy_unconnected(original_signal) - new_signal.set_destination(new_input_port) - if original_signal.source is None: - raise ValueError("Dangling signal without source in SFG") - # Recursively add the connected operation. - new_connected_op = self._add_operation_copy_recursively(original_signal.source.operation) - new_signal.set_source(new_connected_op.output(original_signal.source.index)) - - # Connect output ports. - for original_output_port, new_output_port in zip(original_op.outputs, new_op.outputs): - for original_signal in original_output_port.signals: - if original_signal in self._original_output_signals: # Check if the signal is one of the SFG's output signals. - new_signal = self._add_component_copy_unconnected(original_signal) - new_signal.set_source(new_output_port) - new_signal.set_destination(self._output_operations[self._original_output_signals[original_signal]].input(0)) - elif original_signal not in self._original_components_added: # Only add the signal if it wasn't already added. - new_signal = self._add_component_copy_unconnected(original_signal) - new_signal.set_source(new_output_port) - if original_signal.destination is None: - raise ValueError("Dangling signal without destination in SFG") - # Recursively add the connected operation. - new_connected_op = self._add_operation_copy_recursively(original_signal.destination.operation) - new_signal.set_destination(new_connected_op.input(original_signal.destination.index)) - - return new_op - + def _copy_structure_from_operation_dfs(self, start_op: Operation): + op_stack = deque([start_op]) + + while op_stack: + original_op = op_stack.pop() + # Add or get the new copy of the operation.. + new_op = None + if original_op not in self._original_components_to_new: + new_op = self._add_component_copy_unconnected(original_op) + self._components_in_dfs_order.append(new_op) + else: + new_op = self._original_components_to_new[original_op] + + # Connect input ports to new signals + for original_input_port in original_op.inputs: + if original_input_port.signal_count < 1: + raise ValueError("Unconnected input port in SFG") + + for original_signal in original_input_port.signals: + + # Check if the signal is one of the SFG's input signals + if original_signal in self._original_input_signals_to_indexes: + # New signal already created during first step of constructor + new_signal = self._original_components_to_new[original_signal] + new_signal.set_destination(new_op.input(original_input_port.index)) + + self._components_in_dfs_order.extend([new_signal, new_signal.source.operation]) + + # Check if the signal has not been added before + elif original_signal not in self._original_components_to_new: + if original_signal.source is None: + raise ValueError("Dangling signal without source in SFG") + + new_signal = self._add_component_copy_unconnected(original_signal) + new_signal.set_destination(new_op.input(original_input_port.index)) + + self._components_in_dfs_order.append(new_signal) + + original_connected_op = original_signal.source.operation + # Check if connected Operation has been added before + if original_connected_op in self._original_components_to_new: + # Set source to the already added operations port + new_signal.set_source(self._original_components_to_new[original_connected_op].output(original_signal.source.index)) + else: + # Create new operation, set signal source to it + new_connected_op = self._add_component_copy_unconnected(original_connected_op) + new_signal.set_source(new_connected_op.output(original_signal.source.index)) + + self._components_in_dfs_order.append(new_connected_op) + + # Add connected operation to queue of operations to visit + op_stack.append(original_connected_op) + + # Connect output ports + for original_output_port in original_op.outputs: + + for original_signal in original_output_port.signals: + # Check if the signal is one of the SFG's output signals. + if original_signal in self._original_output_signals_to_indexes: + + # New signal already created during first step of constructor. + new_signal = self._original_components_to_new[original_signal] + new_signal.set_source(new_op.output(original_output_port.index)) + + self._components_in_dfs_order.extend([new_signal, new_signal.destination.operation]) + + # Check if signal has not been added before. + elif original_signal not in self._original_components_to_new: + if original_signal.source is None: + raise ValueError("Dangling signal without source in SFG") + + new_signal = self._add_component_copy_unconnected(original_signal) + new_signal.set_source(new_op.output(original_output_port.index)) + + self._components_in_dfs_order.append(new_signal) + + original_connected_op = original_signal.destination.operation + # Check if connected operation has been added. + if original_connected_op in self._original_components_to_new: + # Set destination to the already connected operations port + new_signal.set_destination(self._original_components_to_new[original_connected_op].input(original_signal.destination.index)) + else: + # Create new operation, set destination to it. + new_connected_op = self._add_component_copy_unconnected(original_connected_op) + new_signal.set_destination(new_connected_op.input(original_signal.destination.index)) + + self._components_in_dfs_order.append(new_connected_op) + + # Add connected operation to the queue of operations to visist + op_stack.append(original_connected_op) + def _evaluate_source(self, src: OutputPort, results: MutableMapping[str, Number], registers: MutableMapping[str, Number], prefix: str) -> Number: op_prefix = prefix if op_prefix: @@ -210,4 +323,4 @@ class SFG(AbstractOperation): input_values = [self._evaluate_source(input_port.signals[0].source, results, registers, prefix) for input_port in src.operation.inputs] value = src.operation.evaluate_output(src.index, input_values, results, registers, op_prefix) results[results_key(src.operation.output_count, op_prefix, src.index)] = value - return value \ No newline at end of file + return value diff --git a/test/test_sfg.py b/test/test_sfg.py index 216e4fd90c75d8e181d7a69828c1e395d7868c2f..1501e6cff3f8adfb09a3a629aea69b68a59a476e 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -1,27 +1,46 @@ import pytest -from b_asic import SFG, Signal, Output +from b_asic import SFG, Signal, Input, Output, Addition, Multiplication class TestConstructor: + def test_direct_input_to_output_sfg_construction(self): + inp = Input("INP1") + out = Output(None, "OUT1") + out.input(0).connect(inp, "S1") + + sfg = SFG(inputs = [inp], outputs = [out]) + + assert len(list(sfg.components)) == 3 + assert sfg.input_count == 1 + assert sfg.output_count == 1 + + def test_same_signal_input_and_output_sfg_construction(self): + add1 = Addition(None, None, "ADD1") + add2 = Addition(None, None, "ADD2") + + sig1 = add2.input(0).connect(add1, "S1") + + sfg = SFG(input_signals = [sig1], output_signals = [sig1]) + + assert len(list(sfg.components)) == 3 + assert sfg.input_count == 1 + assert sfg.output_count == 1 + def test_outputs_construction(self, operation_tree): - sfg = SFG(outputs=[Output(operation_tree)]) + sfg = SFG(outputs = [Output(operation_tree)]) assert len(list(sfg.components)) == 7 assert sfg.input_count == 0 assert sfg.output_count == 1 def test_signals_construction(self, operation_tree): - sfg = SFG(output_signals=[Signal(source=operation_tree.output(0))]) + sfg = SFG(output_signals = [Signal(source = operation_tree.output(0))]) assert len(list(sfg.components)) == 7 assert sfg.input_count == 0 assert sfg.output_count == 1 - def test_cycle_construction(self, operation_graph_with_cycle): - with pytest.raises(Exception): - SFG(outputs=[Output(operation_graph_with_cycle)]) - class TestEvaluation: def test_evaluate_output(self, operation_tree): @@ -33,6 +52,72 @@ class TestEvaluation: assert sfg.evaluate_output(0, []) == 14 def test_evaluate_output_cycle(self, operation_graph_with_cycle): + sfg = SFG(outputs = [Output(operation_graph_with_cycle)]) with pytest.raises(Exception): - sfg = SFG(outputs = [Output(operation_graph_with_cycle)]) - sfg.evaluate_output(0, []) \ No newline at end of file + sfg.evaluate_output(0, []) + + +class TestDeepCopy: + def test_deep_copy_no_duplicates(self): + inp1 = Input("INP1") + inp2 = Input("INP2") + inp3 = Input("INP3") + add1 = Addition(inp1, inp2, "ADD1") + mul1 = Multiplication(add1, inp3, "MUL1") + out1 = Output(mul1, "OUT1") + + mac_sfg = SFG(inputs = [inp1, inp2], outputs = [out1], name = "mac_sfg") + + mac_sfg_deep_copy = mac_sfg.deep_copy() + + for g_id, component in mac_sfg._components_by_id.items(): + component_copy = mac_sfg_deep_copy.find_by_id(g_id) + assert component.name == component_copy.name + + def test_deep_copy(self): + inp1 = Input("INP1") + inp2 = Input("INP2") + inp3 = Input("INP3") + add1 = Addition(None, None, "ADD1") + add2 = Addition(None, None, "ADD2") + mul1 = Multiplication(None, None, "MUL1") + out1 = Output(None, "OUT1") + + add1.input(0).connect(inp1, "S1") + add1.input(1).connect(inp2, "S2") + add2.input(0).connect(add1, "S4") + add2.input(1).connect(inp3, "S3") + mul1.input(0).connect(add1, "S5") + mul1.input(1).connect(add2, "S6") + out1.input(0).connect(mul1, "S7") + + mac_sfg = SFG(inputs = [inp1, inp2], outputs = [out1], name = "mac_sfg") + + mac_sfg_deep_copy = mac_sfg.deep_copy() + + for g_id, component in mac_sfg._components_by_id.items(): + component_copy = mac_sfg_deep_copy.find_by_id(g_id) + assert component.name == component_copy.name + + +class TestComponents: + def test_advanced_components(self): + inp1 = Input("INP1") + inp2 = Input("INP2") + inp3 = Input("INP3") + add1 = Addition(None, None, "ADD1") + add2 = Addition(None, None, "ADD2") + mul1 = Multiplication(None, None, "MUL1") + out1 = Output(None, "OUT1") + + add1.input(0).connect(inp1, "S1") + add1.input(1).connect(inp2, "S2") + add2.input(0).connect(add1, "S4") + add2.input(1).connect(inp3, "S3") + mul1.input(0).connect(add1, "S5") + mul1.input(1).connect(add2, "S6") + out1.input(0).connect(mul1, "S7") + + mac_sfg = SFG(inputs = [inp1, inp2], outputs = [out1], name = "mac_sfg") + + assert set([comp.name for comp in mac_sfg.components]) == {"INP1", "INP2", "INP3", "ADD1", "ADD2", "MUL1", "OUT1", "S1", "S2", "S3", "S4", "S5", "S6", "S7"}