diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py index f266afc6c2a43b4077a973fd2fa8e9790e2273e9..00f07a9592ea7a8242215c1859887d8600f3dbd8 100644 --- a/b_asic/core_operations.py +++ b/b_asic/core_operations.py @@ -109,6 +109,7 @@ class Addition(AbstractOperation): """ is_linear = True + is_swappable = True def __init__( self, @@ -274,7 +275,7 @@ class AddSub(AbstractOperation): return a + b if self.is_add else a - b @property - def is_add(self) -> Num: + def is_add(self) -> bool: """Get if operation is an addition.""" return self.param("is_add") @@ -283,6 +284,10 @@ class AddSub(AbstractOperation): """Set if operation is an addition.""" self.set_param("is_add", is_add) + @property + def is_swappable(self) -> bool: + return self.is_add + class Multiplication(AbstractOperation): r""" @@ -316,6 +321,7 @@ class Multiplication(AbstractOperation): ConstantMultiplication """ + is_swappable = True def __init__( self, @@ -410,6 +416,7 @@ class Min(AbstractOperation): ======== Max """ + is_swappable = True def __init__( self, @@ -455,6 +462,7 @@ class Max(AbstractOperation): ======== Min """ + is_swappable = True def __init__( self, @@ -695,6 +703,7 @@ class MAD(AbstractOperation): .. math:: y = x_0 \times x_1 + x_2 """ + is_swappable = True def __init__( self, @@ -731,6 +740,15 @@ class MAD(AbstractOperation): or self.input(1).connected_source.operation.is_constant ) + def swap_io(self) -> None: + self._input_ports = [ + self._input_ports[1], + self._input_ports[0], + self._input_ports[2], + ] + for i, p in enumerate(self._input_ports): + p._index = i + class SymmetricTwoportAdaptor(AbstractOperation): r""" @@ -743,6 +761,7 @@ class SymmetricTwoportAdaptor(AbstractOperation): \end{eqnarray} """ is_linear = True + is_swappable = True def __init__( self, @@ -784,6 +803,16 @@ class SymmetricTwoportAdaptor(AbstractOperation): """Set the constant value of this operation.""" self.set_param("value", value) + def swap_io(self) -> None: + # Swap inputs and outputs and change sign of coefficient + self._input_ports.reverse() + for i, p in enumerate(self._input_ports): + p._index = i + self._output_ports.reverse() + for i, p in enumerate(self._output_ports): + p._index = i + self.set_param("value", -self.value) + class Reciprocal(AbstractOperation): r""" diff --git a/b_asic/operation.py b/b_asic/operation.py index ccff1c5eb0551b321f83fff8da73f911a2e27d63..982a2b318fa75fc5a36efa329f8eea8c8a09d58a 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -481,6 +481,24 @@ class Operation(GraphComponent, SignalSourceProvider): """ raise NotImplementedError + @property + @abstractmethod + def is_swappable(self) -> bool: + """ + Return True if the inputs (and outputs) to the operation can be swapped and + retain the same function. + """ + raise NotImplementedError + + @abstractmethod + def swap_io(self) -> None: + """ + Swap inputs (and outputs) of operation. + + Errors if :meth:`is_swappable` is False. + """ + raise NotImplementedError + class AbstractOperation(Operation, AbstractGraphComponent): """ @@ -1150,12 +1168,28 @@ class AbstractOperation(Operation, AbstractGraphComponent): @property def is_linear(self) -> bool: + # doc-string inherited if self.is_constant: return True return False @property def is_constant(self) -> bool: + # doc-string inherited return all( input_.connected_source.operation.is_constant for input_ in self.inputs ) + + @property + def is_swappable(self) -> bool: + # doc-string inherited + return False + + def swap_io(self) -> None: + # doc-string inherited + if not self.is_swappable: + raise TypeError(f"operation io cannot be swapped for {type(self)}") + if self.input_count == 2 and self.output_count == 1: + self._input_ports.reverse() + for i, p in enumerate(self._input_ports): + p._index = i diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 5fe7ef97e61fb536ba9ff84090be53efdb2e75ec..af6dd758c26008859cbb2606acbda9a4363fd9d9 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -748,6 +748,20 @@ class SFG(AbstractOperation): Signal(output_port, new_operation) signal.set_source(new_operation) + def swap_io_of_operation(self, operation_id: GraphID) -> None: + """ + Swap the inputs (and outputs) of operation. + + Parameters + ---------- + operation_id : GraphID + The GraphID of the operation to swap. + + """ + operation = cast(Operation, self.find_by_id(operation_id)) + if operation is not None: + operation.swap_io() + def remove_operation(self, operation_id: GraphID) -> Union["SFG", None]: """ Returns a version of the SFG where the operation with the specified GraphID @@ -1399,7 +1413,8 @@ class SFG(AbstractOperation): branch_node : bool, default: False Add a branch node in case the fan-out of a signal is two or more. port_numbering : bool, default: True - Show the port number in case the number of ports (input or output) is two or more. + Show the port number in case the number of ports (input or output) is two or + more. splines : {"spline", "line", "ortho", "polyline", "curved"}, default: "spline" Spline style, see https://graphviz.org/docs/attrs/splines/ for more info. @@ -1503,7 +1518,8 @@ class SFG(AbstractOperation): branch_node : bool, default: False Add a branch node in case the fan-out of a signal is two or more. port_numbering : bool, default: True - Show the port number in case the number of ports (input or output) is two or more. + Show the port number in case the number of ports (input or output) is two or + more. splines : {"spline", "line", "ortho", "polyline", "curved"}, default: "spline" Spline style, see https://graphviz.org/docs/attrs/splines/ for more info. @@ -1599,8 +1615,8 @@ class SFG(AbstractOperation): new_source_op = new_ops[layer][source_op_idx] source_op_output = new_source_op.outputs[source_op_output_index] - # If this is the last layer, we need to create a new delay element and connect it instead - # of the copied port + # If this is the last layer, we need to create a new delay element + # and connect it instead of the copied port if layer == factor - 1: delay = Delay(name=op.name) delay.graph_id = op.graph_id @@ -1630,14 +1646,14 @@ class SFG(AbstractOperation): new_destination = new_dest_op.inputs[sink_op_output_index] new_destination.connect(new_source_port) else: - # Other opreations need to be re-targeted to the corresponding output in the - # current layer, as long as that output is not a delay, as that has been solved - # above. + # Other opreations need to be re-targeted to the corresponding + # output in the current layer, as long as that output is not a + # delay, as that has been solved above. # To avoid double connections, we'll only re-connect inputs for input_num, original_input in enumerate(op.inputs): original_source = original_input.connected_source - # We may not always have something connected to the input, if we don't - # we can abort + # We may not always have something connected to the input, if we + # don't we can abort if original_source is None: continue diff --git a/test/test_core_operations.py b/test/test_core_operations.py index fe9e770c17c366ce131cfea04ba00f5448852c4b..b53d588ee38f553d0181d7eb3b3c209769e01a80 100644 --- a/test/test_core_operations.py +++ b/test/test_core_operations.py @@ -100,6 +100,13 @@ class TestAddSub: test_operation = AddSub(is_add=False) assert test_operation.evaluate_output(0, [3 + 5j, 4 + 6j]) == -1 - 1j + def test_addsub_subtraction_is_swappable(self): + test_operation = AddSub(is_add=False) + assert not test_operation.is_swappable + + test_operation = AddSub(is_add=True) + assert test_operation.is_swappable + class TestMultiplication: """Tests for Multiplication class.""" @@ -130,10 +137,7 @@ class TestDivision: def test_division_complex(self): test_operation = Division() - assert ( - test_operation.evaluate_output(0, [60 + 40j, 10 + 20j]) - == 2.8 - 1.6j - ) + assert test_operation.evaluate_output(0, [60 + 40j, 10 + 20j]) == 2.8 - 1.6j class TestSquareRoot: @@ -254,12 +258,14 @@ class TestSymmetricTwoportAdaptor: def test_symmetrictwoportadaptor_complex(self): test_operation = SymmetricTwoportAdaptor(0.5) - assert ( - test_operation.evaluate_output(0, [2 + 1j, 3 - 2j]) == 3.5 - 3.5j - ) - assert ( - test_operation.evaluate_output(1, [2 + 1j, 3 - 2j]) == 2.5 - 0.5j - ) + assert test_operation.evaluate_output(0, [2 + 1j, 3 - 2j]) == 3.5 - 3.5j + assert test_operation.evaluate_output(1, [2 + 1j, 3 - 2j]) == 2.5 - 0.5j + + def test_symmetrictwoportadaptor_swap_io(self): + test_operation = SymmetricTwoportAdaptor(0.5) + assert test_operation.value == 0.5 + test_operation.swap_io() + assert test_operation.value == -0.5 class TestReciprocal: diff --git a/test/test_operation.py b/test/test_operation.py index 41cc64a34ec706077606d1c49d0ee6c686c36a5a..6b16a148405a3ffc32401b3e95c3184937926661 100644 --- a/test/test_operation.py +++ b/test/test_operation.py @@ -358,3 +358,11 @@ class TestLatencyOffset: ), ): bfly.set_latency_offsets({"foo": 3, "out2": 5}) + + +class TestIsSwappable: + def test_butterfly_is_swappable(self): + bfly = Butterfly() + assert not bfly.is_swappable + with pytest.raises(TypeError, match="operation io cannot be swapped"): + bfly.swap_io() diff --git a/test/test_sfg.py b/test/test_sfg.py index 3d402347e7174eedd67f87cfce299dbb44ab60d5..d88c62894b5e0a777c6522d25f5b00d43e805b4c 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -9,7 +9,7 @@ from typing import Counter, Dict, Type import pytest -from b_asic import SFG, Input, Output, Signal +from b_asic import Input, Output, Signal from b_asic.core_operations import ( Addition, Butterfly, @@ -22,6 +22,7 @@ from b_asic.core_operations import ( ) from b_asic.operation import ResultKey from b_asic.save_load_structure import python_to_sfg, sfg_to_python +from b_asic.signal_flow_graph import SFG, GraphID from b_asic.simulation import Simulation from b_asic.special_operations import Delay @@ -1564,6 +1565,30 @@ class TestIsConstant: assert not sfg_nested.is_constant +class TestSwapIOOfOperation: + def do_test(self, sfg: SFG, graph_id: GraphID): + NUM_TESTS = 5 + # Evaluate with some random values + # To avoid problems with missing inputs at the end of the sequence, + # we generate i*(some large enough) number + input_list = [ + [random.random() for _ in range(0, NUM_TESTS)] for _ in sfg.inputs + ] + sim_ref = Simulation(sfg, input_list) + sim_ref.run() + + sfg.swap_io_of_operation(graph_id) + sim_swap = Simulation(sfg, input_list) + sim_swap.run() + for n, _ in enumerate(sfg.outputs): + ref_values = list(sim_ref.results[ResultKey(f"{n}")]) + swap_values = list(sim_swap.results[ResultKey(f"{n}")]) + assert ref_values == swap_values + + def test_single_accumulator(self, sfg_simple_accumulator: SFG): + self.do_test(sfg_simple_accumulator, 'add1') + + class TestInsertComponentAfter: def test_insert_component_after_in_sfg(self, large_operation_tree_names): sfg = SFG(outputs=[Output(large_operation_tree_names)])