diff --git a/.gitignore b/.gitignore index 0bdf5476787722ba2007a85c8c73255d56178262..3395e18ae0404b87b2abc5759426b6ef38a49d09 100644 --- a/.gitignore +++ b/.gitignore @@ -1,33 +1,33 @@ -.vs/ -.vscode/ -build*/ -bin*/ -logs/ -dist/ -CMakeLists.txt.user* -*.autosave -*.creator -*.creator.user* -\#*\# -/.emacs.desktop -/.emacs.desktop.lock -*.elc -auto-save-list -tramp -.\#* -*~ -.fuse_hudden* -.directory -.Trash-* -.nfs* -Thumbs.db -Thumbs.db:encryptable -ehthumbs.db -ehthumbs_vista.db -$RECYCLE.BIN/ -*.stackdump -[Dd]esktop.ini -*.egg-info -__pycache__/ -env/ +.vs/ +.vscode/ +build*/ +bin*/ +logs/ +dist/ +CMakeLists.txt.user* +*.autosave +*.creator +*.creator.user* +\#*\# +/.emacs.desktop +/.emacs.desktop.lock +*.elc +auto-save-list +tramp +.\#* +*~ +.fuse_hudden* +.directory +.Trash-* +.nfs* +Thumbs.db +Thumbs.db:encryptable +ehthumbs.db +ehthumbs_vista.db +$RECYCLE.BIN/ +*.stackdump +[Dd]esktop.ini +*.egg-info +__pycache__/ +env/ venv/ \ No newline at end of file diff --git a/b_asic/__init__.py b/b_asic/__init__.py index 752bac07b25b905a175c71b62ae19aaa38dd899a..bd3574ba07b2556fae03ff8fa22deecfd2656705 100644 --- a/b_asic/__init__.py +++ b/b_asic/__init__.py @@ -2,9 +2,8 @@ Better ASIC Toolbox. TODO: More info. """ -from _b_asic import * -from b_asic.basic_operation import * from b_asic.core_operations import * +from b_asic.graph_component import * from b_asic.operation import * from b_asic.precedence_chart import * from b_asic.port import * @@ -12,3 +11,4 @@ from b_asic.schema import * from b_asic.signal_flow_graph import * from b_asic.signal import * from b_asic.simulation import * +from b_asic.special_operations import * diff --git a/b_asic/basic_operation.py b/b_asic/basic_operation.py deleted file mode 100644 index 93a272223cd24500a49e976ebf34db16b58104bc..0000000000000000000000000000000000000000 --- a/b_asic/basic_operation.py +++ /dev/null @@ -1,109 +0,0 @@ -"""@package docstring -B-ASIC Basic Operation Module. -TODO: More info. -""" - -from abc import abstractmethod -from typing import List, Dict, Optional, Any -from numbers import Number - -from b_asic.port import InputPort, OutputPort -from b_asic.signal import Signal -from b_asic.operation import Operation -from b_asic.simulation import SimulationState, OperationState - - -class BasicOperation(Operation): - """Generic abstract operation class which most implementations will derive from. - TODO: More info. - """ - - _input_ports: List[InputPort] - _output_ports: List[OutputPort] - _parameters: Dict[str, Optional[Any]] - - def __init__(self): - """Construct a BasicOperation.""" - self._input_ports = [] - self._output_ports = [] - self._parameters = {} - - @abstractmethod - def evaluate(self, inputs: list) -> list: - """Evaluate the operation and generate a list of output values given a list of input values.""" - pass - - def inputs(self) -> List[InputPort]: - return self._input_ports.copy() - - def outputs(self) -> List[OutputPort]: - return self._output_ports.copy() - - def input_count(self) -> int: - return len(self._input_ports) - - def output_count(self) -> int: - return len(self._output_ports) - - def input(self, i: int) -> InputPort: - return self._input_ports[i] - - def output(self, i: int) -> OutputPort: - return self._output_ports[i] - - def params(self) -> Dict[str, Optional[Any]]: - return self._parameters.copy() - - def param(self, name: str) -> Optional[Any]: - return self._parameters.get(name) - - def set_param(self, name: str, value: Any) -> None: - assert name in self._parameters # TODO: Error message. - self._parameters[name] = value - - def evaluate_outputs(self, state: SimulationState) -> List[Number]: - # TODO: Check implementation. - input_count: int = self.input_count() - output_count: int = self.output_count() - assert input_count == len(self._input_ports) # TODO: Error message. - assert output_count == len(self._output_ports) # TODO: Error message. - - self_state: OperationState = state.operation_states[self.identifier()] - - while self_state.iteration < state.iteration: - input_values: List[Number] = [0] * input_count - for i in range(input_count): - source: Signal = self._input_ports[i].signal - input_values[i] = source.operation.evaluate_outputs(state)[source.port_index] - - self_state.output_values = self.evaluate(input_values) - assert len(self_state.output_values) == output_count # TODO: Error message. - self_state.iteration += 1 - for i in range(output_count): - for signal in self._output_ports[i].signals(): - destination: Signal = signal.destination - destination.evaluate_outputs(state) - - return self_state.output_values - - def split(self) -> List[Operation]: - # TODO: Check implementation. - results = self.evaluate(self._input_ports) - if all(isinstance(e, Operation) for e in results): - return results - return [self] - - @property - def neighbours(self) -> List[Operation]: - neighbours: List[Operation] = [] - for port in self._input_ports: - for signal in port.signals: - neighbours.append(signal.source.operation) - - for port in self._output_ports: - for signal in port.signals: - neighbours.append(signal.destination.operation) - - return neighbours - - # TODO: More stuff. diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py index 45919b8bdb4f892cf9ff30a8305d585678b22220..a1a149d787f831405558b774993b1b0ef86fe0be 100644 --- a/b_asic/core_operations.py +++ b/b_asic/core_operations.py @@ -4,73 +4,268 @@ TODO: More info. """ from numbers import Number +from typing import Optional +from numpy import conjugate, sqrt, abs as np_abs -from b_asic.port import InputPort, OutputPort -from b_asic.operation import Operation -from b_asic.basic_operation import BasicOperation -from b_asic.graph_id import GraphIDType +from b_asic.port import SignalSourceProvider, InputPort, OutputPort +from b_asic.operation import AbstractOperation +from b_asic.graph_component import Name, TypeName -class Input(Operation): - """Input operation. - TODO: More info. - """ +class Constant(AbstractOperation): + """Constant value operation. + TODO: More info. + """ - # TODO: Implement all functions. - pass + def __init__(self, value: Number = 0, name: Name = ""): + super().__init__(input_count = 0, output_count = 1, name = name) + self.set_param("value", value) + @property + def type_name(self) -> TypeName: + return "c" -class Constant(BasicOperation): - """Constant value operation. - TODO: More info. - """ + def evaluate(self): + return self.param("value") + + @property + def value(self) -> Number: + """TODO: docstring""" + return self.param("value") - def __init__(self, value: Number): - """Construct a Constant.""" - super().__init__() - self._output_ports = [OutputPort(1, self)] # TODO: Generate appropriate ID for ports. - self._parameters["value"] = value + @value.setter + def value(self, value: Number): + """TODO: docstring""" + return self.set_param("value", value) - def evaluate(self, inputs: list) -> list: - return [self.param("value")] - def type_name(self) -> GraphIDType: - return "const" +class Addition(AbstractOperation): + """Binary addition operation. + TODO: More info. + """ -class Addition(BasicOperation): - """Binary addition operation. - TODO: More info. - """ + def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = ""): + super().__init__(input_count = 2, output_count = 1, name = name, input_sources = [src0, src1]) - def __init__(self): - """Construct an Addition.""" - super().__init__() - self._input_ports = [InputPort(1, self), InputPort(1, self)] # TODO: Generate appropriate ID for ports. - self._output_ports = [OutputPort(1, self)] # TODO: Generate appropriate ID for ports. + @property + def type_name(self) -> TypeName: + return "add" - def evaluate(self, inputs: list) -> list: - return [inputs[0] + inputs[1]] + def evaluate(self, a, b): + return a + b - def type_name(self) -> GraphIDType: - return "add" +class Subtraction(AbstractOperation): + """Binary subtraction operation. + TODO: More info. + """ -class ConstantMultiplication(BasicOperation): - """Unary constant multiplication operation. - TODO: More info. - """ + def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = ""): + super().__init__(input_count = 2, output_count = 1, name = name, input_sources = [src0, src1]) - def __init__(self, coefficient: Number): - """Construct a ConstantMultiplication.""" - super().__init__() - self._input_ports = [InputPort(1), self] # TODO: Generate appropriate ID for ports. - self._output_ports = [OutputPort(1, self)] # TODO: Generate appropriate ID for ports. - self._parameters["coefficient"] = coefficient + @property + def type_name(self) -> TypeName: + return "sub" - def evaluate(self, inputs: list) -> list: - return [inputs[0] * self.param("coefficient")] + def evaluate(self, a, b): + return a - b - def type_name(self) -> GraphIDType: - return "const_mul" -# TODO: More operations. +class Multiplication(AbstractOperation): + """Binary multiplication operation. + TODO: More info. + """ + + def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = ""): + super().__init__(input_count = 2, output_count = 1, name = name, input_sources = [src0, src1]) + + @property + def type_name(self) -> TypeName: + return "mul" + + def evaluate(self, a, b): + return a * b + + +class Division(AbstractOperation): + """Binary division operation. + TODO: More info. + """ + + def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = ""): + super().__init__(input_count = 2, output_count = 1, name = name, input_sources = [src0, src1]) + + @property + def type_name(self) -> TypeName: + return "div" + + def evaluate(self, a, b): + return a / b + + +class SquareRoot(AbstractOperation): + """Unary square root operation. + TODO: More info. + """ + + def __init__(self, src0: Optional[SignalSourceProvider] = None, name: Name = ""): + super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) + + @property + def type_name(self) -> TypeName: + return "sqrt" + + def evaluate(self, a): + return sqrt(complex(a)) + + +class ComplexConjugate(AbstractOperation): + """Unary complex conjugate operation. + TODO: More info. + """ + + def __init__(self, src0: Optional[SignalSourceProvider] = None, name: Name = ""): + super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) + + @property + def type_name(self) -> TypeName: + return "conj" + + def evaluate(self, a): + return conjugate(a) + + +class Max(AbstractOperation): + """Binary max operation. + TODO: More info. + """ + + def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = ""): + super().__init__(input_count = 2, output_count = 1, name = name, input_sources = [src0, src1]) + + @property + def type_name(self) -> TypeName: + return "max" + + def evaluate(self, a, b): + assert not isinstance(a, complex) and not isinstance(b, complex), \ + ("core_operations.Max does not support complex numbers.") + return a if a > b else b + + +class Min(AbstractOperation): + """Binary min operation. + TODO: More info. + """ + + def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = ""): + super().__init__(input_count = 2, output_count = 1, name = name, input_sources = [src0, src1]) + + @property + def type_name(self) -> TypeName: + return "min" + + def evaluate(self, a, b): + assert not isinstance(a, complex) and not isinstance(b, complex), \ + ("core_operations.Min does not support complex numbers.") + return a if a < b else b + + +class Absolute(AbstractOperation): + """Unary absolute value operation. + TODO: More info. + """ + + def __init__(self, src0: Optional[SignalSourceProvider] = None, name: Name = ""): + super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) + + @property + def type_name(self) -> TypeName: + return "abs" + + def evaluate(self, a): + return np_abs(a) + + +class ConstantMultiplication(AbstractOperation): + """Unary constant multiplication operation. + TODO: More info. + """ + + def __init__(self, value: Number, src0: Optional[SignalSourceProvider] = None, name: Name = ""): + super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) + self.set_param("value", value) + + @property + def type_name(self) -> TypeName: + return "cmul" + + def evaluate(self, a): + return a * self.param("value") + + +class ConstantAddition(AbstractOperation): + """Unary constant addition operation. + TODO: More info. + """ + + def __init__(self, value: Number, src0: Optional[SignalSourceProvider] = None, name: Name = ""): + super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) + self.set_param("value", value) + + @property + def type_name(self) -> TypeName: + return "cadd" + + def evaluate(self, a): + return a + self.param("value") + + +class ConstantSubtraction(AbstractOperation): + """Unary constant subtraction operation. + TODO: More info. + """ + + def __init__(self, value: Number, src0: Optional[SignalSourceProvider] = None, name: Name = ""): + super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) + self.set_param("value", value) + + @property + def type_name(self) -> TypeName: + return "csub" + + def evaluate(self, a): + return a - self.param("value") + + +class ConstantDivision(AbstractOperation): + """Unary constant division operation. + TODO: More info. + """ + + def __init__(self, value: Number, src0: Optional[SignalSourceProvider] = None, name: Name = ""): + super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) + self.set_param("value", value) + + @property + def type_name(self) -> TypeName: + return "cdiv" + + def evaluate(self, a): + return a / self.param("value") + +class Butterfly(AbstractOperation): + """Butterfly operation that returns two outputs. + The first output is a + b and the second output is a - b. + TODO: More info. + """ + + def __init__(self, src0: Optional[SignalSourceProvider] = None, src1: Optional[SignalSourceProvider] = None, name: Name = ""): + super().__init__(input_count = 2, output_count = 2, name = name, input_sources = [src0, src1]) + + def evaluate(self, a, b): + return a + b, a - b + + @property + def type_name(self) -> TypeName: + return "bfly" diff --git a/b_asic/graph_component.py b/b_asic/graph_component.py new file mode 100644 index 0000000000000000000000000000000000000000..52eba17c7e343842f636870d5d9a8fa694b713da --- /dev/null +++ b/b_asic/graph_component.py @@ -0,0 +1,65 @@ +"""@package docstring +B-ASIC Operation Module. +TODO: More info. +""" + +from abc import ABC, abstractmethod +from copy import copy +from typing import NewType + +Name = NewType("Name", str) +TypeName = NewType("TypeName", str) + + +class GraphComponent(ABC): + """Graph component interface. + TODO: More info. + """ + + @property + @abstractmethod + def type_name(self) -> TypeName: + """Return the type name of the graph component""" + raise NotImplementedError + + @property + @abstractmethod + def name(self) -> Name: + """Return the name of the graph component.""" + raise NotImplementedError + + @name.setter + @abstractmethod + def name(self, name: Name) -> None: + """Set the name of the graph component to the entered name.""" + raise NotImplementedError + + @abstractmethod + def copy_unconnected(self) -> "GraphComponent": + """Get a copy of this graph component, except without any connected components.""" + raise NotImplementedError + + +class AbstractGraphComponent(GraphComponent): + """Abstract Graph Component class which is a component of a signal flow graph. + + TODO: More info. + """ + + _name: Name + + def __init__(self, name: Name = ""): + self._name = name + + @property + def name(self) -> Name: + return self._name + + @name.setter + def name(self, name: Name) -> None: + self._name = name + + def copy_unconnected(self) -> GraphComponent: + new_comp = self.__class__() + new_comp.name = copy(self.name) + return new_comp \ No newline at end of file diff --git a/b_asic/graph_id.py b/b_asic/graph_id.py deleted file mode 100644 index 0fd1855b6b353105fd4ef7cf55084906263e47c7..0000000000000000000000000000000000000000 --- a/b_asic/graph_id.py +++ /dev/null @@ -1,26 +0,0 @@ -"""@package docstring -B-ASIC Graph ID module for handling IDs of different objects in a graph. -TODO: More info -""" - -from collections import defaultdict -from typing import NewType, DefaultDict - -GraphID = NewType("GraphID", str) -GraphIDType = NewType("GraphIDType", str) -GraphIDNumber = NewType("GraphIDNumber", int) - - -class GraphIDGenerator: - """A class that generates Graph IDs for objects.""" - - _next_id_number: DefaultDict[GraphIDType, GraphIDNumber] - - def __init__(self): - self._next_id_number = defaultdict(lambda: 1) # Initalises every key element to 1 - - def get_next_id(self, graph_id_type: GraphIDType) -> GraphID: - """Returns the next graph id for a certain graph id type.""" - graph_id = graph_id_type + str(self._next_id_number[graph_id_type]) - self._next_id_number[graph_id_type] += 1 # Increase the current id number - return graph_id diff --git a/b_asic/operation.py b/b_asic/operation.py index 923690aa0916306a6d719a567c21bc746a2d4b4e..d644dbd3c566406be1723b18493281aa98cf6ff6 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -3,93 +3,304 @@ B-ASIC Operation Module. TODO: More info. """ -from abc import ABC, abstractmethod +import collections + +from abc import abstractmethod +from copy import deepcopy from numbers import Number -from typing import List, Dict, Optional, Any, TYPE_CHECKING - -if TYPE_CHECKING: - from b_asic.port import InputPort, OutputPort - from b_asic.simulation import SimulationState - from b_asic.graph_id import GraphIDType - - -class Operation(ABC): - """Operation interface. - TODO: More info. - """ - - @abstractmethod - def inputs(self) -> "List[InputPort]": - """Get a list of all input ports.""" - pass - - @abstractmethod - def outputs(self) -> "List[OutputPort]": - """Get a list of all output ports.""" - pass - - @abstractmethod - def input_count(self) -> int: - """Get the number of input ports.""" - pass - - @abstractmethod - def output_count(self) -> int: - """Get the number of output ports.""" - pass - - @abstractmethod - def input(self, i: int) -> "InputPort": - """Get the input port at index i.""" - pass - - @abstractmethod - def output(self, i: int) -> "OutputPort": - """Get the output port at index i.""" - pass - - @abstractmethod - def params(self) -> Dict[str, Optional[Any]]: - """Get a dictionary of all parameter values.""" - pass - - @abstractmethod - def param(self, name: str) -> Optional[Any]: - """Get the value of a parameter. - Returns None if the parameter is not defined. - """ - pass - - @abstractmethod - def set_param(self, name: str, value: Any) -> None: - """Set the value of a parameter. - The parameter must be defined. - """ - pass - - @abstractmethod - def evaluate_outputs(self, state: "SimulationState") -> List[Number]: - """Simulate the circuit until its iteration count matches that of the simulation state, - then return the resulting output vector. - """ - pass - - @abstractmethod - def split(self) -> "List[Operation]": - """Split the operation into multiple operations. - If splitting is not possible, this may return a list containing only the operation itself. - """ - pass - - @abstractmethod - def type_name(self) -> "GraphIDType": - """Returns a string representing the operation name of the operation.""" - pass - - @abstractmethod - def neighbours(self) -> "List[Operation]": - """Return all operations that are connected by signals to this operation. - If no neighbours are found this returns an empty list - """ - - # TODO: More stuff. +from typing import List, Sequence, Iterable, Dict, Optional, Any, Set, Generator, Union +from collections import deque + +from b_asic.graph_component import GraphComponent, AbstractGraphComponent, Name +from b_asic.port import SignalSourceProvider, InputPort, OutputPort + + +class Operation(GraphComponent, SignalSourceProvider): + """Operation interface. + TODO: More info. + """ + + @abstractmethod + def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Addition, ConstantAddition]": + """Overloads the addition operator to make it return a new Addition operation + object that is connected to the self and other objects. If other is a number then + returns a ConstantAddition operation object instead. + """ + raise NotImplementedError + + @abstractmethod + def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Subtraction, ConstantSubtraction]": + """Overloads the subtraction operator to make it return a new Subtraction operation + object that is connected to the self and other objects. If other is a number then + returns a ConstantSubtraction operation object instead. + """ + raise NotImplementedError + + @abstractmethod + def __mul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]": + """Overloads the multiplication operator to make it return a new Multiplication operation + object that is connected to the self and other objects. If other is a number then + returns a ConstantMultiplication operation object instead. + """ + raise NotImplementedError + + @abstractmethod + def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Division, ConstantDivision]": + """Overloads the division operator to make it return a new Division operation + object that is connected to the self and other objects. If other is a number then + returns a ConstantDivision operation object instead. + """ + raise NotImplementedError + + @property + @abstractmethod + def inputs(self) -> List[InputPort]: + """Get a list of all input ports.""" + raise NotImplementedError + + @property + @abstractmethod + def outputs(self) -> List[OutputPort]: + """Get a list of all output ports.""" + raise NotImplementedError + + @property + @abstractmethod + def input_count(self) -> int: + """Get the number of input ports.""" + raise NotImplementedError + + @property + @abstractmethod + def output_count(self) -> int: + """Get the number of output ports.""" + raise NotImplementedError + + @abstractmethod + def input(self, i: int) -> InputPort: + """Get the input port at index i.""" + raise NotImplementedError + + @abstractmethod + def output(self, i: int) -> OutputPort: + """Get the output port at index i.""" + raise NotImplementedError + + @abstractmethod + def params(self) -> Dict[str, Optional[Any]]: + """Get a dictionary of all parameter values.""" + raise NotImplementedError + + @abstractmethod + def param(self, name: str) -> Optional[Any]: + """Get the value of a parameter. + Returns None if the parameter is not defined. + """ + raise NotImplementedError + + @abstractmethod + def set_param(self, name: str, value: Any) -> None: + """Set the value of a parameter. + Adds the parameter if it is not already defined. + """ + raise NotImplementedError + + @abstractmethod + def evaluate_output(self, i: int, input_values: Sequence[Number]) -> Sequence[Optional[Number]]: + """Evaluate the output at index i of this operation with the given input values. + The returned sequence contains results corresponding to each output of this operation, + where a value of None means it was not evaluated. + The value at index i is guaranteed to have been evaluated, while the others may or may not + have been evaluated depending on what is the most efficient. + For example, Butterfly().evaluate_output(1, [5, 4]) may result in either (9, 1) or (None, 1). + """ + raise NotImplementedError + + @abstractmethod + def split(self) -> Iterable["Operation"]: + """Split the operation into multiple operations. + If splitting is not possible, this may return a list containing only the operation itself. + """ + raise NotImplementedError + + @property + @abstractmethod + def neighbors(self) -> Iterable["Operation"]: + """Return all operations that are connected by signals to this operation. + If no neighbors are found, this returns an empty list. + """ + raise NotImplementedError + + @abstractmethod + def traverse(self) -> Generator["Operation", None, None]: + """Get a generator that recursively iterates through all operations that are connected by signals to this operation, + as well as the ones that they are connected to. + """ + raise NotImplementedError + + +class AbstractOperation(Operation, AbstractGraphComponent): + """Generic abstract operation class which most implementations will derive from. + TODO: More info. + """ + + _input_ports: List[InputPort] + _output_ports: List[OutputPort] + _parameters: Dict[str, Optional[Any]] + + def __init__(self, input_count: int, output_count: int, name: Name = "", input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None): + super().__init__(name) + self._input_ports = [] + self._output_ports = [] + self._parameters = {} + + # Allocate input ports. + for i in range(input_count): + self._input_ports.append(InputPort(self, i)) + + # Allocate output ports. + for i in range(output_count): + self._output_ports.append(OutputPort(self, i)) + + # Connect given input sources, if any. + if input_sources is not None: + source_count = len(input_sources) + if source_count != input_count: + raise ValueError(f"Operation expected {input_count} input sources but only got {source_count}") + for i, src in enumerate(input_sources): + if src is not None: + self._input_ports[i].connect(src.source) + + @abstractmethod + def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ + """Evaluate the operation and generate a list of output values given a + list of input values. + """ + raise NotImplementedError + + def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Addition, ConstantAddition]": + # Import here to avoid circular imports. + from b_asic.core_operations import Addition, ConstantAddition + + if isinstance(src, Number): + return ConstantAddition(src, self) + return Addition(self, src) + + def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Subtraction, ConstantSubtraction]": + # Import here to avoid circular imports. + from b_asic.core_operations import Subtraction, ConstantSubtraction + + if isinstance(src, Number): + return ConstantSubtraction(src, self) + return Subtraction(self, src) + + def __mul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]": + # Import here to avoid circular imports. + from b_asic.core_operations import Multiplication, ConstantMultiplication + + if isinstance(src, Number): + return ConstantMultiplication(src, self) + return Multiplication(self, src) + + def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Division, ConstantDivision]": + # Import here to avoid circular imports. + from b_asic.core_operations import Division, ConstantDivision + + if isinstance(src, Number): + return ConstantDivision(src, self) + return Division(self, src) + + @property + def inputs(self) -> List[InputPort]: + return self._input_ports.copy() + + @property + def outputs(self) -> List[OutputPort]: + return self._output_ports.copy() + + @property + def input_count(self) -> int: + return len(self._input_ports) + + @property + def output_count(self) -> int: + return len(self._output_ports) + + def input(self, i: int) -> InputPort: + return self._input_ports[i] + + def output(self, i: int) -> OutputPort: + return self._output_ports[i] + + @property + def params(self) -> Dict[str, Optional[Any]]: + return self._parameters.copy() + + def param(self, name: str) -> Optional[Any]: + return self._parameters.get(name) + + def set_param(self, name: str, value: Any) -> None: + self._parameters[name] = value + + def evaluate_output(self, i: int, input_values: Sequence[Number]) -> Sequence[Optional[Number]]: + result = self.evaluate(*input_values) + if isinstance(result, collections.Sequence): + if len(result) != self.output_count: + raise RuntimeError("Operation evaluated to incorrect number of outputs") + return result + if isinstance(result, Number): + if self.output_count != 1: + raise RuntimeError("Operation evaluated to incorrect number of outputs") + return [result] + raise RuntimeError("Operation evaluated to invalid type") + + def split(self) -> Iterable[Operation]: + # Import here to avoid circular imports. + from b_asic.special_operations import Input + try: + result = self.evaluate([Input()] * self.input_count) + if isinstance(result, collections.Sequence) and all(isinstance(e, Operation) for e in result): + return result + if isinstance(result, Operation): + return [result] + except TypeError: + pass + except ValueError: + pass + return [self] + + @property + def neighbors(self) -> Iterable[Operation]: + neighbors = [] + for port in self._input_ports: + for signal in port.signals: + neighbors.append(signal.source.operation) + for port in self._output_ports: + for signal in port.signals: + neighbors.append(signal.destination.operation) + return neighbors + + def traverse(self) -> Generator[Operation, None, None]: + # Breadth first search. + visited = {self} + queue = deque([self]) + while queue: + operation = queue.popleft() + yield operation + for n_operation in operation.neighbors: + if n_operation not in visited: + visited.add(n_operation) + queue.append(n_operation) + + @property + def source(self) -> OutputPort: + if self.output_count != 1: + diff = "more" if self.output_count > 1 else "less" + raise TypeError(f"{self.__class__.__name__} cannot be used as an input source because it has {diff} than 1 output") + return self.output(0) + + def copy_unconnected(self) -> GraphComponent: + new_comp: AbstractOperation = super().copy_unconnected() + for name, value in self.params.items(): + new_comp.set_param(name, deepcopy(value)) # pylint: disable=no-member + return new_comp diff --git a/b_asic/port.py b/b_asic/port.py index 4c6fb244b5882e4a6bbc9cadc9e4d38016bc5748..4f249e3cf81d19943996e2056499a323d6c10a73 100644 --- a/b_asic/port.py +++ b/b_asic/port.py @@ -4,132 +4,198 @@ TODO: More info. """ from abc import ABC, abstractmethod -from typing import NewType, Optional, List +from copy import copy +from typing import NewType, Optional, List, Iterable, TYPE_CHECKING from b_asic.signal import Signal -from b_asic.operation import Operation -PortId = NewType("PortId", int) +if TYPE_CHECKING: + from b_asic.operation import Operation class Port(ABC): - """Abstract port class. - TODO: More info. - """ - - _port_id: PortId - _operation: Operation - - def __init__(self, port_id: PortId, operation: Operation): - self._port_id = port_id - self._operation = operation - - @property - def identifier(self) -> PortId: - """Get the unique identifier.""" - return self._port_id - - @property - def operation(self) -> Operation: - """Get the connected operation.""" - return self._operation - - @property - @abstractmethod - def signals(self) -> List[Signal]: - """Get a list of all connected signals.""" - pass - - @property - @abstractmethod - def signal(self, i: int = 0) -> Signal: - """Get the connected signal at index i.""" - pass - - @abstractmethod - def signal_count(self) -> int: - """Get the number of connected signals.""" - pass - - @abstractmethod - def connect(self, signal: Signal) -> None: - """Connect a signal.""" - pass - - @abstractmethod - def disconnect(self, i: int = 0) -> None: - """Disconnect a signal.""" - pass - - - # TODO: More stuff. - - -class InputPort(Port): - """Input port. - TODO: More info. - """ - _source_signal: Optional[Signal] - - def __init__(self, port_id: PortId, operation: Operation): - super().__init__(port_id, operation) - self._source_signal = None - - @property - def signals(self) -> List[Signal]: - return [] if self._source_signal is None else [self._source_signal] - - @property - def signal(self, i: int = 0) -> Signal: - assert 0 <= i < self.signal_count() # TODO: Error message. - assert self._source_signal is not None # TODO: Error message. - return self._source_signal - - def signal_count(self) -> int: - return 0 if self._source_signal is None else 1 - - def connect(self, signal: Signal) -> None: - self._source_signal = signal - signal.destination = self - - def disconnect(self, i: int = 0) -> None: - assert 0 <= i < self.signal_count() # TODO: Error message. - self._source_signal.disconnect_source() - self._source_signal = None - - # TODO: More stuff. - - -class OutputPort(Port): - """Output port. - TODO: More info. - """ - - _destination_signals: List[Signal] - - def __init__(self, port_id: PortId, operation: Operation): - super().__init__(port_id, operation) - self._destination_signals = [] - - @property - def signals(self) -> List[Signal]: - return self._destination_signals.copy() - - @property - def signal(self, i: int = 0) -> Signal: - assert 0 <= i < self.signal_count() # TODO: Error message. - return self._destination_signals[i] - - def signal_count(self) -> int: - return len(self._destination_signals) - - def connect(self, signal: Signal) -> None: - assert signal not in self._destination_signals # TODO: Error message. - self._destination_signals.append(signal) - signal.source = self - - def disconnect(self, i: int = 0) -> None: - assert 0 <= i < self.signal_count() # TODO: Error message. - del self._destination_signals[i] - - # TODO: More stuff. + """Port Interface. + + TODO: More documentaiton? + """ + + @property + @abstractmethod + def operation(self) -> "Operation": + """Return the connected operation.""" + raise NotImplementedError + + @property + @abstractmethod + def index(self) -> int: + """Return the index of the port.""" + raise NotImplementedError + + @property + @abstractmethod + def signal_count(self) -> int: + """Return the number of connected signals.""" + raise NotImplementedError + + @property + @abstractmethod + def signals(self) -> Iterable[Signal]: + """Return all connected signals.""" + raise NotImplementedError + + @abstractmethod + def add_signal(self, signal: Signal) -> None: + """Connect this port to the entered signal. If the entered signal isn't connected to + this port then connect the entered signal to the port aswell. + """ + raise NotImplementedError + + @abstractmethod + def remove_signal(self, signal: Signal) -> None: + """Remove the signal that was entered from the Ports signals. + If the entered signal still is connected to this port then disconnect the + entered signal from the port aswell. + + Keyword arguments: + - signal: Signal to remove. + """ + raise NotImplementedError + + @abstractmethod + def clear(self) -> None: + """Removes all connected signals from the Port.""" + raise NotImplementedError + + +class AbstractPort(Port): + """Abstract port class. + + Handles functionality for port id and saves the connection to the parent operation. + """ + + _operation: "Operation" + _index: int + + def __init__(self, operation: "Operation", index: int): + self._operation = operation + self._index = index + + @property + def operation(self) -> "Operation": + return self._operation + + @property + def index(self) -> int: + return self._index + + +class SignalSourceProvider(ABC): + """Signal source provider interface. + TODO: More info. + """ + + @property + @abstractmethod + def source(self) -> "OutputPort": + """Get the main source port provided by this object.""" + raise NotImplementedError + + +class InputPort(AbstractPort): + """Input port. + TODO: More info. + """ + + _source_signal: Optional[Signal] + _value_length: Optional[int] + + def __init__(self, operation: "Operation", index: int): + super().__init__(operation, index) + self._source_signal = None + self._value_length = None + + @property + def signal_count(self) -> int: + return 0 if self._source_signal is None else 1 + + @property + def signals(self) -> Iterable[Signal]: + return [] if self._source_signal is None else [self._source_signal] + + def add_signal(self, signal: Signal) -> None: + assert self._source_signal is None, "Input port may have only one signal added." + assert signal is not self._source_signal, "Attempted to add already connected signal." + self._source_signal = signal + signal.set_destination(self) + + def remove_signal(self, signal: Signal) -> None: + assert signal is self._source_signal, "Attempted to remove already removed signal." + self._source_signal = None + signal.remove_destination() + + def clear(self) -> None: + if self._source_signal is not None: + self.remove_signal(self._source_signal) + + @property + def connected_source(self) -> Optional["OutputPort"]: + """Get the output port that is currently connected to this input port, + or None if it is unconnected. + """ + return None if self._source_signal is None else self._source_signal.source + + def connect(self, src: SignalSourceProvider) -> Signal: + """Connect the provided signal source to this input port by creating a new signal. + Returns the new signal. + """ + assert self._source_signal is None, "Attempted to connect already connected input port." + return Signal(src.source, self) # self._source_signal is set by the signal constructor. + + @property + def value_length(self) -> Optional[int]: + """Get the number of bits that this port should truncate received values to.""" + return self._value_length + + @value_length.setter + def value_length(self, bits: Optional[int]) -> None: + """Set the number of bits that this port should truncate received values to.""" + assert bits is None or (isinstance(bits, int) and bits >= 0), "Value length must be non-negative." + self._value_length = bits + + +class OutputPort(AbstractPort, SignalSourceProvider): + """Output port. + TODO: More info. + """ + + _destination_signals: List[Signal] + + def __init__(self, operation: "Operation", index: int): + super().__init__(operation, index) + self._destination_signals = [] + + @property + def signal_count(self) -> int: + return len(self._destination_signals) + + @property + def signals(self) -> Iterable[Signal]: + return self._destination_signals + + def add_signal(self, signal: Signal) -> None: + assert signal not in self._destination_signals, "Attempted to add already connected signal." + self._destination_signals.append(signal) + signal.set_source(self) + + def remove_signal(self, signal: Signal) -> None: + assert signal in self._destination_signals, "Attempted to remove already removed signal." + self._destination_signals.remove(signal) + signal.remove_source() + + def clear(self) -> None: + for signal in copy(self._destination_signals): + self.remove_signal(signal) + + @property + def source(self) -> "OutputPort": + return self \ No newline at end of file diff --git a/b_asic/precedence_chart.py b/b_asic/precedence_chart.py index 93b86164fec041c20d9b170839897ecff96ccfdf..be55a123e0ab4330057c0bb62581e45195f5e5ba 100644 --- a/b_asic/precedence_chart.py +++ b/b_asic/precedence_chart.py @@ -7,15 +7,15 @@ from b_asic.signal_flow_graph import SFG class PrecedenceChart: - """Precedence chart constructed from a signal flow graph. - TODO: More info. - """ + """Precedence chart constructed from a signal flow graph. + TODO: More info. + """ - sfg: SFG - # TODO: More members. + sfg: SFG + # TODO: More members. - def __init__(self, sfg: SFG): - self.sfg = sfg - # TODO: Implement. + def __init__(self, sfg: SFG): + self.sfg = sfg + # TODO: Implement. - # TODO: More stuff. + # TODO: More stuff. diff --git a/b_asic/schema.py b/b_asic/schema.py index 41938263d144a066822befc5ad0a3a2ab41839c4..e5068cdc080c5c5004c44c885ac48f52ba44c1f3 100644 --- a/b_asic/schema.py +++ b/b_asic/schema.py @@ -7,15 +7,15 @@ from b_asic.precedence_chart import PrecedenceChart class Schema: - """Schema constructed from a precedence chart. - TODO: More info. - """ + """Schema constructed from a precedence chart. + TODO: More info. + """ - pc: PrecedenceChart - # TODO: More members. + pc: PrecedenceChart + # TODO: More members. - def __init__(self, pc: PrecedenceChart): - self.pc = pc - # TODO: Implement. + def __init__(self, pc: PrecedenceChart): + self.pc = pc + # TODO: Implement. - # TODO: More stuff. + # TODO: More stuff. diff --git a/b_asic/signal.py b/b_asic/signal.py index 17078138e5ff75889cd6afe8759584f04749edfa..67e1d0f908ba57f5d355e77794993587343e63cf 100644 --- a/b_asic/signal.py +++ b/b_asic/signal.py @@ -1,38 +1,90 @@ """@package docstring B-ASIC Signal Module. """ -from typing import TYPE_CHECKING, Optional +from typing import Optional, TYPE_CHECKING + +from b_asic.graph_component import AbstractGraphComponent, TypeName, Name + if TYPE_CHECKING: - from b_asic import OutputPort, InputPort + from b_asic.port import InputPort, OutputPort + + +class Signal(AbstractGraphComponent): + """A connection between two ports.""" + + _source: Optional["OutputPort"] + _destination: Optional["InputPort"] + + def __init__(self, source: Optional["OutputPort"] = None, \ + destination: Optional["InputPort"] = None, name: Name = ""): + super().__init__(name) + self._source = None + self._destination = None + if source is not None: + self.set_source(source) + if destination is not None: + self.set_destination(destination) + + @property + def source(self) -> Optional["OutputPort"]: + """Return the source OutputPort of the signal.""" + return self._source -class Signal: - """A connection between two ports.""" - _source: "OutputPort" - _destination: "InputPort" + @property + def destination(self) -> Optional["InputPort"]: + """Return the destination "InputPort" of the signal.""" + return self._destination - def __init__(self, src: Optional["OutputPort"] = None, dest: Optional["InputPort"] = None): - self._source = src - self._destination = dest + def set_source(self, src: "OutputPort") -> None: + """Disconnect the previous source OutputPort of the signal and + connect to the entered source OutputPort. Also connect the entered + source port to the signal if it hasn't already been connected. - @property - def source(self) -> "InputPort": - return self._source + Keyword arguments: + - src: OutputPort to connect as source to the signal. + """ + if src is not self._source: + self.remove_source() + self._source = src + if self not in src.signals: + src.add_signal(self) - @property - def destination(self) -> "OutputPort": - return self._destination + def set_destination(self, dest: "InputPort") -> None: + """Disconnect the previous destination InputPort of the signal and + connect to the entered destination InputPort. Also connect the entered + destination port to the signal if it hasn't already been connected. - @source.setter - def source(self, src: "Outputport") -> None: - self._source = src + Keywords argments: + - dest: InputPort to connect as destination to the signal. + """ + if dest is not self._destination: + self.remove_destination() + self._destination = dest + if self not in dest.signals: + dest.add_signal(self) - @destination.setter - def destination(self, dest: "InputPort") -> None: - self._destination = dest + @property + def type_name(self) -> TypeName: + return "s" - def disconnect_source(self) -> None: - self._source = None + def remove_source(self) -> None: + """Disconnect the source OutputPort of the signal. If the source port + still is connected to this signal then also disconnect the source port.""" + src = self._source + if src is not None: + self._source = None + if self in src.signals: + src.remove_signal(self) - def disconnect_destination(self) -> None: - self._destination = None + def remove_destination(self) -> None: + """Disconnect the destination InputPort of the signal.""" + dest = self._destination + if dest is not None: + self._destination = None + if self in dest.signals: + dest.remove_signal(self) + def dangling(self) -> bool: + """Returns true if the signal is missing either a source or a destination, + else false.""" + return self._source is None or self._destination is None diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index f7d4be640dee3605855f02a159c54b574774e565..a011653f4db4d85c5a9e91e0ce72d62472d658e5 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -3,69 +3,207 @@ B-ASIC Signal Flow Graph Module. TODO: More info. """ -from typing import List, Dict, Union, Optional +from typing import NewType, List, Iterable, Sequence, Dict, Optional, DefaultDict, Set +from numbers import Number +from collections import defaultdict -from b_asic.operation import Operation -from b_asic.basic_operation import BasicOperation +from b_asic.port import SignalSourceProvider, OutputPort +from b_asic.operation import Operation, AbstractOperation from b_asic.signal import Signal -from b_asic.simulation import SimulationState, OperationState -from typing import List -from b_asic.graph_id import GraphIDGenerator, GraphID - - -class SFG(BasicOperation): - """Signal flow graph. - TODO: More info. - """ - - _graph_objects_by_id: Dict[GraphID, Union[Operation, Signal]] - _graph_id_generator: GraphIDGenerator - - def __init__(self, input_destinations: List[Signal], output_sources: List[Signal]): - super().__init__() - # TODO: Allocate input/output ports with appropriate IDs. - - self._graph_objects_by_id = dict # Map Operation ID to Operation objects - self._graph_id_generator = GraphIDGenerator() - - # TODO: Traverse the graph between the inputs/outputs and add to self._operations. - # TODO: Connect ports with signals with appropriate IDs. - - def evaluate(self, inputs: list) -> list: - return [] # TODO: Implement - - def add_operation(self, operation: Operation) -> GraphID: - """Adds the entered operation to the SFG's dictionary of graph objects and - returns a generated GraphID for it. - - Keyword arguments: - operation: Operation to add to the graph. - """ - return self._add_graph_obj(operation, operation.type_name()) - - def add_signal(self, signal: Signal) -> GraphID: - """Adds the entered signal to the SFG's dictionary of graph objects and returns - a generated GraphID for it. - - Keyword argumentst: - signal: Signal to add to the graph. - """ - return self._add_graph_obj(signal, 'sig') - - def find_by_id(self, graph_id: GraphID) -> Optional[Operation]: - """Finds a graph object based on the entered Graph ID and returns it. If no graph - object with the entered ID was found then returns None. - - Keyword arguments: - graph_id: Graph ID of the wanted object. - """ - if graph_id in self._graph_objects_by_id: - return self._graph_objects_by_id[graph_id] - - return None - - def _add_graph_obj(self, obj: Union[Operation, Signal], operation_id_type: str): - graph_id = self._graph_id_generator.get_next_id(operation_id_type) - self._graph_objects_by_id[graph_id] = obj - return graph_id - +from b_asic.graph_component import GraphComponent, Name, TypeName +from b_asic.special_operations import Input, Output + + +GraphID = NewType("GraphID", str) +GraphIDNumber = NewType("GraphIDNumber", int) + + +class GraphIDGenerator: + """A class that generates Graph IDs for objects.""" + + _next_id_number: DefaultDict[TypeName, GraphIDNumber] + + def __init__(self, id_number_offset: GraphIDNumber = 0): + self._next_id_number = defaultdict(lambda: id_number_offset) + + def next_id(self, type_name: TypeName) -> GraphID: + """Return the next graph id for a certain graph id type.""" + self._next_id_number[type_name] += 1 + return type_name + str(self._next_id_number[type_name]) + + +class SFG(AbstractOperation): + """Signal flow graph. + TODO: More info. + """ + + _components_by_id: Dict[GraphID, GraphComponent] + _components_by_name: DefaultDict[Name, List[GraphComponent]] + _graph_id_generator: GraphIDGenerator + _input_operations: List[Input] + _output_operations: List[Output] + _original_components_added: Set[GraphComponent] + _original_input_signals: Dict[Signal, int] + _original_output_signals: Dict[Signal, int] + + def __init__(self, input_signals: Sequence[Signal] = [], output_signals: Sequence[Signal] = [], \ + inputs: Sequence[Input] = [], outputs: Sequence[Output] = [], operations: Sequence[Operation] = [], \ + id_number_offset: GraphIDNumber = 0, name: Name = "", \ + input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None): + super().__init__( + input_count = len(input_signals) + len(inputs), + output_count = len(output_signals) + len(outputs), + name = name, + input_sources = input_sources) + + self._components_by_id = dict() + self._components_by_name = defaultdict(list) + self._graph_id_generator = GraphIDGenerator(id_number_offset) + self._input_operations = [] + self._output_operations = [] + self._original_components_added = set() + self._original_input_signals = {} + self._original_output_signals = {} + + # Setup input operations and signals. + for i, s in enumerate(input_signals): + self._input_operations.append(self._add_component_copy_unconnected(Input())) + self._original_input_signals[s] = i + for i, op in enumerate(inputs, len(input_signals)): + self._input_operations.append(self._add_component_copy_unconnected(op)) + for s in op.output(0).signals: + self._original_input_signals[s] = i + + # Setup output operations and signals. + for i, s in enumerate(output_signals): + self._output_operations.append(self._add_component_copy_unconnected(Output())) + self._original_output_signals[s] = i + for i, op in enumerate(outputs, len(output_signals)): + self._output_operations.append(self._add_component_copy_unconnected(op)) + for s in op.input(0).signals: + self._original_output_signals[s] = i + + # Search the graph inwards from each input signal. + for s, i in self._original_input_signals.items(): + if s.destination is None: + raise ValueError(f"Input signal #{i} is missing destination in SFG") + if s.destination.operation not in self._original_components_added: + self._add_operation_copy_recursively(s.destination.operation) + + # Search the graph inwards from each output signal. + for s, i in self._original_output_signals.items(): + if s.source is None: + raise ValueError(f"Output signal #{i} is missing source in SFG") + if s.source.operation not in self._original_components_added: + self._add_operation_copy_recursively(s.source.operation) + + # Search the graph outwards from each operation. + for op in operations: + if op not in self._original_components_added: + self._add_operation_copy_recursively(op) + + @property + def type_name(self) -> TypeName: + return "sfg" + + def evaluate(self, *args): + if len(args) != self.input_count: + raise ValueError("Wrong number of inputs supplied to SFG for evaluation") + for arg, op in zip(args, self._input_operations): + op.value = arg + + result = [] + for op in self._output_operations: + result.append(self._evaluate_source(op.input(0).signals[0].source)) + + n = len(result) + return None if n == 0 else result[0] if n == 1 else result + + def evaluate_output(self, i: int, input_values: Sequence[Number]) -> Sequence[Optional[Number]]: + assert i >= 0 and i < self.output_count, "Output index out of range" + result = [None] * self.output_count + result[i] = self._evaluate_source(self._output_operations[i].input(0).signals[0].source) + return result + + def split(self) -> Iterable[Operation]: + return filter(lambda comp: isinstance(comp, Operation), self._components_by_id.values()) + + @property + def components(self) -> Iterable[GraphComponent]: + """Get all components of this graph.""" + return self._components_by_id.values() + + def find_by_id(self, graph_id: GraphID) -> Optional[GraphComponent]: + """Find a graph object based on the entered Graph ID and return it. If no graph + object with the entered ID was found then return None. + + Keyword arguments: + graph_id: Graph ID of the wanted object. + """ + return self._components_by_id.get(graph_id, None) + + def find_by_name(self, name: Name) -> List[GraphComponent]: + """Find all graph objects that have the entered name and return them + in a list. If no graph object with the entered name was found then return an + empty list. + + Keyword arguments: + name: Name of the wanted object. + """ + return self._components_by_name.get(name, []) + + def _add_component_copy_unconnected(self, original_comp: GraphComponent) -> GraphComponent: + assert original_comp not in self._original_components_added, "Tried to add duplicate SFG component" + self._original_components_added.add(original_comp) + + new_comp = original_comp.copy_unconnected() + self._components_by_id[self._graph_id_generator.next_id(new_comp.type_name)] = new_comp + self._components_by_name[new_comp.name].append(new_comp) + return new_comp + + def _add_operation_copy_recursively(self, original_op: Operation) -> Operation: + # Add a copy of the operation without any connections. + new_op = self._add_component_copy_unconnected(original_op) + + # Connect input ports. + for original_input_port, new_input_port in zip(original_op.inputs, new_op.inputs): + if original_input_port.signal_count < 1: + raise ValueError("Unconnected input port in SFG") + for original_signal in original_input_port.signals: + if original_signal in self._original_input_signals: # Check if the signal is one of the SFG's input signals. + new_signal = self._add_component_copy_unconnected(original_signal) + new_signal.set_destination(new_input_port) + new_signal.set_source(self._input_operations[self._original_input_signals[original_signal]].output(0)) + elif original_signal not in self._original_components_added: # Only add the signal if it wasn't already added. + new_signal = self._add_component_copy_unconnected(original_signal) + new_signal.set_destination(new_input_port) + if original_signal.source is None: + raise ValueError("Dangling signal without source in SFG") + # Recursively add the connected operation. + new_connected_op = self._add_operation_copy_recursively(original_signal.source.operation) + new_signal.set_source(new_connected_op.output(original_signal.source.index)) + + # Connect output ports. + for original_output_port, new_output_port in zip(original_op.outputs, new_op.outputs): + for original_signal in original_output_port.signals: + if original_signal in self._original_output_signals: # Check if the signal is one of the SFG's output signals. + new_signal = self._add_component_copy_unconnected(original_signal) + new_signal.set_source(new_output_port) + new_signal.set_destination(self._output_operations[self._original_output_signals[original_signal]].input(0)) + elif original_signal not in self._original_components_added: # Only add the signal if it wasn't already added. + new_signal = self._add_component_copy_unconnected(original_signal) + new_signal.set_source(new_output_port) + if original_signal.destination is None: + raise ValueError("Dangling signal without destination in SFG") + # Recursively add the connected operation. + new_connected_op = self._add_operation_copy_recursively(original_signal.destination.operation) + new_signal.set_destination(new_connected_op.input(original_signal.destination.index)) + + return new_op + + def _evaluate_source(self, src: OutputPort) -> Number: + input_values = [] + for input_port in src.operation.inputs: + input_src = input_port.signals[0].source + input_values.append(self._evaluate_source(input_src)) + return src.operation.evaluate_output(src.index, input_values) \ No newline at end of file diff --git a/b_asic/simulation.py b/b_asic/simulation.py index c4f7f8f366a5298ab7104b14741d61a3cc42f7c9..a2ce11b3263d517cba79c92093e594d712c5b8f3 100644 --- a/b_asic/simulation.py +++ b/b_asic/simulation.py @@ -4,32 +4,40 @@ TODO: More info. """ from numbers import Number -from typing import List +from typing import List, Dict class OperationState: - """Simulation state of an operation. - TODO: More info. - """ + """Simulation state of an operation. + TODO: More info. + """ - output_values: List[Number] - iteration: int + output_values: List[Number] + iteration: int - def __init__(self): - self.output_values = [] - self.iteration = 0 + def __init__(self): + self.output_values = [] + self.iteration = 0 class SimulationState: - """Simulation state. - TODO: More info. - """ - - # operation_states: Dict[OperationId, OperationState] - iteration: int - - def __init__(self): - self.operation_states = {} - self.iteration = 0 - - # TODO: More stuff. + """Simulation state. + TODO: More info. + """ + + operation_states: Dict[int, OperationState] + iteration: int + + def __init__(self): + op_state = OperationState() + self.operation_states = {1: op_state} + self.iteration = 0 + + # @property + # #def iteration(self): + # return self.iteration + # @iteration.setter + # def iteration(self, new_iteration: int): + # self.iteration = new_iteration + # + # TODO: More stuff diff --git a/b_asic/special_operations.py b/b_asic/special_operations.py new file mode 100644 index 0000000000000000000000000000000000000000..465c0086d0120b10e27f769a216874b2e08dd53c --- /dev/null +++ b/b_asic/special_operations.py @@ -0,0 +1,54 @@ +"""@package docstring +B-ASIC Special Operations Module. +TODO: More info. +""" + +from numbers import Number +from typing import Optional + +from b_asic.operation import AbstractOperation +from b_asic.graph_component import Name, TypeName +from b_asic.port import SignalSourceProvider + + +class Input(AbstractOperation): + """Input operation. + TODO: More info. + """ + + def __init__(self, name: Name = ""): + super().__init__(input_count = 0, output_count = 1, name = name) + self.set_param("value", 0) + + @property + def type_name(self) -> TypeName: + return "in" + + def evaluate(self): + return self.param("value") + + @property + def value(self) -> Number: + """TODO: docstring""" + return self.param("value") + + @value.setter + def value(self, value: Number): + """TODO: docstring""" + self.set_param("value", value) + + +class Output(AbstractOperation): + """Output operation. + TODO: More info. + """ + + def __init__(self, src0: Optional[SignalSourceProvider] = None, name: Name = ""): + super().__init__(input_count = 1, output_count = 0, name = name, input_sources=[src0]) + + @property + def type_name(self) -> TypeName: + return "out" + + def evaluate(self): + return None \ No newline at end of file diff --git a/b_asic/traverse_tree.py b/b_asic/traverse_tree.py deleted file mode 100644 index dc00371eaddbbaba0592d31325dbdda9efad09f7..0000000000000000000000000000000000000000 --- a/b_asic/traverse_tree.py +++ /dev/null @@ -1,43 +0,0 @@ -"""@package docstring -B-ASIC Operation Tree Traversing Module. -""" - -from typing import List, Optional -from collections import deque - -from b_asic.operation import Operation - - -class Traverse: - """Traverse operation tree.""" - - def __init__(self, operation: Operation): - """Construct a TraverseTree.""" - self._initial_operation = operation - - def _breadth_first_search(self, start: Operation) -> List[Operation]: - """Use breadth first search to traverse the operation tree.""" - visited: List[Operation] = [start] - queue = deque([start]) - while queue: - operation = queue.popleft() - for n_operation in operation.neighbours: - if n_operation not in visited: - visited.append(n_operation) - queue.append(n_operation) - - return visited - - def traverse(self, type_: Optional[Operation] = None) -> List[Operation]: - """Traverse the the operation tree and return operation where type matches. - If the type is None then return the entire tree. - - Keyword arguments: - type_-- the operation type to search for (default None) - """ - - operations: List[Operation] = self._breadth_first_search(self._initial_operation) - if type_ is not None: - operations = [oper for oper in operations if isinstance(oper, type_)] - - return operations diff --git a/src/main.cpp b/src/main.cpp index 75a77ef58b86cd29238205a078cec780a6ba9a36..bc4e83c69e7d331bbacfa37d8b22baec35833682 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,21 +1,21 @@ -#include <pybind11/pybind11.h> - -namespace py = pybind11; - -namespace asic { - -int add(int a, int b) { - return a + b; -} - -int sub(int a, int b) { - return a - b; -} - -} // namespace asic - -PYBIND11_MODULE(_b_asic, m) { - m.doc() = "Better ASIC Toolbox Extension Module."; - m.def("add", &asic::add, "A function which adds two numbers.", py::arg("a"), py::arg("b")); - m.def("sub", &asic::sub, "A function which subtracts two numbers.", py::arg("a"), py::arg("b")); +#include <pybind11/pybind11.h> + +namespace py = pybind11; + +namespace asic { + +int add(int a, int b) { + return a + b; +} + +int sub(int a, int b) { + return a - b; +} + +} // namespace asic + +PYBIND11_MODULE(_b_asic, m) { + m.doc() = "Better ASIC Toolbox Extension Module."; + m.def("add", &asic::add, "A function which adds two numbers.", py::arg("a"), py::arg("b")); + m.def("sub", &asic::sub, "A function which subtracts two numbers.", py::arg("a"), py::arg("b")); } \ No newline at end of file diff --git a/test/conftest.py b/test/conftest.py index 986af94cc7341f48ba736e6f9d934c8eb706c079..64f39843c53a4369781a269fd7fc30ad9aa1d255 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,2 +1,4 @@ +from test.fixtures.signal import signal, signals +from test.fixtures.operation_tree import * +from test.fixtures.port import * import pytest -from test.fixtures.signal import * \ No newline at end of file diff --git a/test/fixtures/operation_tree.py b/test/fixtures/operation_tree.py new file mode 100644 index 0000000000000000000000000000000000000000..94a1e42f724fdf7f14dbd13debaccc850fbbf552 --- /dev/null +++ b/test/fixtures/operation_tree.py @@ -0,0 +1,30 @@ +from b_asic.core_operations import Addition, Constant +from b_asic.signal import Signal + +import pytest + +@pytest.fixture +def operation(): + return Constant(2) + +@pytest.fixture +def operation_tree(): + """Return a addition operation connected with 2 constants. + ---C---+ + +--A + ---C---+ + """ + return Addition(Constant(2), Constant(3)) + +@pytest.fixture +def large_operation_tree(): + """Return an addition operation connected with a large operation tree with 2 other additions and 4 constants. + ---C---+ + +--A---+ + ---C---+ | + +---A + ---C---+ | + +--A---+ + ---C---+ + """ + return Addition(Addition(Constant(2), Constant(3)), Addition(Constant(4), Constant(5))) diff --git a/test/fixtures/port.py b/test/fixtures/port.py new file mode 100644 index 0000000000000000000000000000000000000000..63632ecdb3a9d81a7f27759cd7166af3163c9e94 --- /dev/null +++ b/test/fixtures/port.py @@ -0,0 +1,10 @@ +import pytest +from b_asic.port import InputPort, OutputPort + +@pytest.fixture +def input_port(): + return InputPort(None, 0) + +@pytest.fixture +def output_port(): + return OutputPort(None, 0) diff --git a/test/fixtures/signal.py b/test/fixtures/signal.py index 9139e93a529cc7a371426b9b97b4d31ddf95d2f7..0c5692feb3203f37876e48df0ab7f2caa69c4d45 100644 --- a/test/fixtures/signal.py +++ b/test/fixtures/signal.py @@ -3,8 +3,10 @@ from b_asic import Signal @pytest.fixture def signal(): + """Return a signal with no connections.""" return Signal() @pytest.fixture def signals(): - return [Signal() for _ in range(0,3)] + """Return 3 signals with no connections.""" + return [Signal() for _ in range(0, 3)] diff --git a/test/graph_id/conftest.py b/test/graph_id/conftest.py deleted file mode 100644 index 5871ed8eef2f90304e1f64c12ba17e1915250724..0000000000000000000000000000000000000000 --- a/test/graph_id/conftest.py +++ /dev/null @@ -1 +0,0 @@ -import pytest diff --git a/test/graph_id/test_graph_id_generator.py b/test/graph_id/test_graph_id_generator.py deleted file mode 100644 index 7aeb6cad27e43233a88eb69e58bd89f78a863c5b..0000000000000000000000000000000000000000 --- a/test/graph_id/test_graph_id_generator.py +++ /dev/null @@ -1,29 +0,0 @@ -""" -B-ASIC test suite for graph id generator. -""" - -from b_asic.graph_id import GraphIDGenerator, GraphID - -import pytest - -def test_empty_string_generator(): - """Test the graph id generator for an empty string type.""" - graph_id_generator = GraphIDGenerator() - assert graph_id_generator.get_next_id("") == "1" - assert graph_id_generator.get_next_id("") == "2" - - -def test_normal_string_generator(): - """"Test the graph id generator for a normal string type.""" - graph_id_generator = GraphIDGenerator() - assert graph_id_generator.get_next_id("add") == "add1" - assert graph_id_generator.get_next_id("add") == "add2" - -def test_different_strings_generator(): - """Test the graph id generator for different strings.""" - graph_id_generator = GraphIDGenerator() - assert graph_id_generator.get_next_id("sub") == "sub1" - assert graph_id_generator.get_next_id("mul") == "mul1" - assert graph_id_generator.get_next_id("sub") == "sub2" - assert graph_id_generator.get_next_id("mul") == "mul2" - \ No newline at end of file diff --git a/test/port/test_inputport.py b/test/port/test_inputport.py deleted file mode 100644 index f0e70cb761ddcd76b2144b1f2c26e606922d213e..0000000000000000000000000000000000000000 --- a/test/port/test_inputport.py +++ /dev/null @@ -1,20 +0,0 @@ -""" -B-ASIC test suite for Inputport -""" - -from b_asic import InputPort - -import pytest - -def test_connect_multiple_signals(signals): - """ - test if only one signal can connect to an input port - """ - inp_port = InputPort(0, None) - - for s in signals: - inp_port.connect(s) - - assert inp_port.signal_count() == 1 - assert inp_port.signals[0] == signals[-1] - diff --git a/test/port/test_outputport.py b/test/port/test_outputport.py deleted file mode 100644 index 5c76bb480fa63488073f6dab9f82c3f3ce00b4f3..0000000000000000000000000000000000000000 --- a/test/port/test_outputport.py +++ /dev/null @@ -1,18 +0,0 @@ -""" -B-ASIC test suite for InputPort -TODO: More info -""" -from b_asic import OutputPort -import pytest - -def test_connect_multiple_signals(signals): - """ - test if multiple signals can connect to an output port - """ - outp_port = OutputPort(0, None) - - for s in signals: - outp_port.connect(s) - - assert outp_port.signal_count() == 3 - assert outp_port.signals == signals \ No newline at end of file diff --git a/test/signal_flow_graph/conftest.py b/test/signal_flow_graph/conftest.py deleted file mode 100644 index 5871ed8eef2f90304e1f64c12ba17e1915250724..0000000000000000000000000000000000000000 --- a/test/signal_flow_graph/conftest.py +++ /dev/null @@ -1 +0,0 @@ -import pytest diff --git a/test/signal_flow_graph/test_signal_flow_graph.py b/test/signal_flow_graph/test_signal_flow_graph.py deleted file mode 100644 index 921e8906ff277b85f7d53e68cd55be338c778419..0000000000000000000000000000000000000000 --- a/test/signal_flow_graph/test_signal_flow_graph.py +++ /dev/null @@ -1,3 +0,0 @@ -from b_asic.signal_flow_graph import SFG -from b_asic.core_operations import Addition, Constant -from b_asic.signal import Signal diff --git a/test/test_abstract_operation.py b/test/test_abstract_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..626a2dc3e5e26fb76d9266dcdd31940681df5c6e --- /dev/null +++ b/test/test_abstract_operation.py @@ -0,0 +1,77 @@ +""" +B-ASIC test suite for the AbstractOperation class. +""" + +from b_asic.core_operations import Addition, ConstantAddition, Subtraction, ConstantSubtraction, \ + Multiplication, ConstantMultiplication, Division, ConstantDivision + +import pytest + + +def test_addition_overload(): + """Tests addition overloading for both operation and number argument.""" + add1 = Addition(None, None, "add1") + add2 = Addition(None, None, "add2") + + add3 = add1 + add2 + + assert isinstance(add3, Addition) + assert add3.input(0).signals == add1.output(0).signals + assert add3.input(1).signals == add2.output(0).signals + + add4 = add3 + 5 + + assert isinstance(add4, ConstantAddition) + assert add4.input(0).signals == add3.output(0).signals + + +def test_subtraction_overload(): + """Tests subtraction overloading for both operation and number argument.""" + add1 = Addition(None, None, "add1") + add2 = Addition(None, None, "add2") + + sub1 = add1 - add2 + + assert isinstance(sub1, Subtraction) + assert sub1.input(0).signals == add1.output(0).signals + assert sub1.input(1).signals == add2.output(0).signals + + sub2 = sub1 - 5 + + assert isinstance(sub2, ConstantSubtraction) + assert sub2.input(0).signals == sub1.output(0).signals + + +def test_multiplication_overload(): + """Tests multiplication overloading for both operation and number argument.""" + add1 = Addition(None, None, "add1") + add2 = Addition(None, None, "add2") + + mul1 = add1 * add2 + + assert isinstance(mul1, Multiplication) + assert mul1.input(0).signals == add1.output(0).signals + assert mul1.input(1).signals == add2.output(0).signals + + mul2 = mul1 * 5 + + assert isinstance(mul2, ConstantMultiplication) + assert mul2.input(0).signals == mul1.output(0).signals + + +def test_division_overload(): + """Tests division overloading for both operation and number argument.""" + add1 = Addition(None, None, "add1") + add2 = Addition(None, None, "add2") + + div1 = add1 / add2 + + assert isinstance(div1, Division) + assert div1.input(0).signals == add1.output(0).signals + assert div1.input(1).signals == add2.output(0).signals + + div2 = div1 / 5 + + assert isinstance(div2, ConstantDivision) + assert div2.input(0).signals == div1.output(0).signals + diff --git a/test/test_core_operations.py b/test/test_core_operations.py new file mode 100644 index 0000000000000000000000000000000000000000..854ccf85f447e430af303dc9a45c8946ac8d7828 --- /dev/null +++ b/test/test_core_operations.py @@ -0,0 +1,314 @@ +""" +B-ASIC test suite for the core operations. +""" + +from b_asic.core_operations import Constant, Addition, Subtraction, \ + Multiplication, Division, SquareRoot, ComplexConjugate, Max, Min, \ + Absolute, ConstantMultiplication, ConstantAddition, ConstantSubtraction, \ + ConstantDivision, Butterfly + +# Constant tests. + + +def test_constant(): + constant_operation = Constant(3) + assert constant_operation.evaluate() == 3 + + +def test_constant_negative(): + constant_operation = Constant(-3) + assert constant_operation.evaluate() == -3 + + +def test_constant_complex(): + constant_operation = Constant(3+4j) + assert constant_operation.evaluate() == 3+4j + +# Addition tests. + + +def test_addition(): + test_operation = Addition() + constant_operation = Constant(3) + constant_operation_2 = Constant(5) + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == 8 + + +def test_addition_negative(): + test_operation = Addition() + constant_operation = Constant(-3) + constant_operation_2 = Constant(-5) + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == -8 + + +def test_addition_complex(): + test_operation = Addition() + constant_operation = Constant((3+5j)) + constant_operation_2 = Constant((4+6j)) + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == (7+11j) + +# Subtraction tests. + + +def test_subtraction(): + test_operation = Subtraction() + constant_operation = Constant(5) + constant_operation_2 = Constant(3) + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == 2 + + +def test_subtraction_negative(): + test_operation = Subtraction() + constant_operation = Constant(-5) + constant_operation_2 = Constant(-3) + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == -2 + + +def test_subtraction_complex(): + test_operation = Subtraction() + constant_operation = Constant((3+5j)) + constant_operation_2 = Constant((4+6j)) + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == (-1-1j) + +# Multiplication tests. + + +def test_multiplication(): + test_operation = Multiplication() + constant_operation = Constant(5) + constant_operation_2 = Constant(3) + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == 15 + + +def test_multiplication_negative(): + test_operation = Multiplication() + constant_operation = Constant(-5) + constant_operation_2 = Constant(-3) + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == 15 + + +def test_multiplication_complex(): + test_operation = Multiplication() + constant_operation = Constant((3+5j)) + constant_operation_2 = Constant((4+6j)) + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == (-18+38j) + +# Division tests. + + +def test_division(): + test_operation = Division() + constant_operation = Constant(30) + constant_operation_2 = Constant(5) + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == 6 + + +def test_division_negative(): + test_operation = Division() + constant_operation = Constant(-30) + constant_operation_2 = Constant(-5) + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == 6 + + +def test_division_complex(): + test_operation = Division() + constant_operation = Constant((60+40j)) + constant_operation_2 = Constant((10+20j)) + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == (2.8-1.6j) + +# SquareRoot tests. + + +def test_squareroot(): + test_operation = SquareRoot() + constant_operation = Constant(36) + assert test_operation.evaluate(constant_operation.evaluate()) == 6 + + +def test_squareroot_negative(): + test_operation = SquareRoot() + constant_operation = Constant(-36) + assert test_operation.evaluate(constant_operation.evaluate()) == 6j + + +def test_squareroot_complex(): + test_operation = SquareRoot() + constant_operation = Constant((48+64j)) + assert test_operation.evaluate(constant_operation.evaluate()) == (8+4j) + +# ComplexConjugate tests. + + +def test_complexconjugate(): + test_operation = ComplexConjugate() + constant_operation = Constant(3+4j) + assert test_operation.evaluate(constant_operation.evaluate()) == (3-4j) + + +def test_test_complexconjugate_negative(): + test_operation = ComplexConjugate() + constant_operation = Constant(-3-4j) + assert test_operation.evaluate(constant_operation.evaluate()) == (-3+4j) + +# Max tests. + + +def test_max(): + test_operation = Max() + constant_operation = Constant(30) + constant_operation_2 = Constant(5) + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == 30 + + +def test_max_negative(): + test_operation = Max() + constant_operation = Constant(-30) + constant_operation_2 = Constant(-5) + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == -5 + +# Min tests. + + +def test_min(): + test_operation = Min() + constant_operation = Constant(30) + constant_operation_2 = Constant(5) + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == 5 + + +def test_min_negative(): + test_operation = Min() + constant_operation = Constant(-30) + constant_operation_2 = Constant(-5) + assert test_operation.evaluate( + constant_operation.evaluate(), constant_operation_2.evaluate()) == -30 + +# Absolute tests. + + +def test_absolute(): + test_operation = Absolute() + constant_operation = Constant(30) + assert test_operation.evaluate(constant_operation.evaluate()) == 30 + + +def test_absolute_negative(): + test_operation = Absolute() + constant_operation = Constant(-5) + assert test_operation.evaluate(constant_operation.evaluate()) == 5 + + +def test_absolute_complex(): + test_operation = Absolute() + constant_operation = Constant((3+4j)) + assert test_operation.evaluate(constant_operation.evaluate()) == 5.0 + +# ConstantMultiplication tests. + + +def test_constantmultiplication(): + test_operation = ConstantMultiplication(5) + constant_operation = Constant(20) + assert test_operation.evaluate(constant_operation.evaluate()) == 100 + + +def test_constantmultiplication_negative(): + test_operation = ConstantMultiplication(5) + constant_operation = Constant(-5) + assert test_operation.evaluate(constant_operation.evaluate()) == -25 + + +def test_constantmultiplication_complex(): + test_operation = ConstantMultiplication(3+2j) + constant_operation = Constant((3+4j)) + assert test_operation.evaluate(constant_operation.evaluate()) == (1+18j) + +# ConstantAddition tests. + + +def test_constantaddition(): + test_operation = ConstantAddition(5) + constant_operation = Constant(20) + assert test_operation.evaluate(constant_operation.evaluate()) == 25 + + +def test_constantaddition_negative(): + test_operation = ConstantAddition(4) + constant_operation = Constant(-5) + assert test_operation.evaluate(constant_operation.evaluate()) == -1 + + +def test_constantaddition_complex(): + test_operation = ConstantAddition(3+2j) + constant_operation = Constant((3+4j)) + assert test_operation.evaluate(constant_operation.evaluate()) == (6+6j) + +# ConstantSubtraction tests. + + +def test_constantsubtraction(): + test_operation = ConstantSubtraction(5) + constant_operation = Constant(20) + assert test_operation.evaluate(constant_operation.evaluate()) == 15 + + +def test_constantsubtraction_negative(): + test_operation = ConstantSubtraction(4) + constant_operation = Constant(-5) + assert test_operation.evaluate(constant_operation.evaluate()) == -9 + + +def test_constantsubtraction_complex(): + test_operation = ConstantSubtraction(4+6j) + constant_operation = Constant((3+4j)) + assert test_operation.evaluate(constant_operation.evaluate()) == (-1-2j) + +# ConstantDivision tests. + + +def test_constantdivision(): + test_operation = ConstantDivision(5) + constant_operation = Constant(20) + assert test_operation.evaluate(constant_operation.evaluate()) == 4 + + +def test_constantdivision_negative(): + test_operation = ConstantDivision(4) + constant_operation = Constant(-20) + assert test_operation.evaluate(constant_operation.evaluate()) == -5 + + +def test_constantdivision_complex(): + test_operation = ConstantDivision(2+2j) + constant_operation = Constant((10+10j)) + assert test_operation.evaluate(constant_operation.evaluate()) == (5+0j) + + +def test_butterfly(): + test_operation = Butterfly() + assert list(test_operation.evaluate(2, 3)) == [5, -1] + + +def test_butterfly_negative(): + test_operation = Butterfly() + assert list(test_operation.evaluate(-2, -3)) == [-5, 1] + + +def test_buttefly_complex(): + test_operation = Butterfly() + assert list(test_operation.evaluate(2+1j, 3-2j)) == [5-1j, -1+3j] diff --git a/test/test_graph_id_generator.py b/test/test_graph_id_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..b8e0cdebb7f1cc32297bacff89314244dda7cd6f --- /dev/null +++ b/test/test_graph_id_generator.py @@ -0,0 +1,28 @@ +""" +B-ASIC test suite for graph id generator. +""" + +from b_asic.signal_flow_graph import GraphIDGenerator, GraphID +import pytest + +@pytest.fixture +def graph_id_generator(): + return GraphIDGenerator() + +class TestGetNextId: + def test_empty_string_generator(self, graph_id_generator): + """Test the graph id generator for an empty string type.""" + assert graph_id_generator.next_id("") == "1" + assert graph_id_generator.next_id("") == "2" + + def test_normal_string_generator(self, graph_id_generator): + """"Test the graph id generator for a normal string type.""" + assert graph_id_generator.next_id("add") == "add1" + assert graph_id_generator.next_id("add") == "add2" + + def test_different_strings_generator(self, graph_id_generator): + """Test the graph id generator for different strings.""" + assert graph_id_generator.next_id("sub") == "sub1" + assert graph_id_generator.next_id("mul") == "mul1" + assert graph_id_generator.next_id("sub") == "sub2" + assert graph_id_generator.next_id("mul") == "mul2" diff --git a/test/test_inputport.py b/test/test_inputport.py new file mode 100644 index 0000000000000000000000000000000000000000..b43bf8e3d11eb3286c087c6a8bbb0b46956e51fb --- /dev/null +++ b/test/test_inputport.py @@ -0,0 +1,101 @@ +""" +B-ASIC test suite for Inputport +""" + +import pytest + +from b_asic import InputPort, OutputPort +from b_asic import Signal + +@pytest.fixture +def inp_port(): + return InputPort(None, 0) + +@pytest.fixture +def out_port(): + return OutputPort(None, 0) + +@pytest.fixture +def out_port2(): + return OutputPort(None, 1) + +@pytest.fixture +def dangling_sig(): + return Signal() + +@pytest.fixture +def s_w_source(out_port): + return Signal(source=out_port) + +@pytest.fixture +def sig_with_dest(inp_port): + return Signal(destination=inp_port) + +@pytest.fixture +def connected_sig(inp_port, out_port): + return Signal(source=out_port, destination=inp_port) + +def test_connect_then_disconnect(inp_port, out_port): + """Test connect unused port to port.""" + s1 = inp_port.connect(out_port) + + assert inp_port.connected_source == out_port + assert inp_port.signals == [s1] + assert out_port.signals == [s1] + assert s1.source is out_port + assert s1.destination is inp_port + + inp_port.remove_signal(s1) + + assert inp_port.connected_source is None + assert inp_port.signals == [] + assert out_port.signals == [s1] + assert s1.source is out_port + assert s1.destination is None + +def test_connect_used_port_to_new_port(inp_port, out_port, out_port2): + """Does connecting multiple ports to an inputport throw error?""" + inp_port.connect(out_port) + with pytest.raises(Exception): + inp_port.connect(out_port2) + +def test_add_signal_then_disconnect(inp_port, s_w_source): + """Can signal be connected then disconnected properly?""" + inp_port.add_signal(s_w_source) + + assert inp_port.connected_source == s_w_source.source + assert inp_port.signals == [s_w_source] + assert s_w_source.source.signals == [s_w_source] + assert s_w_source.destination is inp_port + + inp_port.remove_signal(s_w_source) + + assert inp_port.connected_source is None + assert inp_port.signals == [] + assert s_w_source.source.signals == [s_w_source] + assert s_w_source.destination is None + +def test_set_value_length_pos_int(inp_port): + inp_port.value_length = 10 + assert inp_port.value_length == 10 + +def test_set_value_length_zero(inp_port): + inp_port.value_length = 0 + assert inp_port.value_length == 0 + +def test_set_value_length_neg_int(inp_port): + with pytest.raises(Exception): + inp_port.value_length = -10 + +def test_set_value_length_complex(inp_port): + with pytest.raises(Exception): + inp_port.value_length = (2+4j) + +def test_set_value_length_float(inp_port): + with pytest.raises(Exception): + inp_port.value_length = 3.2 + +def test_set_value_length_pos_then_none(inp_port): + inp_port.value_length = 10 + inp_port.value_length = None + assert inp_port.value_length is None diff --git a/test/test_operation.py b/test/test_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..c3a05bb5a08fa443753c2bafcf2b035274098455 --- /dev/null +++ b/test/test_operation.py @@ -0,0 +1,31 @@ +from b_asic.core_operations import Constant, Addition, ConstantAddition, Butterfly +from b_asic.signal import Signal +from b_asic.port import InputPort, OutputPort + +import pytest + + +class TestTraverse: + def test_traverse_single_tree(self, operation): + """Traverse a tree consisting of one operation.""" + constant = Constant(None) + assert list(constant.traverse()) == [constant] + + def test_traverse_tree(self, operation_tree): + """Traverse a basic addition tree with two constants.""" + assert len(list(operation_tree.traverse())) == 3 + + def test_traverse_large_tree(self, large_operation_tree): + """Traverse a larger tree.""" + assert len(list(large_operation_tree.traverse())) == 7 + + def test_traverse_type(self, large_operation_tree): + traverse = list(large_operation_tree.traverse()) + assert len( + list(filter(lambda type_: isinstance(type_, Addition), traverse))) == 3 + assert len( + list(filter(lambda type_: isinstance(type_, Constant), traverse))) == 4 + + def test_traverse_loop(self, operation_tree): + # TODO: Construct a graph that contains a loop and make sure you can traverse it properly. + assert True diff --git a/test/test_outputport.py b/test/test_outputport.py new file mode 100644 index 0000000000000000000000000000000000000000..21f08764ac4d7f9497dc02615cce343120598959 --- /dev/null +++ b/test/test_outputport.py @@ -0,0 +1,84 @@ +""" +B-ASIC test suite for OutputPort. +""" +from b_asic import OutputPort, InputPort, Signal +import pytest + +@pytest.fixture +def output_port(): + return OutputPort(None, 0) + +@pytest.fixture +def input_port(): + return InputPort(None, 0) + +@pytest.fixture +def list_of_input_ports(): + return [InputPort(None, i) for i in range(0, 3)] + +class TestConnect: + def test_multiple_ports(self, output_port, list_of_input_ports): + """Can multiple ports connect to an output port?""" + for port in list_of_input_ports: + port.connect(output_port) + + assert output_port.signal_count == len(list_of_input_ports) + + def test_same_port(self, output_port, list_of_input_ports): + """Check error handing.""" + list_of_input_ports[0].connect(output_port) + with pytest.raises(Exception): + list_of_input_ports[0].connect(output_port) + + assert output_port.signal_count == 1 + +class TestAddSignal: + def test_dangling(self, output_port): + s = Signal() + output_port.add_signal(s) + + assert output_port.signal_count == 1 + assert output_port.signals == [s] + +class TestDisconnect: + def test_others_clear(self, output_port, list_of_input_ports): + """Can multiple ports disconnect from OutputPort?""" + for port in list_of_input_ports: + port.connect(output_port) + + for port in list_of_input_ports: + port.clear() + + assert output_port.signal_count == 3 + assert all(s.dangling() for s in output_port.signals) + + def test_self_clear(self, output_port, list_of_input_ports): + """Can an OutputPort disconnect from multiple ports?""" + for port in list_of_input_ports: + port.connect(output_port) + + output_port.clear() + + assert output_port.signal_count == 0 + assert output_port.signals == [] + +class TestRemoveSignal: + def test_one_signal(self, output_port, input_port): + s = input_port.connect(output_port) + output_port.remove_signal(s) + + assert output_port.signal_count == 0 + assert output_port.signals == [] + + def test_multiple_signals(self, output_port, list_of_input_ports): + """Can multiple signals disconnect from OutputPort?""" + sigs = [] + + for port in list_of_input_ports: + sigs.append(port.connect(output_port)) + + for s in sigs: + output_port.remove_signal(s) + + assert output_port.signal_count == 0 + assert output_port.signals == [] diff --git a/test/test_sfg.py b/test/test_sfg.py new file mode 100644 index 0000000000000000000000000000000000000000..d3daf2e96db0cd87350b148e0969febb2397a0fd --- /dev/null +++ b/test/test_sfg.py @@ -0,0 +1,32 @@ +from b_asic import SFG +from b_asic.signal import Signal +from b_asic.core_operations import Addition, Constant +from b_asic.special_operations import Input, Output + +class TestConstructor: + def test_outputs_construction(self, operation_tree): + outp = Output(operation_tree) + sfg = SFG(outputs=[outp]) + + assert len(list(sfg.components)) == 7 + assert sfg.input_count == 0 + assert sfg.output_count == 1 + + def test_signals_construction(self, operation_tree): + outs = Signal(source=operation_tree.output(0)) + sfg = SFG(output_signals=[outs]) + + assert len(list(sfg.components)) == 7 + assert sfg.input_count == 0 + assert sfg.output_count == 1 + + def test_operations_construction(self, operation_tree): + sfg1 = SFG(operations=[operation_tree]) + sfg2 = SFG(operations=[operation_tree.input(1).signals[0].source.operation]) + + assert len(list(sfg1.components)) == 5 + assert len(list(sfg2.components)) == 5 + assert sfg1.input_count == 0 + assert sfg2.input_count == 0 + assert sfg1.output_count == 0 + assert sfg2.output_count == 0 diff --git a/test/test_signal.py b/test/test_signal.py new file mode 100644 index 0000000000000000000000000000000000000000..9a45086a99e55089c9e25100cdd56399ca46a5cc --- /dev/null +++ b/test/test_signal.py @@ -0,0 +1,62 @@ +""" +B-ASIC test suit for the signal module which consists of the Signal class. +""" + +from b_asic.port import InputPort, OutputPort +from b_asic.signal import Signal + +import pytest + +def test_signal_creation_and_disconnction_and_connection_changing(): + in_port = InputPort(None, 0) + out_port = OutputPort(None, 1) + s = Signal(out_port, in_port) + + assert in_port.signals == [s] + assert out_port.signals == [s] + assert s.source is out_port + assert s.destination is in_port + + in_port1 = InputPort(None, 0) + s.set_destination(in_port1) + + assert in_port.signals == [] + assert in_port1.signals == [s] + assert out_port.signals == [s] + assert s.source is out_port + assert s.destination is in_port1 + + s.remove_source() + + assert out_port.signals == [] + assert in_port1.signals == [s] + assert s.source is None + assert s.destination is in_port1 + + s.remove_destination() + + assert out_port.signals == [] + assert in_port1.signals == [] + assert s.source is None + assert s.destination is None + + out_port1 = OutputPort(None, 0) + s.set_source(out_port1) + + assert out_port1.signals == [s] + assert s.source is out_port1 + assert s.destination is None + + s.set_source(out_port) + + assert out_port.signals == [s] + assert out_port1.signals == [] + assert s.source is out_port + assert s.destination is None + + s.set_destination(in_port) + + assert out_port.signals == [s] + assert in_port.signals == [s] + assert s.source is out_port + assert s.destination is in_port diff --git a/test/traverse/test_traverse_tree.py b/test/traverse/test_traverse_tree.py deleted file mode 100644 index 57e8a67befc512146859a8999152ff5c679b4588..0000000000000000000000000000000000000000 --- a/test/traverse/test_traverse_tree.py +++ /dev/null @@ -1,78 +0,0 @@ -""" -TODO: - - Rewrite to more clean code, not so repetitive - - Update when signals and id's has been merged. -""" - -from b_asic.core_operations import Constant, Addition -from b_asic.signal import Signal -from b_asic.port import InputPort, OutputPort -from b_asic.traverse_tree import Traverse - -import pytest - -@pytest.fixture -def operation(): - return Constant(2) - -def create_operation(_type, dest_oper, index, **kwargs): - oper = _type(**kwargs) - oper_signal = Signal() - oper._output_ports[0].connect(oper_signal) - - dest_oper._input_ports[index].connect(oper_signal) - return oper - -@pytest.fixture -def operation_tree(): - add_oper = Addition() - - const_oper = create_operation(Constant, add_oper, 0, value=2) - const_oper_2 = create_operation(Constant, add_oper, 1, value=3) - - return add_oper - -@pytest.fixture -def large_operation_tree(): - add_oper = Addition() - add_oper_2 = Addition() - - const_oper = create_operation(Constant, add_oper, 0, value=2) - const_oper_2 = create_operation(Constant, add_oper, 1, value=3) - - const_oper_3 = create_operation(Constant, add_oper_2, 0, value=4) - const_oper_4 = create_operation(Constant, add_oper_2, 1, value=5) - - add_oper_3 = Addition() - add_oper_signal = Signal(add_oper, add_oper_3) - add_oper._output_ports[0].connect(add_oper_signal) - add_oper_3._input_ports[0].connect(add_oper_signal) - - add_oper_2_signal = Signal(add_oper_2, add_oper_3) - add_oper_2._output_ports[0].connect(add_oper_2_signal) - add_oper_3._input_ports[1].connect(add_oper_2_signal) - return const_oper - -def test_traverse_single_tree(operation): - traverse = Traverse(operation) - assert traverse.traverse() == [operation] - -def test_traverse_tree(operation_tree): - traverse = Traverse(operation_tree) - assert len(traverse.traverse()) == 3 - -def test_traverse_large_tree(large_operation_tree): - traverse = Traverse(large_operation_tree) - assert len(traverse.traverse()) == 7 - -def test_traverse_type(large_operation_tree): - traverse = Traverse(large_operation_tree) - assert len(traverse.traverse(Addition)) == 3 - assert len(traverse.traverse(Constant)) == 4 - -def test_traverse_loop(operation_tree): - add_oper_signal = Signal() - operation_tree._output_ports[0].connect(add_oper_signal) - operation_tree._input_ports[0].connect(add_oper_signal) - traverse = Traverse(operation_tree) - assert len(traverse.traverse()) == 2 \ No newline at end of file