Skip to content
Snippets Groups Projects
Commit f929ba05 authored by Oscar Gustafsson's avatar Oscar Gustafsson :bicyclist:
Browse files

Remove get_io_coordinates

parent fa15c7a4
No related branches found
No related tags found
1 merge request!212Remove get_io_coordinates
Pipeline #90037 passed
...@@ -25,12 +25,7 @@ from typing import ( ...@@ -25,12 +25,7 @@ from typing import (
overload, overload,
) )
from b_asic.graph_component import ( from b_asic.graph_component import AbstractGraphComponent, GraphComponent, GraphID, Name
AbstractGraphComponent,
GraphComponent,
GraphID,
Name,
)
from b_asic.port import InputPort, OutputPort, SignalSourceProvider from b_asic.port import InputPort, OutputPort, SignalSourceProvider
from b_asic.signal import Signal from b_asic.signal import Signal
from b_asic.types import Num from b_asic.types import Num
...@@ -403,9 +398,7 @@ class Operation(GraphComponent, SignalSourceProvider): ...@@ -403,9 +398,7 @@ class Operation(GraphComponent, SignalSourceProvider):
@abstractmethod @abstractmethod
def get_plot_coordinates( def get_plot_coordinates(
self, self,
) -> Tuple[ ) -> Tuple[Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]]:
Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]
]:
""" """
Return a tuple containing coordinates for the two polygons outlining Return a tuple containing coordinates for the two polygons outlining
the latency and execution time of the operation. the latency and execution time of the operation.
...@@ -413,24 +406,6 @@ class Operation(GraphComponent, SignalSourceProvider): ...@@ -413,24 +406,6 @@ class Operation(GraphComponent, SignalSourceProvider):
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def get_io_coordinates(
self,
) -> Tuple[
Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]
]:
"""
Return a tuple containing coordinates for inputs and outputs, respectively.
These maps to the polygons and are corresponding to a start time of 0
and height 1.
See also
========
get_input_coordinates
get_output_coordinates
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def get_input_coordinates( def get_input_coordinates(
self, self,
...@@ -442,7 +417,6 @@ class Operation(GraphComponent, SignalSourceProvider): ...@@ -442,7 +417,6 @@ class Operation(GraphComponent, SignalSourceProvider):
See also See also
======== ========
get_io_coordinates
get_output_coordinates get_output_coordinates
""" """
raise NotImplementedError raise NotImplementedError
...@@ -459,7 +433,6 @@ class Operation(GraphComponent, SignalSourceProvider): ...@@ -459,7 +433,6 @@ class Operation(GraphComponent, SignalSourceProvider):
See also See also
======== ========
get_input_coordinates get_input_coordinates
get_io_coordinates
""" """
raise NotImplementedError raise NotImplementedError
...@@ -512,9 +485,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -512,9 +485,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
input_count: int, input_count: int,
output_count: int, output_count: int,
name: Name = Name(""), name: Name = Name(""),
input_sources: Optional[ input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None,
Sequence[Optional[SignalSourceProvider]]
] = None,
latency: Optional[int] = None, latency: Optional[int] = None,
latency_offsets: Optional[Dict[str, int]] = None, latency_offsets: Optional[Dict[str, int]] = None,
execution_time: Optional[int] = None, execution_time: Optional[int] = None,
...@@ -575,9 +546,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -575,9 +546,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
@overload @overload
@abstractmethod @abstractmethod
def evaluate( def evaluate(self, *inputs: Num) -> List[Num]: # pylint: disable=arguments-differ
self, *inputs: Num
) -> List[Num]: # pylint: disable=arguments-differ
... ...
@abstractmethod @abstractmethod
...@@ -601,34 +570,25 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -601,34 +570,25 @@ class AbstractOperation(Operation, AbstractGraphComponent):
# Import here to avoid circular imports. # Import here to avoid circular imports.
from b_asic.core_operations import Addition, Constant from b_asic.core_operations import Addition, Constant
return Addition( return Addition(Constant(src) if isinstance(src, Number) else src, self)
Constant(src) if isinstance(src, Number) else src, self
)
def __sub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction": def __sub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction":
# Import here to avoid circular imports. # Import here to avoid circular imports.
from b_asic.core_operations import Constant, Subtraction from b_asic.core_operations import Constant, Subtraction
return Subtraction( return Subtraction(self, Constant(src) if isinstance(src, Number) else src)
self, Constant(src) if isinstance(src, Number) else src
)
def __rsub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction": def __rsub__(self, src: Union[SignalSourceProvider, Num]) -> "Subtraction":
# Import here to avoid circular imports. # Import here to avoid circular imports.
from b_asic.core_operations import Constant, Subtraction from b_asic.core_operations import Constant, Subtraction
return Subtraction( return Subtraction(Constant(src) if isinstance(src, Number) else src, self)
Constant(src) if isinstance(src, Number) else src, self
)
def __mul__( def __mul__(
self, src: Union[SignalSourceProvider, Num] self, src: Union[SignalSourceProvider, Num]
) -> Union["Multiplication", "ConstantMultiplication"]: ) -> Union["Multiplication", "ConstantMultiplication"]:
# Import here to avoid circular imports. # Import here to avoid circular imports.
from b_asic.core_operations import ( from b_asic.core_operations import ConstantMultiplication, Multiplication
ConstantMultiplication,
Multiplication,
)
return ( return (
ConstantMultiplication(src, self) ConstantMultiplication(src, self)
...@@ -640,10 +600,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -640,10 +600,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
self, src: Union[SignalSourceProvider, Num] self, src: Union[SignalSourceProvider, Num]
) -> Union["Multiplication", "ConstantMultiplication"]: ) -> Union["Multiplication", "ConstantMultiplication"]:
# Import here to avoid circular imports. # Import here to avoid circular imports.
from b_asic.core_operations import ( from b_asic.core_operations import ConstantMultiplication, Multiplication
ConstantMultiplication,
Multiplication,
)
return ( return (
ConstantMultiplication(src, self) ConstantMultiplication(src, self)
...@@ -655,9 +612,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -655,9 +612,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
# Import here to avoid circular imports. # Import here to avoid circular imports.
from b_asic.core_operations import Constant, Division from b_asic.core_operations import Constant, Division
return Division( return Division(self, Constant(src) if isinstance(src, Number) else src)
self, Constant(src) if isinstance(src, Number) else src
)
def __rtruediv__( def __rtruediv__(
self, src: Union[SignalSourceProvider, Num] self, src: Union[SignalSourceProvider, Num]
...@@ -835,8 +790,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -835,8 +790,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
self, delays: Optional[DelayMap] = None, prefix: str = "" self, delays: Optional[DelayMap] = None, prefix: str = ""
) -> Sequence[Optional[Num]]: ) -> Sequence[Optional[Num]]:
return [ return [
self.current_output(i, delays, prefix) self.current_output(i, delays, prefix) for i in range(self.output_count)
for i in range(self.output_count)
] ]
def evaluate_outputs( def evaluate_outputs(
...@@ -927,9 +881,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -927,9 +881,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
Operations input ports. Operations input ports.
""" """
return [ return [
signal.source.operation signal.source.operation for signal in self.input_signals if signal.source
for signal in self.input_signals
if signal.source
] ]
@property @property
...@@ -1008,10 +960,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -1008,10 +960,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
return max( return max(
( (
( (cast(int, output.latency_offset) - cast(int, input.latency_offset))
cast(int, output.latency_offset)
- cast(int, input.latency_offset)
)
for output, input in it.product(self.outputs, self.inputs) for output, input in it.product(self.outputs, self.inputs)
) )
) )
...@@ -1116,9 +1065,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -1116,9 +1065,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
def get_plot_coordinates( def get_plot_coordinates(
self, self,
) -> Tuple[ ) -> Tuple[Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]]:
Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]
]:
# Doc-string inherited # Doc-string inherited
return ( return (
self._get_plot_coordinates_for_latency(), self._get_plot_coordinates_for_latency(),
...@@ -1169,28 +1116,22 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -1169,28 +1116,22 @@ class AbstractOperation(Operation, AbstractGraphComponent):
def get_input_coordinates(self) -> Tuple[Tuple[float, float], ...]: def get_input_coordinates(self) -> Tuple[Tuple[float, float], ...]:
# doc-string inherited # doc-string inherited
num_in = self.input_count
return tuple( return tuple(
( (
self.input_latency_offsets()[k], self.input_latency_offsets()[k],
(1 + 2 * k) / (2 * len(self.inputs)), (1 + 2 * k) / (2 * num_in),
) )
for k in range(len(self.inputs)) for k in range(num_in)
) )
def get_output_coordinates(self) -> Tuple[Tuple[float, float], ...]: def get_output_coordinates(self) -> Tuple[Tuple[float, float], ...]:
# doc-string inherited # doc-string inherited
num_out = self.output_count
return tuple( return tuple(
( (
self.output_latency_offsets()[k], self.output_latency_offsets()[k],
(1 + 2 * k) / (2 * len(self.outputs)), (1 + 2 * k) / (2 * num_out),
) )
for k in range(len(self.outputs)) for k in range(num_out)
) )
def get_io_coordinates(
self,
) -> Tuple[
Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]
]:
# Doc-string inherited
return self.get_input_coordinates(), self.get_output_coordinates()
...@@ -219,8 +219,6 @@ class OperationItem(QGraphicsItemGroup): ...@@ -219,8 +219,6 @@ class OperationItem(QGraphicsItemGroup):
# component item # component item
self._set_background(OPERATION_LATENCY_INACTIVE) # used by component filling self._set_background(OPERATION_LATENCY_INACTIVE) # used by component filling
inputs, outputs = self._operation.get_io_coordinates()
def create_ports(io_coordinates, prefix): def create_ports(io_coordinates, prefix):
for i, (x, y) in enumerate(io_coordinates): for i, (x, y) in enumerate(io_coordinates):
pos = QPointF(x, y * self._height) pos = QPointF(x, y * self._height)
...@@ -235,8 +233,8 @@ class OperationItem(QGraphicsItemGroup): ...@@ -235,8 +233,8 @@ class OperationItem(QGraphicsItemGroup):
new_port.setPos(port_pos.x(), port_pos.y()) new_port.setPos(port_pos.x(), port_pos.y())
self._port_items.append(new_port) self._port_items.append(new_port)
create_ports(inputs, "in") create_ports(self._operation.get_input_coordinates(), "in")
create_ports(outputs, "out") create_ports(self._operation.get_output_coordinates(), "out")
# op-id/label # op-id/label
self._label_item = QGraphicsSimpleTextItem(self._operation.graph_id) self._label_item = QGraphicsSimpleTextItem(self._operation.graph_id)
......
...@@ -21,8 +21,7 @@ from b_asic import ( ...@@ -21,8 +21,7 @@ from b_asic import (
class TestOperationOverloading: class TestOperationOverloading:
def test_addition_overload(self): def test_addition_overload(self):
"""Tests addition overloading for both operation and number argument. """Tests addition overloading for both operation and number argument."""
"""
add1 = Addition(None, None, "add1") add1 = Addition(None, None, "add1")
add2 = Addition(None, None, "add2") add2 = Addition(None, None, "add2")
...@@ -42,8 +41,7 @@ class TestOperationOverloading: ...@@ -42,8 +41,7 @@ class TestOperationOverloading:
assert add5.input(1).signals == add4.output(0).signals assert add5.input(1).signals == add4.output(0).signals
def test_subtraction_overload(self): def test_subtraction_overload(self):
"""Tests subtraction overloading for both operation and number argument. """Tests subtraction overloading for both operation and number argument."""
"""
add1 = Addition(None, None, "add1") add1 = Addition(None, None, "add1")
add2 = Addition(None, None, "add2") add2 = Addition(None, None, "add2")
...@@ -63,8 +61,7 @@ class TestOperationOverloading: ...@@ -63,8 +61,7 @@ class TestOperationOverloading:
assert sub3.input(1).signals == sub2.output(0).signals assert sub3.input(1).signals == sub2.output(0).signals
def test_multiplication_overload(self): def test_multiplication_overload(self):
"""Tests multiplication overloading for both operation and number argument. """Tests multiplication overloading for both operation and number argument."""
"""
add1 = Addition(None, None, "add1") add1 = Addition(None, None, "add1")
add2 = Addition(None, None, "add2") add2 = Addition(None, None, "add2")
...@@ -84,8 +81,7 @@ class TestOperationOverloading: ...@@ -84,8 +81,7 @@ class TestOperationOverloading:
assert mul3.value == 5 assert mul3.value == 5
def test_division_overload(self): def test_division_overload(self):
"""Tests division overloading for both operation and number argument. """Tests division overloading for both operation and number argument."""
"""
add1 = Addition(None, None, "add1") add1 = Addition(None, None, "add1")
add2 = Addition(None, None, "add2") add2 = Addition(None, None, "add2")
...@@ -125,18 +121,8 @@ class TestTraverse: ...@@ -125,18 +121,8 @@ class TestTraverse:
def test_traverse_type(self, large_operation_tree): def test_traverse_type(self, large_operation_tree):
result = list(large_operation_tree.traverse()) result = list(large_operation_tree.traverse())
assert ( assert len(list(filter(lambda type_: isinstance(type_, Addition), result))) == 3
len( assert len(list(filter(lambda type_: isinstance(type_, Constant), result))) == 4
list(filter(lambda type_: isinstance(type_, Addition), result))
)
== 3
)
assert (
len(
list(filter(lambda type_: isinstance(type_, Constant), result))
)
== 4
)
def test_traverse_loop(self, operation_graph_with_cycle): def test_traverse_loop(self, operation_graph_with_cycle):
assert len(list(operation_graph_with_cycle.traverse())) == 8 assert len(list(operation_graph_with_cycle.traverse())) == 8
...@@ -184,9 +170,7 @@ class TestLatency: ...@@ -184,9 +170,7 @@ class TestLatency:
} }
def test_latency_offsets_constructor(self): def test_latency_offsets_constructor(self):
bfly = Butterfly( bfly = Butterfly(latency_offsets={"in0": 2, "in1": 3, "out0": 5, "out1": 10})
latency_offsets={"in0": 2, "in1": 3, "out0": 5, "out1": 10}
)
assert bfly.latency == 8 assert bfly.latency == 8
assert bfly.latency_offsets == { assert bfly.latency_offsets == {
...@@ -233,17 +217,13 @@ class TestExecutionTime: ...@@ -233,17 +217,13 @@ class TestExecutionTime:
def test_set_execution_time_negative(self): def test_set_execution_time_negative(self):
bfly = Butterfly() bfly = Butterfly()
with pytest.raises( with pytest.raises(ValueError, match="Execution time cannot be negative"):
ValueError, match="Execution time cannot be negative"
):
bfly.execution_time = -1 bfly.execution_time = -1
class TestCopyOperation: class TestCopyOperation:
def test_copy_butterfly_latency_offsets(self): def test_copy_butterfly_latency_offsets(self):
bfly = Butterfly( bfly = Butterfly(latency_offsets={"in0": 4, "in1": 2, "out0": 10, "out1": 9})
latency_offsets={"in0": 4, "in1": 2, "out0": 10, "out1": 9}
)
bfly_copy = bfly.copy_component() bfly_copy = bfly.copy_component()
...@@ -274,9 +254,7 @@ class TestPlotCoordinates: ...@@ -274,9 +254,7 @@ class TestPlotCoordinates:
assert exe == ((0, 0), (0, 1), (1, 1), (1, 0), (0, 0)) assert exe == ((0, 0), (0, 1), (1, 1), (1, 0), (0, 0))
def test_complicated_case(self): def test_complicated_case(self):
bfly = Butterfly( bfly = Butterfly(latency_offsets={"in0": 2, "in1": 3, "out0": 5, "out1": 10})
latency_offsets={"in0": 2, "in1": 3, "out0": 5, "out1": 10}
)
bfly.execution_time = 7 bfly.execution_time = 7
lat, exe = bfly.get_plot_coordinates() lat, exe = bfly.get_plot_coordinates()
...@@ -300,28 +278,22 @@ class TestIOCoordinates: ...@@ -300,28 +278,22 @@ class TestIOCoordinates:
cmult.execution_time = 1 cmult.execution_time = 1
cmult.set_latency(3) cmult.set_latency(3)
i_c, o_c = cmult.get_io_coordinates() assert cmult.get_input_coordinates() == ((0, 0.5),)
assert i_c == ((0, 0.5),) assert cmult.get_output_coordinates() == ((3, 0.5),)
assert o_c == ((3, 0.5),)
def test_complicated_case(self): def test_complicated_case(self):
bfly = Butterfly( bfly = Butterfly(latency_offsets={"in0": 2, "in1": 3, "out0": 5, "out1": 10})
latency_offsets={"in0": 2, "in1": 3, "out0": 5, "out1": 10}
)
bfly.execution_time = 7 bfly.execution_time = 7
i_c, o_c = bfly.get_io_coordinates() assert bfly.get_input_coordinates() == ((2, 0.25), (3, 0.75))
assert i_c == ((2, 0.25), (3, 0.75)) assert bfly.get_output_coordinates() == ((5, 0.25), (10, 0.75))
assert o_c == ((5, 0.25), (10, 0.75))
def test_io_coordinates_error(self): def test_io_coordinates_error(self):
bfly = Butterfly() bfly = Butterfly()
bfly.set_latency_offsets({"in0": 3, "out1": 5}) bfly.set_latency_offsets({"in0": 3, "out1": 5})
with pytest.raises( with pytest.raises(ValueError, match="Missing latencies for inputs \\[1\\]"):
ValueError, match="Missing latencies for inputs \\[1\\]" bfly.get_input_coordinates()
):
bfly.get_io_coordinates()
class TestSplit: class TestSplit:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment