diff --git a/b_asic/operation.py b/b_asic/operation.py index 6548cb45bfa391ad89e6a1db2af11fdef4827c53..4ad5b7f947f97eb8e896e6220ce9ead71117ce23 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -25,12 +25,7 @@ from typing import ( overload, ) -from b_asic.graph_component import ( - AbstractGraphComponent, - GraphComponent, - GraphID, - Name, -) +from b_asic.graph_component import AbstractGraphComponent, GraphComponent, GraphID, Name from b_asic.port import InputPort, OutputPort, SignalSourceProvider from b_asic.signal import Signal from b_asic.types import Num @@ -403,9 +398,7 @@ class Operation(GraphComponent, SignalSourceProvider): @abstractmethod def get_plot_coordinates( self, - ) -> Tuple[ - Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...] - ]: + ) -> Tuple[Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]]: """ Return a tuple containing coordinates for the two polygons outlining the latency and execution time of the operation. @@ -413,24 +406,6 @@ class Operation(GraphComponent, SignalSourceProvider): """ raise NotImplementedError - @abstractmethod - def get_io_coordinates( - self, - ) -> Tuple[ - Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...] - ]: - """ - Return a tuple containing coordinates for inputs and outputs, respectively. - These maps to the polygons and are corresponding to a start time of 0 - and height 1. - - See also - ======== - get_input_coordinates - get_output_coordinates - """ - raise NotImplementedError - @abstractmethod def get_input_coordinates( self, @@ -442,7 +417,6 @@ class Operation(GraphComponent, SignalSourceProvider): See also ======== - get_io_coordinates get_output_coordinates """ raise NotImplementedError @@ -459,7 +433,6 @@ class Operation(GraphComponent, SignalSourceProvider): See also ======== get_input_coordinates - get_io_coordinates """ raise NotImplementedError @@ -512,9 +485,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): input_count: int, output_count: int, name: Name = Name(""), - input_sources: Optional[ - Sequence[Optional[SignalSourceProvider]] - ] = None, + input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None, latency: Optional[int] = None, latency_offsets: Optional[Dict[str, int]] = None, execution_time: Optional[int] = None, @@ -575,9 +546,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): @overload @abstractmethod - def evaluate( - self, *inputs: Num - ) -> List[Num]: # pylint: disable=arguments-differ + def evaluate(self, *inputs: Num) -> List[Num]: # pylint: disable=arguments-differ ... @abstractmethod @@ -601,34 +570,25 @@ class AbstractOperation(Operation, AbstractGraphComponent): # Import here to avoid circular imports. from b_asic.core_operations import Addition, Constant - return Addition( - Constant(src) if isinstance(src, Number) else src, self - ) + return Addition(Constant(src) if isinstance(src, Number) else src, self) def __sub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction": # Import here to avoid circular imports. from b_asic.core_operations import Constant, Subtraction - return Subtraction( - self, Constant(src) if isinstance(src, Number) else src - ) + return Subtraction(self, Constant(src) if isinstance(src, Number) else src) def __rsub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction": # Import here to avoid circular imports. from b_asic.core_operations import Constant, Subtraction - return Subtraction( - Constant(src) if isinstance(src, Number) else src, self - ) + return Subtraction(Constant(src) if isinstance(src, Number) else src, self) def __mul__( self, src: Union[SignalSourceProvider, Num] ) -> Union["Multiplication", "ConstantMultiplication"]: # Import here to avoid circular imports. - from b_asic.core_operations import ( - ConstantMultiplication, - Multiplication, - ) + from b_asic.core_operations import ConstantMultiplication, Multiplication return ( ConstantMultiplication(src, self) @@ -640,10 +600,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): self, src: Union[SignalSourceProvider, Num] ) -> Union["Multiplication", "ConstantMultiplication"]: # Import here to avoid circular imports. - from b_asic.core_operations import ( - ConstantMultiplication, - Multiplication, - ) + from b_asic.core_operations import ConstantMultiplication, Multiplication return ( ConstantMultiplication(src, self) @@ -655,9 +612,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): # Import here to avoid circular imports. from b_asic.core_operations import Constant, Division - return Division( - self, Constant(src) if isinstance(src, Number) else src - ) + return Division(self, Constant(src) if isinstance(src, Number) else src) def __rtruediv__( self, src: Union[SignalSourceProvider, Num] @@ -835,8 +790,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): self, delays: Optional[DelayMap] = None, prefix: str = "" ) -> Sequence[Optional[Num]]: return [ - self.current_output(i, delays, prefix) - for i in range(self.output_count) + self.current_output(i, delays, prefix) for i in range(self.output_count) ] def evaluate_outputs( @@ -927,9 +881,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): Operations input ports. """ return [ - signal.source.operation - for signal in self.input_signals - if signal.source + signal.source.operation for signal in self.input_signals if signal.source ] @property @@ -1008,10 +960,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): return max( ( - ( - cast(int, output.latency_offset) - - cast(int, input.latency_offset) - ) + (cast(int, output.latency_offset) - cast(int, input.latency_offset)) for output, input in it.product(self.outputs, self.inputs) ) ) @@ -1116,9 +1065,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): def get_plot_coordinates( self, - ) -> Tuple[ - Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...] - ]: + ) -> Tuple[Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]]: # Doc-string inherited return ( self._get_plot_coordinates_for_latency(), @@ -1169,28 +1116,22 @@ class AbstractOperation(Operation, AbstractGraphComponent): def get_input_coordinates(self) -> Tuple[Tuple[float, float], ...]: # doc-string inherited + num_in = self.input_count return tuple( ( self.input_latency_offsets()[k], - (1 + 2 * k) / (2 * len(self.inputs)), + (1 + 2 * k) / (2 * num_in), ) - for k in range(len(self.inputs)) + for k in range(num_in) ) def get_output_coordinates(self) -> Tuple[Tuple[float, float], ...]: # doc-string inherited + num_out = self.output_count return tuple( ( self.output_latency_offsets()[k], - (1 + 2 * k) / (2 * len(self.outputs)), + (1 + 2 * k) / (2 * num_out), ) - for k in range(len(self.outputs)) + for k in range(num_out) ) - - def get_io_coordinates( - self, - ) -> Tuple[ - Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...] - ]: - # Doc-string inherited - return self.get_input_coordinates(), self.get_output_coordinates() diff --git a/b_asic/scheduler_gui/operation_item.py b/b_asic/scheduler_gui/operation_item.py index 244723ba9ca0ee53bed48906db3e55547180d356..e6881b7adfe16bc89a27c7cdbbb8dface32ac13a 100644 --- a/b_asic/scheduler_gui/operation_item.py +++ b/b_asic/scheduler_gui/operation_item.py @@ -219,8 +219,6 @@ class OperationItem(QGraphicsItemGroup): # component item self._set_background(OPERATION_LATENCY_INACTIVE) # used by component filling - inputs, outputs = self._operation.get_io_coordinates() - def create_ports(io_coordinates, prefix): for i, (x, y) in enumerate(io_coordinates): pos = QPointF(x, y * self._height) @@ -235,8 +233,8 @@ class OperationItem(QGraphicsItemGroup): new_port.setPos(port_pos.x(), port_pos.y()) self._port_items.append(new_port) - create_ports(inputs, "in") - create_ports(outputs, "out") + create_ports(self._operation.get_input_coordinates(), "in") + create_ports(self._operation.get_output_coordinates(), "out") # op-id/label self._label_item = QGraphicsSimpleTextItem(self._operation.graph_id) diff --git a/test/test_operation.py b/test/test_operation.py index b73c678069f8b6e1633331af4a9edd8f11e57b19..0ce438b644618a8f1fe8ddc0141cb772a84bb3a2 100644 --- a/test/test_operation.py +++ b/test/test_operation.py @@ -21,8 +21,7 @@ from b_asic import ( class TestOperationOverloading: def test_addition_overload(self): - """Tests addition overloading for both operation and number argument. - """ + """Tests addition overloading for both operation and number argument.""" add1 = Addition(None, None, "add1") add2 = Addition(None, None, "add2") @@ -42,8 +41,7 @@ class TestOperationOverloading: assert add5.input(1).signals == add4.output(0).signals def test_subtraction_overload(self): - """Tests subtraction overloading for both operation and number argument. - """ + """Tests subtraction overloading for both operation and number argument.""" add1 = Addition(None, None, "add1") add2 = Addition(None, None, "add2") @@ -63,8 +61,7 @@ class TestOperationOverloading: assert sub3.input(1).signals == sub2.output(0).signals def test_multiplication_overload(self): - """Tests multiplication overloading for both operation and number argument. - """ + """Tests multiplication overloading for both operation and number argument.""" add1 = Addition(None, None, "add1") add2 = Addition(None, None, "add2") @@ -84,8 +81,7 @@ class TestOperationOverloading: assert mul3.value == 5 def test_division_overload(self): - """Tests division overloading for both operation and number argument. - """ + """Tests division overloading for both operation and number argument.""" add1 = Addition(None, None, "add1") add2 = Addition(None, None, "add2") @@ -125,18 +121,8 @@ class TestTraverse: def test_traverse_type(self, large_operation_tree): result = list(large_operation_tree.traverse()) - assert ( - len( - list(filter(lambda type_: isinstance(type_, Addition), result)) - ) - == 3 - ) - assert ( - len( - list(filter(lambda type_: isinstance(type_, Constant), result)) - ) - == 4 - ) + assert len(list(filter(lambda type_: isinstance(type_, Addition), result))) == 3 + assert len(list(filter(lambda type_: isinstance(type_, Constant), result))) == 4 def test_traverse_loop(self, operation_graph_with_cycle): assert len(list(operation_graph_with_cycle.traverse())) == 8 @@ -184,9 +170,7 @@ class TestLatency: } def test_latency_offsets_constructor(self): - bfly = Butterfly( - latency_offsets={"in0": 2, "in1": 3, "out0": 5, "out1": 10} - ) + bfly = Butterfly(latency_offsets={"in0": 2, "in1": 3, "out0": 5, "out1": 10}) assert bfly.latency == 8 assert bfly.latency_offsets == { @@ -233,17 +217,13 @@ class TestExecutionTime: def test_set_execution_time_negative(self): bfly = Butterfly() - with pytest.raises( - ValueError, match="Execution time cannot be negative" - ): + with pytest.raises(ValueError, match="Execution time cannot be negative"): bfly.execution_time = -1 class TestCopyOperation: def test_copy_butterfly_latency_offsets(self): - bfly = Butterfly( - latency_offsets={"in0": 4, "in1": 2, "out0": 10, "out1": 9} - ) + bfly = Butterfly(latency_offsets={"in0": 4, "in1": 2, "out0": 10, "out1": 9}) bfly_copy = bfly.copy_component() @@ -274,9 +254,7 @@ class TestPlotCoordinates: assert exe == ((0, 0), (0, 1), (1, 1), (1, 0), (0, 0)) def test_complicated_case(self): - bfly = Butterfly( - latency_offsets={"in0": 2, "in1": 3, "out0": 5, "out1": 10} - ) + bfly = Butterfly(latency_offsets={"in0": 2, "in1": 3, "out0": 5, "out1": 10}) bfly.execution_time = 7 lat, exe = bfly.get_plot_coordinates() @@ -300,28 +278,22 @@ class TestIOCoordinates: cmult.execution_time = 1 cmult.set_latency(3) - i_c, o_c = cmult.get_io_coordinates() - assert i_c == ((0, 0.5),) - assert o_c == ((3, 0.5),) + assert cmult.get_input_coordinates() == ((0, 0.5),) + assert cmult.get_output_coordinates() == ((3, 0.5),) def test_complicated_case(self): - bfly = Butterfly( - latency_offsets={"in0": 2, "in1": 3, "out0": 5, "out1": 10} - ) + bfly = Butterfly(latency_offsets={"in0": 2, "in1": 3, "out0": 5, "out1": 10}) bfly.execution_time = 7 - i_c, o_c = bfly.get_io_coordinates() - assert i_c == ((2, 0.25), (3, 0.75)) - assert o_c == ((5, 0.25), (10, 0.75)) + assert bfly.get_input_coordinates() == ((2, 0.25), (3, 0.75)) + assert bfly.get_output_coordinates() == ((5, 0.25), (10, 0.75)) def test_io_coordinates_error(self): bfly = Butterfly() bfly.set_latency_offsets({"in0": 3, "out1": 5}) - with pytest.raises( - ValueError, match="Missing latencies for inputs \\[1\\]" - ): - bfly.get_io_coordinates() + with pytest.raises(ValueError, match="Missing latencies for inputs \\[1\\]"): + bfly.get_input_coordinates() class TestSplit: