Skip to content
Snippets Groups Projects
Commit 37d0425a authored by Angus Lothian's avatar Angus Lothian :dark_sunglasses: Committed by Ivar Härnqvist
Browse files

Change test of multiple outputs of evaluate output and Butterfly to not depend...

Change test of multiple outputs of evaluate output and Butterfly to not depend on implementation returing list or tuple
parent 7e2d5182
Branches
No related tags found
2 merge requests!67WIP: B-ASIC version 1.0.0 hotfix,!65B-ASIC version 1.0.0
...@@ -4,10 +4,8 @@ TODO: More info. ...@@ -4,10 +4,8 @@ TODO: More info.
""" """
from numbers import Number from numbers import Number
from typing import Any
from numpy import conjugate, sqrt, abs as np_abs from numpy import conjugate, sqrt, abs as np_abs
from b_asic.port import InputPort, OutputPort from b_asic.port import InputPort, OutputPort
from b_asic.graph_id import GraphIDType
from b_asic.operation import AbstractOperation from b_asic.operation import AbstractOperation
from b_asic.graph_component import Name, TypeName from b_asic.graph_component import Name, TypeName
...@@ -335,3 +333,28 @@ class ConstantDivision(AbstractOperation): ...@@ -335,3 +333,28 @@ class ConstantDivision(AbstractOperation):
@property @property
def type_name(self) -> TypeName: def type_name(self) -> TypeName:
return "cdiv" return "cdiv"
class Butterfly(AbstractOperation):
"""Butterfly operation that returns two outputs.
The first output is a + b and the second output is a - b.
TODO: More info.
"""
def __init__(self, source1: OutputPort = None, source2: OutputPort = None, name: Name = ""):
super().__init__(name)
self._input_ports = [InputPort(0, self), InputPort(1, self)]
self._output_ports = [OutputPort(0, self), OutputPort(1, self)]
if source1 is not None:
self._input_ports[0].connect(source1)
if source2 is not None:
self._input_ports[1].connect(source2)
def evaluate(self, a, b):
return a + b, a - b
@property
def type_name(self) -> TypeName:
return "bfly"
...@@ -5,12 +5,10 @@ TODO: More info. ...@@ -5,12 +5,10 @@ TODO: More info.
from abc import abstractmethod from abc import abstractmethod
from numbers import Number from numbers import Number
from typing import List, Dict, Optional, Any, Set, TYPE_CHECKING from typing import List, Dict, Optional, Any, Set, Sequence, TYPE_CHECKING
from collections import deque from collections import deque
from b_asic.graph_component import GraphComponent, AbstractGraphComponent, Name from b_asic.graph_component import GraphComponent, AbstractGraphComponent, Name
from b_asic.simulation import SimulationState, OperationState
from b_asic.signal import Signal
if TYPE_CHECKING: if TYPE_CHECKING:
from b_asic.port import InputPort, OutputPort from b_asic.port import InputPort, OutputPort
...@@ -51,6 +49,12 @@ class Operation(GraphComponent): ...@@ -51,6 +49,12 @@ class Operation(GraphComponent):
"""Get the output port at index i.""" """Get the output port at index i."""
raise NotImplementedError raise NotImplementedError
@abstractmethod
def evaluate_output(self, i: int, inputs: Sequence[Number]) -> Sequence[Optional[Number]]:
"""Evaluate the output port at the entered index with the entered input values and
returns all output values that are calulated during the evaluation in a list."""
raise NotImplementedError
@abstractmethod @abstractmethod
def params(self) -> Dict[str, Optional[Any]]: def params(self) -> Dict[str, Optional[Any]]:
"""Get a dictionary of all parameter values.""" """Get a dictionary of all parameter values."""
...@@ -70,13 +74,6 @@ class Operation(GraphComponent): ...@@ -70,13 +74,6 @@ class Operation(GraphComponent):
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def evaluate_outputs(self, state: "SimulationState") -> List[Number]:
"""Simulate the circuit until its iteration count matches that of the simulation state,
then return the resulting output vector.
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def split(self) -> "List[Operation]": def split(self) -> "List[Operation]":
"""Split the operation into multiple operations. """Split the operation into multiple operations.
...@@ -115,6 +112,15 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -115,6 +112,15 @@ class AbstractOperation(Operation, AbstractGraphComponent):
""" """
raise NotImplementedError raise NotImplementedError
def evaluate_output(self, i: int, inputs: Sequence[Number]) -> Sequence[Optional[Number]]:
eval_return = self.evaluate(*inputs)
if isinstance(eval_return, Number):
return [eval_return]
elif isinstance(eval_return, (list, tuple)):
return eval_return
else:
raise TypeError("Incorrect returned type from evaluate function.")
def inputs(self) -> List["InputPort"]: def inputs(self) -> List["InputPort"]:
return self._input_ports.copy() return self._input_ports.copy()
...@@ -143,33 +149,6 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -143,33 +149,6 @@ class AbstractOperation(Operation, AbstractGraphComponent):
assert name in self._parameters # TODO: Error message. assert name in self._parameters # TODO: Error message.
self._parameters[name] = value self._parameters[name] = value
def evaluate_outputs(self, state: SimulationState) -> List[Number]:
# TODO: Check implementation.
input_count: int = self.input_count()
output_count: int = self.output_count()
assert input_count == len(self._input_ports) # TODO: Error message.
assert output_count == len(self._output_ports) # TODO: Error message.
self_state: OperationState = state.operation_states[self]
while self_state.iteration < state.iteration:
input_values: List[Number] = [0] * input_count
for i in range(input_count):
source: Signal = self._input_ports[i].signal
input_values[i] = source.operation.evaluate_outputs(state)[
source.port_index]
self_state.output_values = self.evaluate(input_values)
# TODO: Error message.
assert len(self_state.output_values) == output_count
self_state.iteration += 1
for i in range(output_count):
for signal in self._output_ports[i].signals():
destination: Signal = signal.destination
destination.evaluate_outputs(state)
return self_state.output_values
def split(self) -> List[Operation]: def split(self) -> List[Operation]:
# TODO: Check implementation. # TODO: Check implementation.
results = self.evaluate(self._input_ports) results = self.evaluate(self._input_ports)
...@@ -265,4 +244,3 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -265,4 +244,3 @@ class AbstractOperation(Operation, AbstractGraphComponent):
return ConstantDivision(other, self.output(0)) return ConstantDivision(other, self.output(0))
else: else:
raise TypeError("Other type is not an Operation or a Number.") raise TypeError("Other type is not an Operation or a Number.")
...@@ -11,6 +11,7 @@ from b_asic.signal import Signal ...@@ -11,6 +11,7 @@ from b_asic.signal import Signal
PortIndex = NewType("PortIndex", int) PortIndex = NewType("PortIndex", int)
class Port(ABC): class Port(ABC):
"""Port Interface. """Port Interface.
...@@ -126,6 +127,7 @@ class InputPort(AbstractPort): ...@@ -126,6 +127,7 @@ class InputPort(AbstractPort):
@property @property
def value_length(self) -> Optional[int]: def value_length(self) -> Optional[int]:
"""Return the InputPorts value length."""
return self._value_length return self._value_length
@property @property
...@@ -144,7 +146,8 @@ class InputPort(AbstractPort): ...@@ -144,7 +146,8 @@ class InputPort(AbstractPort):
def connect(self, port: "OutputPort") -> Signal: def connect(self, port: "OutputPort") -> Signal:
assert self._source_signal is None, "Connecting new port to already connected input port." assert self._source_signal is None, "Connecting new port to already connected input port."
return Signal(port, self) # self._source_signal is set by the signal constructor. # self._source_signal is set by the signal constructor.
return Signal(port, self)
def add_signal(self, signal: Signal) -> None: def add_signal(self, signal: Signal) -> None:
assert self._source_signal is None, "Connecting new port to already connected input port." assert self._source_signal is None, "Connecting new port to already connected input port."
...@@ -183,24 +186,21 @@ class OutputPort(AbstractPort): ...@@ -183,24 +186,21 @@ class OutputPort(AbstractPort):
def signals(self) -> List[Signal]: def signals(self) -> List[Signal]:
return self._destination_signals.copy() return self._destination_signals.copy()
def signal(self, i: int = 0) -> Signal:
assert 0 <= i < self.signal_count(), "Signal index out of bounds."
return self._destination_signals[i]
@property @property
def connected_ports(self) -> List[Port]: def connected_ports(self) -> List[Port]:
return [signal.destination for signal in self._destination_signals \ return [signal.destination for signal in self._destination_signals
if signal.destination is not None] if signal.destination is not None]
def signal_count(self) -> int: def signal_count(self) -> int:
return len(self._destination_signals) return len(self._destination_signals)
def connect(self, port: InputPort) -> Signal: def connect(self, port: InputPort) -> Signal:
return Signal(self, port) # Signal is added to self._destination_signals in signal constructor. # Signal is added to self._destination_signals in signal constructor.
return Signal(self, port)
def add_signal(self, signal: Signal) -> None: def add_signal(self, signal: Signal) -> None:
assert signal not in self.signals, \ assert signal not in self.signals, \
"Attempting to connect to Signal already connected." "Attempting to connect to Signal already connected."
self._destination_signals.append(signal) self._destination_signals.append(signal)
if self is not signal.source: if self is not signal.source:
# Connect this outputport to the signal if it isn't already. # Connect this outputport to the signal if it isn't already.
......
...@@ -15,8 +15,8 @@ class Signal(AbstractGraphComponent): ...@@ -15,8 +15,8 @@ class Signal(AbstractGraphComponent):
_source: "OutputPort" _source: "OutputPort"
_destination: "InputPort" _destination: "InputPort"
def __init__(self, source: Optional["OutputPort"] = None, \ def __init__(self, source: Optional["OutputPort"] = None,
destination: Optional["InputPort"] = None, name: Name = ""): destination: Optional["InputPort"] = None, name: Name = ""):
super().__init__(name) super().__init__(name)
......
...@@ -4,7 +4,7 @@ TODO: More info. ...@@ -4,7 +4,7 @@ TODO: More info.
""" """
from numbers import Number from numbers import Number
from typing import List from typing import List, Dict
class OperationState: class OperationState:
...@@ -25,11 +25,19 @@ class SimulationState: ...@@ -25,11 +25,19 @@ class SimulationState:
TODO: More info. TODO: More info.
""" """
# operation_states: Dict[OperationId, OperationState] operation_states: Dict[int, OperationState]
iteration: int iteration: int
def __init__(self): def __init__(self):
self.operation_states = {} op_state = OperationState()
self.operation_states = {1: op_state}
self.iteration = 0 self.iteration = 0
# TODO: More stuff. # @property
# #def iteration(self):
# return self.iteration
# @iteration.setter
# def iteration(self, new_iteration: int):
# self.iteration = new_iteration
#
# TODO: More stuff
...@@ -2,226 +2,313 @@ ...@@ -2,226 +2,313 @@
B-ASIC test suite for the core operations. B-ASIC test suite for the core operations.
""" """
from b_asic.core_operations import Constant, Addition, Subtraction, Multiplication, Division, SquareRoot, ComplexConjugate, Max, Min, Absolute, ConstantMultiplication, ConstantAddition, ConstantSubtraction, ConstantDivision from b_asic.core_operations import Constant, Addition, Subtraction, \
Multiplication, Division, SquareRoot, ComplexConjugate, Max, Min, \
Absolute, ConstantMultiplication, ConstantAddition, ConstantSubtraction, \
ConstantDivision, Butterfly
# Constant tests. # Constant tests.
def test_constant(): def test_constant():
constant_operation = Constant(3) constant_operation = Constant(3)
assert constant_operation.evaluate() == 3 assert constant_operation.evaluate() == 3
def test_constant_negative(): def test_constant_negative():
constant_operation = Constant(-3) constant_operation = Constant(-3)
assert constant_operation.evaluate() == -3 assert constant_operation.evaluate() == -3
def test_constant_complex(): def test_constant_complex():
constant_operation = Constant(3+4j) constant_operation = Constant(3+4j)
assert constant_operation.evaluate() == 3+4j assert constant_operation.evaluate() == 3+4j
# Addition tests. # Addition tests.
def test_addition(): def test_addition():
test_operation = Addition() test_operation = Addition()
constant_operation = Constant(3) constant_operation = Constant(3)
constant_operation_2 = Constant(5) constant_operation_2 = Constant(5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 8 assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 8
def test_addition_negative(): def test_addition_negative():
test_operation = Addition() test_operation = Addition()
constant_operation = Constant(-3) constant_operation = Constant(-3)
constant_operation_2 = Constant(-5) constant_operation_2 = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -8 assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == -8
def test_addition_complex(): def test_addition_complex():
test_operation = Addition() test_operation = Addition()
constant_operation = Constant((3+5j)) constant_operation = Constant((3+5j))
constant_operation_2 = Constant((4+6j)) constant_operation_2 = Constant((4+6j))
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (7+11j) assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == (7+11j)
# Subtraction tests. # Subtraction tests.
def test_subtraction(): def test_subtraction():
test_operation = Subtraction() test_operation = Subtraction()
constant_operation = Constant(5) constant_operation = Constant(5)
constant_operation_2 = Constant(3) constant_operation_2 = Constant(3)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 2 assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 2
def test_subtraction_negative(): def test_subtraction_negative():
test_operation = Subtraction() test_operation = Subtraction()
constant_operation = Constant(-5) constant_operation = Constant(-5)
constant_operation_2 = Constant(-3) constant_operation_2 = Constant(-3)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -2 assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == -2
def test_subtraction_complex(): def test_subtraction_complex():
test_operation = Subtraction() test_operation = Subtraction()
constant_operation = Constant((3+5j)) constant_operation = Constant((3+5j))
constant_operation_2 = Constant((4+6j)) constant_operation_2 = Constant((4+6j))
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (-1-1j) assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == (-1-1j)
# Multiplication tests. # Multiplication tests.
def test_multiplication(): def test_multiplication():
test_operation = Multiplication() test_operation = Multiplication()
constant_operation = Constant(5) constant_operation = Constant(5)
constant_operation_2 = Constant(3) constant_operation_2 = Constant(3)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 15 assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 15
def test_multiplication_negative(): def test_multiplication_negative():
test_operation = Multiplication() test_operation = Multiplication()
constant_operation = Constant(-5) constant_operation = Constant(-5)
constant_operation_2 = Constant(-3) constant_operation_2 = Constant(-3)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 15 assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 15
def test_multiplication_complex(): def test_multiplication_complex():
test_operation = Multiplication() test_operation = Multiplication()
constant_operation = Constant((3+5j)) constant_operation = Constant((3+5j))
constant_operation_2 = Constant((4+6j)) constant_operation_2 = Constant((4+6j))
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (-18+38j) assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == (-18+38j)
# Division tests. # Division tests.
def test_division(): def test_division():
test_operation = Division() test_operation = Division()
constant_operation = Constant(30) constant_operation = Constant(30)
constant_operation_2 = Constant(5) constant_operation_2 = Constant(5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 6 assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 6
def test_division_negative(): def test_division_negative():
test_operation = Division() test_operation = Division()
constant_operation = Constant(-30) constant_operation = Constant(-30)
constant_operation_2 = Constant(-5) constant_operation_2 = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 6 assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 6
def test_division_complex(): def test_division_complex():
test_operation = Division() test_operation = Division()
constant_operation = Constant((60+40j)) constant_operation = Constant((60+40j))
constant_operation_2 = Constant((10+20j)) constant_operation_2 = Constant((10+20j))
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (2.8-1.6j) assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == (2.8-1.6j)
# SquareRoot tests. # SquareRoot tests.
def test_squareroot(): def test_squareroot():
test_operation = SquareRoot() test_operation = SquareRoot()
constant_operation = Constant(36) constant_operation = Constant(36)
assert test_operation.evaluate(constant_operation.evaluate()) == 6 assert test_operation.evaluate(constant_operation.evaluate()) == 6
def test_squareroot_negative(): def test_squareroot_negative():
test_operation = SquareRoot() test_operation = SquareRoot()
constant_operation = Constant(-36) constant_operation = Constant(-36)
assert test_operation.evaluate(constant_operation.evaluate()) == 6j assert test_operation.evaluate(constant_operation.evaluate()) == 6j
def test_squareroot_complex(): def test_squareroot_complex():
test_operation = SquareRoot() test_operation = SquareRoot()
constant_operation = Constant((48+64j)) constant_operation = Constant((48+64j))
assert test_operation.evaluate(constant_operation.evaluate()) == (8+4j) assert test_operation.evaluate(constant_operation.evaluate()) == (8+4j)
# ComplexConjugate tests. # ComplexConjugate tests.
def test_complexconjugate(): def test_complexconjugate():
test_operation = ComplexConjugate() test_operation = ComplexConjugate()
constant_operation = Constant(3+4j) constant_operation = Constant(3+4j)
assert test_operation.evaluate(constant_operation.evaluate()) == (3-4j) assert test_operation.evaluate(constant_operation.evaluate()) == (3-4j)
def test_test_complexconjugate_negative(): def test_test_complexconjugate_negative():
test_operation = ComplexConjugate() test_operation = ComplexConjugate()
constant_operation = Constant(-3-4j) constant_operation = Constant(-3-4j)
assert test_operation.evaluate(constant_operation.evaluate()) == (-3+4j) assert test_operation.evaluate(constant_operation.evaluate()) == (-3+4j)
# Max tests. # Max tests.
def test_max(): def test_max():
test_operation = Max() test_operation = Max()
constant_operation = Constant(30) constant_operation = Constant(30)
constant_operation_2 = Constant(5) constant_operation_2 = Constant(5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 30 assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 30
def test_max_negative(): def test_max_negative():
test_operation = Max() test_operation = Max()
constant_operation = Constant(-30) constant_operation = Constant(-30)
constant_operation_2 = Constant(-5) constant_operation_2 = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -5 assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == -5
# Min tests. # Min tests.
def test_min(): def test_min():
test_operation = Min() test_operation = Min()
constant_operation = Constant(30) constant_operation = Constant(30)
constant_operation_2 = Constant(5) constant_operation_2 = Constant(5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 5 assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == 5
def test_min_negative(): def test_min_negative():
test_operation = Min() test_operation = Min()
constant_operation = Constant(-30) constant_operation = Constant(-30)
constant_operation_2 = Constant(-5) constant_operation_2 = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -30 assert test_operation.evaluate(
constant_operation.evaluate(), constant_operation_2.evaluate()) == -30
# Absolute tests. # Absolute tests.
def test_absolute(): def test_absolute():
test_operation = Absolute() test_operation = Absolute()
constant_operation = Constant(30) constant_operation = Constant(30)
assert test_operation.evaluate(constant_operation.evaluate()) == 30 assert test_operation.evaluate(constant_operation.evaluate()) == 30
def test_absolute_negative(): def test_absolute_negative():
test_operation = Absolute() test_operation = Absolute()
constant_operation = Constant(-5) constant_operation = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate()) == 5 assert test_operation.evaluate(constant_operation.evaluate()) == 5
def test_absolute_complex(): def test_absolute_complex():
test_operation = Absolute() test_operation = Absolute()
constant_operation = Constant((3+4j)) constant_operation = Constant((3+4j))
assert test_operation.evaluate(constant_operation.evaluate()) == 5.0 assert test_operation.evaluate(constant_operation.evaluate()) == 5.0
# ConstantMultiplication tests. # ConstantMultiplication tests.
def test_constantmultiplication(): def test_constantmultiplication():
test_operation = ConstantMultiplication(5) test_operation = ConstantMultiplication(5)
constant_operation = Constant(20) constant_operation = Constant(20)
assert test_operation.evaluate(constant_operation.evaluate()) == 100 assert test_operation.evaluate(constant_operation.evaluate()) == 100
def test_constantmultiplication_negative(): def test_constantmultiplication_negative():
test_operation = ConstantMultiplication(5) test_operation = ConstantMultiplication(5)
constant_operation = Constant(-5) constant_operation = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate()) == -25 assert test_operation.evaluate(constant_operation.evaluate()) == -25
def test_constantmultiplication_complex(): def test_constantmultiplication_complex():
test_operation = ConstantMultiplication(3+2j) test_operation = ConstantMultiplication(3+2j)
constant_operation = Constant((3+4j)) constant_operation = Constant((3+4j))
assert test_operation.evaluate(constant_operation.evaluate()) == (1+18j) assert test_operation.evaluate(constant_operation.evaluate()) == (1+18j)
# ConstantAddition tests. # ConstantAddition tests.
def test_constantaddition(): def test_constantaddition():
test_operation = ConstantAddition(5) test_operation = ConstantAddition(5)
constant_operation = Constant(20) constant_operation = Constant(20)
assert test_operation.evaluate(constant_operation.evaluate()) == 25 assert test_operation.evaluate(constant_operation.evaluate()) == 25
def test_constantaddition_negative(): def test_constantaddition_negative():
test_operation = ConstantAddition(4) test_operation = ConstantAddition(4)
constant_operation = Constant(-5) constant_operation = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate()) == -1 assert test_operation.evaluate(constant_operation.evaluate()) == -1
def test_constantaddition_complex(): def test_constantaddition_complex():
test_operation = ConstantAddition(3+2j) test_operation = ConstantAddition(3+2j)
constant_operation = Constant((3+4j)) constant_operation = Constant((3+4j))
assert test_operation.evaluate(constant_operation.evaluate()) == (6+6j) assert test_operation.evaluate(constant_operation.evaluate()) == (6+6j)
# ConstantSubtraction tests. # ConstantSubtraction tests.
def test_constantsubtraction(): def test_constantsubtraction():
test_operation = ConstantSubtraction(5) test_operation = ConstantSubtraction(5)
constant_operation = Constant(20) constant_operation = Constant(20)
assert test_operation.evaluate(constant_operation.evaluate()) == 15 assert test_operation.evaluate(constant_operation.evaluate()) == 15
def test_constantsubtraction_negative(): def test_constantsubtraction_negative():
test_operation = ConstantSubtraction(4) test_operation = ConstantSubtraction(4)
constant_operation = Constant(-5) constant_operation = Constant(-5)
assert test_operation.evaluate(constant_operation.evaluate()) == -9 assert test_operation.evaluate(constant_operation.evaluate()) == -9
def test_constantsubtraction_complex(): def test_constantsubtraction_complex():
test_operation = ConstantSubtraction(4+6j) test_operation = ConstantSubtraction(4+6j)
constant_operation = Constant((3+4j)) constant_operation = Constant((3+4j))
assert test_operation.evaluate(constant_operation.evaluate()) == (-1-2j) assert test_operation.evaluate(constant_operation.evaluate()) == (-1-2j)
# ConstantDivision tests. # ConstantDivision tests.
def test_constantdivision(): def test_constantdivision():
test_operation = ConstantDivision(5) test_operation = ConstantDivision(5)
constant_operation = Constant(20) constant_operation = Constant(20)
assert test_operation.evaluate(constant_operation.evaluate()) == 4 assert test_operation.evaluate(constant_operation.evaluate()) == 4
def test_constantdivision_negative(): def test_constantdivision_negative():
test_operation = ConstantDivision(4) test_operation = ConstantDivision(4)
constant_operation = Constant(-20) constant_operation = Constant(-20)
assert test_operation.evaluate(constant_operation.evaluate()) == -5 assert test_operation.evaluate(constant_operation.evaluate()) == -5
def test_constantdivision_complex(): def test_constantdivision_complex():
test_operation = ConstantDivision(2+2j) test_operation = ConstantDivision(2+2j)
constant_operation = Constant((10+10j)) constant_operation = Constant((10+10j))
assert test_operation.evaluate(constant_operation.evaluate()) == (5+0j) assert test_operation.evaluate(constant_operation.evaluate()) == (5+0j)
def test_butterfly():
test_operation = Butterfly()
assert list(test_operation.evaluate(2, 3)) == [5, -1]
def test_butterfly_negative():
test_operation = Butterfly()
assert list(test_operation.evaluate(-2, -3)) == [-5, 1]
def test_buttefly_complex():
test_operation = Butterfly()
assert list(test_operation.evaluate(2+1j, 3-2j)) == [5-1j, -1+3j]
from b_asic.core_operations import Constant, Addition from b_asic.core_operations import Constant, Addition, ConstantAddition, Butterfly
from b_asic.signal import Signal from b_asic.signal import Signal
from b_asic.port import InputPort, OutputPort from b_asic.port import InputPort, OutputPort
import pytest import pytest
class TestTraverse: class TestTraverse:
def test_traverse_single_tree(self, operation): def test_traverse_single_tree(self, operation):
"""Traverse a tree consisting of one operation.""" """Traverse a tree consisting of one operation."""
...@@ -20,8 +21,10 @@ class TestTraverse: ...@@ -20,8 +21,10 @@ class TestTraverse:
def test_traverse_type(self, large_operation_tree): def test_traverse_type(self, large_operation_tree):
traverse = list(large_operation_tree.traverse()) traverse = list(large_operation_tree.traverse())
assert len(list(filter(lambda type_: isinstance(type_, Addition), traverse))) == 3 assert len(
assert len(list(filter(lambda type_: isinstance(type_, Constant), traverse))) == 4 list(filter(lambda type_: isinstance(type_, Addition), traverse))) == 3
assert len(
list(filter(lambda type_: isinstance(type_, Constant), traverse))) == 4
def test_traverse_loop(self, operation_tree): def test_traverse_loop(self, operation_tree):
add_oper_signal = Signal() add_oper_signal = Signal()
...@@ -29,3 +32,43 @@ class TestTraverse: ...@@ -29,3 +32,43 @@ class TestTraverse:
operation_tree._input_ports[0].remove_signal(add_oper_signal) operation_tree._input_ports[0].remove_signal(add_oper_signal)
operation_tree._input_ports[0].add_signal(add_oper_signal) operation_tree._input_ports[0].add_signal(add_oper_signal)
assert len(list(operation_tree.traverse())) == 2 assert len(list(operation_tree.traverse())) == 2
class TestEvaluateOutput:
def test_evaluate_output_two_real_inputs(self):
"""Test evaluate_output for two real numbered inputs."""
add1 = Addition()
assert list(add1.evaluate_output(0, [1, 2])) == [3]
def test_evaluate_output_addition_two_complex_inputs(self):
"""Test evaluate_output for two complex numbered inputs."""
add1 = Addition()
assert list(add1.evaluate_output(0, [1+1j, 2])) == [3+1j]
def test_evaluate_output_one_real_input(self):
"""Test evaluate_output for one real numbered inputs."""
c_add1 = ConstantAddition(5)
assert list(c_add1.evaluate_output(0, [1])) == [6]
def test_evaluate_output_one_complex_input(self):
"""Test evaluate_output for one complex numbered inputs."""
c_add1 = ConstantAddition(5)
assert list(c_add1.evaluate_output(0, [1+1j])) == [6+1j]
def test_evaluate_output_two_real_inputs_two_outputs(self):
"""Test evaluate_output for two real inputs and two outputs."""
bfly1 = Butterfly()
assert list(bfly1.evaluate_output(0, [6, 9])) == [15, -3]
assert list(bfly1.evaluate_output(1, [6, 9])) == [15, -3]
def test_evaluate_output_two_complex_inputs_two_outputs(self):
"""Test evaluate_output for two complex inputs and two outputs."""
bfly1 = Butterfly()
assert list(bfly1.evaluate_output(0, [3+2j, 4+2j])) == [7+4j, -1]
assert list(bfly1.evaluate_output(1, [3+2j, 4+2j])) == [7+4j, -1]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment