Skip to content
Snippets Groups Projects
Commit 6e9b10ad authored by Oscar Gustafsson's avatar Oscar Gustafsson :bicyclist:
Browse files

Add Reciprocal operation

parent a43ef562
No related branches found
No related tags found
1 merge request!163Add Reciprocal operation
Pipeline #88832 passed
This commit is part of merge request !163. Comments created here will be created in the context of that merge request.
...@@ -614,8 +614,8 @@ class Butterfly(AbstractOperation): ...@@ -614,8 +614,8 @@ class Butterfly(AbstractOperation):
.. math:: .. math::
\begin{eqnarray} \begin{eqnarray}
y_0 = x_0 + x_1\\ y_0 & = & x_0 + x_1\\
y_1 = x_0 - x_1 y_1 & = & x_0 - x_1
\end{eqnarray} \end{eqnarray}
""" """
...@@ -692,8 +692,8 @@ class SymmetricTwoportAdaptor(AbstractOperation): ...@@ -692,8 +692,8 @@ class SymmetricTwoportAdaptor(AbstractOperation):
.. math:: .. math::
\begin{eqnarray} \begin{eqnarray}
y_0 = x_1 + \text{value}\times\left(x_1 - x_0\right)\\ y_0 & = & x_1 + \text{value}\times\left(x_1 - x_0\right)\\
y_1 = x_0 + \text{value}\times\left(x_1 - x_0\right) y_1 & = & x_0 + \text{value}\times\left(x_1 - x_0\right)
\end{eqnarray} \end{eqnarray}
""" """
...@@ -736,3 +736,39 @@ class SymmetricTwoportAdaptor(AbstractOperation): ...@@ -736,3 +736,39 @@ class SymmetricTwoportAdaptor(AbstractOperation):
def value(self, value: Number) -> None: def value(self, value: Number) -> None:
"""Set the constant value of this operation.""" """Set the constant value of this operation."""
self.set_param("value", value) self.set_param("value", value)
class Reciprocal(AbstractOperation):
r"""
Reciprocal operation.
Gives the reciprocal of its input.
.. math:: y = \frac{1}{x}
"""
def __init__(
self,
src0: Optional[SignalSourceProvider] = None,
name: Name = Name(""),
latency: Optional[int] = None,
latency_offsets: Optional[Dict[str, int]] = None,
execution_time: Optional[int] = None,
):
"""Construct an Reciprocal operation."""
super().__init__(
input_count=1,
output_count=1,
name=Name(name),
input_sources=[src0],
latency=latency,
latency_offsets=latency_offsets,
execution_time=execution_time,
)
@classmethod
def type_name(cls) -> TypeName:
return TypeName("rec")
def evaluate(self, a):
return 1 / a
...@@ -40,6 +40,7 @@ if TYPE_CHECKING: ...@@ -40,6 +40,7 @@ if TYPE_CHECKING:
ConstantMultiplication, ConstantMultiplication,
Division, Division,
Multiplication, Multiplication,
Reciprocal,
Subtraction, Subtraction,
) )
from b_asic.signal_flow_graph import SFG from b_asic.signal_flow_graph import SFG
...@@ -135,7 +136,7 @@ class Operation(GraphComponent, SignalSourceProvider): ...@@ -135,7 +136,7 @@ class Operation(GraphComponent, SignalSourceProvider):
@abstractmethod @abstractmethod
def __rtruediv__( def __rtruediv__(
self, src: Union[SignalSourceProvider, Number] self, src: Union[SignalSourceProvider, Number]
) -> "Division": ) -> Union["Division", "Reciprocal"]:
""" """
Overloads the division operator to make it return a new Division operation Overloads the division operator to make it return a new Division operation
object that is connected to the self and other objects. object that is connected to the self and other objects.
...@@ -387,7 +388,7 @@ class Operation(GraphComponent, SignalSourceProvider): ...@@ -387,7 +388,7 @@ class Operation(GraphComponent, SignalSourceProvider):
self, self,
) -> Tuple[List[List[float]], List[List[float]]]: ) -> Tuple[List[List[float]], List[List[float]]]:
""" """
Get a tuple constaining coordinates for the two polygons outlining Return a tuple containing coordinates for the two polygons outlining
the latency and execution time of the operation. the latency and execution time of the operation.
The polygons are corresponding to a start time of 0 and are of height 1. The polygons are corresponding to a start time of 0 and are of height 1.
""" """
...@@ -398,7 +399,7 @@ class Operation(GraphComponent, SignalSourceProvider): ...@@ -398,7 +399,7 @@ class Operation(GraphComponent, SignalSourceProvider):
self, self,
) -> Tuple[List[List[float]], List[List[float]]]: ) -> Tuple[List[List[float]], List[List[float]]]:
""" """
Get a tuple constaining coordinates for inputs and outputs, respectively. Return a tuple containing coordinates for inputs and outputs, respectively.
These maps to the polygons and are corresponding to a start time of 0 These maps to the polygons and are corresponding to a start time of 0
and height 1. and height 1.
""" """
...@@ -500,9 +501,9 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -500,9 +501,9 @@ class AbstractOperation(Operation, AbstractGraphComponent):
for inp in self.inputs: for inp in self.inputs:
if inp.latency_offset is None: if inp.latency_offset is None:
inp.latency_offset = 0 inp.latency_offset = 0
for outp in self.outputs: for output in self.outputs:
if outp.latency_offset is None: if output.latency_offset is None:
outp.latency_offset = latency output.latency_offset = latency
self._execution_time = execution_time self._execution_time = execution_time
...@@ -592,13 +593,16 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -592,13 +593,16 @@ class AbstractOperation(Operation, AbstractGraphComponent):
def __rtruediv__( def __rtruediv__(
self, src: Union[SignalSourceProvider, Number] self, src: Union[SignalSourceProvider, Number]
) -> "Division": ) -> Union["Division", "Reciprocal"]:
# Import here to avoid circular imports. # Import here to avoid circular imports.
from b_asic.core_operations import Constant, Division from b_asic.core_operations import Constant, Division, Reciprocal
return Division( if isinstance(src, Number):
Constant(src) if isinstance(src, Number) else src, self if src == 1:
) return Reciprocal(self)
else:
return Division(Constant(src), self)
return Division(src, self)
def __lshift__(self, src: SignalSourceProvider) -> Signal: def __lshift__(self, src: SignalSourceProvider) -> Signal:
if self.input_count != 1: if self.input_count != 1:
...@@ -835,10 +839,10 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -835,10 +839,10 @@ class AbstractOperation(Operation, AbstractGraphComponent):
new_component: Operation = cast( new_component: Operation = cast(
Operation, super().copy_component(*args, **kwargs) Operation, super().copy_component(*args, **kwargs)
) )
for i, inp in enumerate(self.inputs): for i, input in enumerate(self.inputs):
new_component.input(i).latency_offset = inp.latency_offset new_component.input(i).latency_offset = input.latency_offset
for i, outp in enumerate(self.outputs): for i, output in enumerate(self.outputs):
new_component.output(i).latency_offset = outp.latency_offset new_component.output(i).latency_offset = output.latency_offset
new_component.execution_time = self._execution_time new_component.execution_time = self._execution_time
return new_component return new_component
...@@ -930,7 +934,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -930,7 +934,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
@property @property
def latency(self) -> int: def latency(self) -> int:
if None in [inp.latency_offset for inp in self.inputs] or None in [ if None in [inp.latency_offset for inp in self.inputs] or None in [
outp.latency_offset for outp in self.outputs output.latency_offset for output in self.outputs
]: ]:
raise ValueError( raise ValueError(
"All native offsets have to set to a non-negative value to" "All native offsets have to set to a non-negative value to"
...@@ -940,10 +944,10 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -940,10 +944,10 @@ class AbstractOperation(Operation, AbstractGraphComponent):
return max( return max(
( (
( (
cast(int, outp.latency_offset) cast(int, output.latency_offset)
- cast(int, inp.latency_offset) - cast(int, input.latency_offset)
) )
for outp, inp in it.product(self.outputs, self.inputs) for output, input in it.product(self.outputs, self.inputs)
) )
) )
...@@ -951,11 +955,11 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -951,11 +955,11 @@ class AbstractOperation(Operation, AbstractGraphComponent):
def latency_offsets(self) -> Dict[str, Optional[int]]: def latency_offsets(self) -> Dict[str, Optional[int]]:
latency_offsets = {} latency_offsets = {}
for i, inp in enumerate(self.inputs): for i, input in enumerate(self.inputs):
latency_offsets[f"in{i}"] = inp.latency_offset latency_offsets[f"in{i}"] = input.latency_offset
for i, outp in enumerate(self.outputs): for i, output in enumerate(self.outputs):
latency_offsets[f"out{i}"] = outp.latency_offset latency_offsets[f"out{i}"] = output.latency_offset
return latency_offsets return latency_offsets
...@@ -1072,18 +1076,18 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -1072,18 +1076,18 @@ class AbstractOperation(Operation, AbstractGraphComponent):
) -> Tuple[List[List[float]], List[List[float]]]: ) -> Tuple[List[List[float]], List[List[float]]]:
# Doc-string inherited # Doc-string inherited
self._check_all_latencies_set() self._check_all_latencies_set()
input_coords = [ input_coordinates = [
[ [
self.inputs[k].latency_offset, self.inputs[k].latency_offset,
(1 + 2 * k) / (2 * len(self.inputs)), (1 + 2 * k) / (2 * len(self.inputs)),
] ]
for k in range(len(self.inputs)) for k in range(len(self.inputs))
] ]
output_coords = [ output_coordinates = [
[ [
self.outputs[k].latency_offset, self.outputs[k].latency_offset,
(1 + 2 * k) / (2 * len(self.outputs)), (1 + 2 * k) / (2 * len(self.outputs)),
] ]
for k in range(len(self.outputs)) for k in range(len(self.outputs))
] ]
return input_coords, output_coords return input_coordinates, output_coordinates
...@@ -12,6 +12,7 @@ from b_asic import ( ...@@ -12,6 +12,7 @@ from b_asic import (
Max, Max,
Min, Min,
Multiplication, Multiplication,
Reciprocal,
SquareRoot, SquareRoot,
Subtraction, Subtraction,
SymmetricTwoportAdaptor, SymmetricTwoportAdaptor,
...@@ -261,6 +262,22 @@ class TestSymmetricTwoportAdaptor: ...@@ -261,6 +262,22 @@ class TestSymmetricTwoportAdaptor:
) )
class TestReciprocal:
"""Tests for Absolute class."""
def test_reciprocal_positive(self):
test_operation = Reciprocal()
assert test_operation.evaluate_output(0, [2]) == 0.5
def test_reciprocal_negative(self):
test_operation = Reciprocal()
assert test_operation.evaluate_output(0, [-5]) == -0.2
def test_reciprocal_complex(self):
test_operation = Reciprocal()
assert test_operation.evaluate_output(0, [1 + 1j]) == 0.5 - 0.5j
class TestDepends: class TestDepends:
def test_depends_addition(self): def test_depends_addition(self):
add1 = Addition() add1 = Addition()
......
...@@ -10,6 +10,7 @@ from b_asic import ( ...@@ -10,6 +10,7 @@ from b_asic import (
ConstantMultiplication, ConstantMultiplication,
Division, Division,
Multiplication, Multiplication,
Reciprocal,
SquareRoot, SquareRoot,
Subtraction, Subtraction,
) )
...@@ -100,6 +101,10 @@ class TestOperationOverloading: ...@@ -100,6 +101,10 @@ class TestOperationOverloading:
assert div3.input(0).signals[0].source.operation.value == 5 assert div3.input(0).signals[0].source.operation.value == 5
assert div3.input(1).signals == div2.output(0).signals assert div3.input(1).signals == div2.output(0).signals
div4 = 1 / div3
assert isinstance(div4, Reciprocal)
assert div4.input(0).signals == div3.output(0).signals
class TestTraverse: class TestTraverse:
def test_traverse_single_tree(self, operation): def test_traverse_single_tree(self, operation):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment