From af8e63324f70c159cda113f4ebd8cc7ea087953d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ivar=20H=C3=A4rnqvist?= <ivarhar@outlook.com> Date: Thu, 9 Apr 2020 01:15:34 +0200 Subject: [PATCH] implement simulation --- b_asic/graph_component.py | 77 ++++++++++++++++++++++--- b_asic/operation.py | 93 +++++++++++++++--------------- b_asic/signal_flow_graph.py | 76 +++++++++++++----------- b_asic/simulation.py | 89 +++++++++++++++++++--------- b_asic/special_operations.py | 46 +++++++++++++-- test/conftest.py | 1 + test/fixtures/operation_tree.py | 14 ++--- test/fixtures/port.py | 4 +- test/fixtures/signal.py | 2 + test/fixtures/signal_flow_graph.py | 28 +++++++++ test/test_abstract_operation.py | 6 +- test/test_core_operations.py | 92 ++++++++++++++--------------- test/test_graph_id_generator.py | 3 +- test/test_inputport.py | 3 +- test/test_operation.py | 5 +- test/test_outputport.py | 5 +- test/test_sfg.py | 10 ++-- test/test_signal.py | 6 +- test/test_simulation.py | 37 ++++++++++++ 19 files changed, 400 insertions(+), 197 deletions(-) create mode 100644 test/fixtures/signal_flow_graph.py create mode 100644 test/test_simulation.py diff --git a/b_asic/graph_component.py b/b_asic/graph_component.py index 52eba17c..efcbbd47 100644 --- a/b_asic/graph_component.py +++ b/b_asic/graph_component.py @@ -4,11 +4,14 @@ TODO: More info. """ from abc import ABC, abstractmethod -from copy import copy -from typing import NewType +from copy import copy, deepcopy +from typing import NewType, Any, Optional, Dict + Name = NewType("Name", str) TypeName = NewType("TypeName", str) +GraphID = NewType("GraphID", str) +GraphIDNumber = NewType("GraphIDNumber", int) class GraphComponent(ABC): @@ -19,24 +22,57 @@ class GraphComponent(ABC): @property @abstractmethod def type_name(self) -> TypeName: - """Return the type name of the graph component""" + """Get the type name of this graph component""" raise NotImplementedError @property @abstractmethod def name(self) -> Name: - """Return the name of the graph component.""" + """Get the name of this graph component.""" raise NotImplementedError @name.setter @abstractmethod def name(self, name: Name) -> None: - """Set the name of the graph component to the entered name.""" + """Set the name of this graph component to the given name.""" + raise NotImplementedError + + @property + @abstractmethod + def graph_id(self) -> GraphID: + """Get the graph id of this graph component.""" + raise NotImplementedError + + @graph_id.setter + @abstractmethod + def graph_id(self, graph_id: GraphID) -> None: + """Set the graph id of this graph component to the given id. + Note that this id will be ignored if this component is used to create a new graph, + and that a new local id will be generated for it instead.""" raise NotImplementedError @abstractmethod - def copy_unconnected(self) -> "GraphComponent": - """Get a copy of this graph component, except without any connected components.""" + def params(self) -> Dict[str, Optional[Any]]: + """Get a dictionary of all parameter values.""" + raise NotImplementedError + + @abstractmethod + def param(self, name: str) -> Optional[Any]: + """Get the value of a parameter. + Returns None if the parameter is not defined. + """ + raise NotImplementedError + + @abstractmethod + def set_param(self, name: str, value: Any) -> None: + """Set the value of a parameter. + Adds the parameter if it is not already defined. + """ + raise NotImplementedError + + @abstractmethod + def copy_component(self) -> "GraphComponent": + """Get a new instance of this graph component type with the same name, id and parameters.""" raise NotImplementedError @@ -47,9 +83,13 @@ class AbstractGraphComponent(GraphComponent): """ _name: Name + _graph_id: GraphID + _parameters: Dict[str, Optional[Any]] def __init__(self, name: Name = ""): self._name = name + self._graph_id = "" + self._parameters = {} @property def name(self) -> Name: @@ -58,8 +98,29 @@ class AbstractGraphComponent(GraphComponent): @name.setter def name(self, name: Name) -> None: self._name = name + + @property + def graph_id(self) -> GraphID: + return self._graph_id + + @graph_id.setter + def graph_id(self, graph_id: GraphID) -> None: + self._graph_id = graph_id + + @property + def params(self) -> Dict[str, Optional[Any]]: + return self._parameters.copy() + + def param(self, name: str) -> Optional[Any]: + return self._parameters.get(name) + + def set_param(self, name: str, value: Any) -> None: + self._parameters[name] = value - def copy_unconnected(self) -> GraphComponent: + def copy_component(self) -> GraphComponent: new_comp = self.__class__() new_comp.name = copy(self.name) + new_comp.graph_id = copy(self.graph_id) + for name, value in self.params.items(): + new_comp.set_param(copy(name), deepcopy(value)) # pylint: disable=no-member return new_comp \ No newline at end of file diff --git a/b_asic/operation.py b/b_asic/operation.py index d644dbd3..a8dc7a96 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -6,14 +6,20 @@ TODO: More info. import collections from abc import abstractmethod -from copy import deepcopy from numbers import Number -from typing import List, Sequence, Iterable, Dict, Optional, Any, Set, Generator, Union +from typing import List, Sequence, Iterable, MutableMapping, Optional, Any, Set, Generator, Union from collections import deque from b_asic.graph_component import GraphComponent, AbstractGraphComponent, Name from b_asic.port import SignalSourceProvider, InputPort, OutputPort +def results_key(output_count: int, prefix: str, index: int): + key = prefix + if output_count != 1: + if key: + key += "." + key += str(index) + return key class Operation(GraphComponent, SignalSourceProvider): """Operation interface. @@ -87,32 +93,19 @@ class Operation(GraphComponent, SignalSourceProvider): raise NotImplementedError @abstractmethod - def params(self) -> Dict[str, Optional[Any]]: - """Get a dictionary of all parameter values.""" - raise NotImplementedError - - @abstractmethod - def param(self, name: str) -> Optional[Any]: - """Get the value of a parameter. - Returns None if the parameter is not defined. - """ - raise NotImplementedError - - @abstractmethod - def set_param(self, name: str, value: Any) -> None: - """Set the value of a parameter. - Adds the parameter if it is not already defined. + def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[str, Number]] = None, registers: Optional[MutableMapping[str, Number]] = None, prefix: str = "") -> Number: + """Evaluate the output at the given index of this operation with the given input values. + The results parameter will be used to store any intermediate results for caching. + The registers parameter will be used to get the current value of any intermediate registers that are encountered, and be updated with their new values. + The prefix parameter will be used as a prefix for the key string when storing results/registers. + See also: evaluate_outputs. """ raise NotImplementedError @abstractmethod - def evaluate_output(self, i: int, input_values: Sequence[Number]) -> Sequence[Optional[Number]]: - """Evaluate the output at index i of this operation with the given input values. - The returned sequence contains results corresponding to each output of this operation, - where a value of None means it was not evaluated. - The value at index i is guaranteed to have been evaluated, while the others may or may not - have been evaluated depending on what is the most efficient. - For example, Butterfly().evaluate_output(1, [5, 4]) may result in either (9, 1) or (None, 1). + def evaluate_outputs(self, input_values: Sequence[Number], results: MutableMapping[str, Number], registers: MutableMapping[str, Number], prefix: str = "") -> Sequence[Number]: + """Evaluate all outputs of this operation given the input values. + See evaluate_output for more information. """ raise NotImplementedError @@ -146,13 +139,11 @@ class AbstractOperation(Operation, AbstractGraphComponent): _input_ports: List[InputPort] _output_ports: List[OutputPort] - _parameters: Dict[str, Optional[Any]] def __init__(self, input_count: int, output_count: int, name: Name = "", input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None): super().__init__(name) self._input_ports = [] self._output_ports = [] - self._parameters = {} # Allocate input ports. for i in range(input_count): @@ -232,27 +223,39 @@ class AbstractOperation(Operation, AbstractGraphComponent): def output(self, i: int) -> OutputPort: return self._output_ports[i] - @property - def params(self) -> Dict[str, Optional[Any]]: - return self._parameters.copy() - - def param(self, name: str) -> Optional[Any]: - return self._parameters.get(name) + def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[str, Number]] = None, registers: Optional[MutableMapping[str, Number]] = None, prefix: str = "") -> Number: + if index < 0 or index >= self.output_count: + raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {index})") + if results is None: + results = {} + if registers is None: + registers = {} + + key = results_key(self.output_count, prefix, index) + if key in results: + return results[key] - def set_param(self, name: str, value: Any) -> None: - self._parameters[name] = value - - def evaluate_output(self, i: int, input_values: Sequence[Number]) -> Sequence[Optional[Number]]: result = self.evaluate(*input_values) if isinstance(result, collections.Sequence): if len(result) != self.output_count: - raise RuntimeError("Operation evaluated to incorrect number of outputs") - return result - if isinstance(result, Number): + raise RuntimeError(f"Operation evaluated to incorrect number of outputs (expected {self.output_count}, got {len(result)})") + elif isinstance(result, Number): if self.output_count != 1: - raise RuntimeError("Operation evaluated to incorrect number of outputs") - return [result] - raise RuntimeError("Operation evaluated to invalid type") + raise RuntimeError(f"Operation evaluated to incorrect number of outputs (expected {self.output_count}, got 1)") + result = (result,) + else: + raise RuntimeError(f"Operation evaluated to invalid type (expected Sequence/Number, got {result.__class__.__name__})") + + if self.output_count == 1: + results[key] = result[index] + else: + for i, value in enumerate(result): + results[results_key(self.output_count, prefix, i)] = value + return result[index] + + + def evaluate_outputs(self, input_values: Sequence[Number], results: MutableMapping[str, Number], registers: MutableMapping[str, Number], prefix: str = "") -> Sequence[Number]: + return [self.evaluate_output(i, input_values, results, registers, prefix) for i in range(self.output_count)] def split(self) -> Iterable[Operation]: # Import here to avoid circular imports. @@ -298,9 +301,3 @@ class AbstractOperation(Operation, AbstractGraphComponent): diff = "more" if self.output_count > 1 else "less" raise TypeError(f"{self.__class__.__name__} cannot be used as an input source because it has {diff} than 1 output") return self.output(0) - - def copy_unconnected(self) -> GraphComponent: - new_comp: AbstractOperation = super().copy_unconnected() - for name, value in self.params.items(): - new_comp.set_param(name, deepcopy(value)) # pylint: disable=no-member - return new_comp diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index a011653f..b796f160 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -3,21 +3,17 @@ B-ASIC Signal Flow Graph Module. TODO: More info. """ -from typing import NewType, List, Iterable, Sequence, Dict, Optional, DefaultDict, Set +from typing import List, Iterable, Sequence, Dict, MutableMapping, Optional, DefaultDict, Set from numbers import Number from collections import defaultdict from b_asic.port import SignalSourceProvider, OutputPort -from b_asic.operation import Operation, AbstractOperation +from b_asic.operation import Operation, AbstractOperation, results_key from b_asic.signal import Signal -from b_asic.graph_component import GraphComponent, Name, TypeName +from b_asic.graph_component import GraphID, GraphIDNumber, GraphComponent, Name, TypeName from b_asic.special_operations import Input, Output -GraphID = NewType("GraphID", str) -GraphIDNumber = NewType("GraphIDNumber", int) - - class GraphIDGenerator: """A class that generates Graph IDs for objects.""" @@ -47,7 +43,7 @@ class SFG(AbstractOperation): _original_output_signals: Dict[Signal, int] def __init__(self, input_signals: Sequence[Signal] = [], output_signals: Sequence[Signal] = [], \ - inputs: Sequence[Input] = [], outputs: Sequence[Output] = [], operations: Sequence[Operation] = [], \ + inputs: Sequence[Input] = [], outputs: Sequence[Output] = [], \ id_number_offset: GraphIDNumber = 0, name: Name = "", \ input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None): super().__init__( @@ -97,33 +93,35 @@ class SFG(AbstractOperation): if s.source.operation not in self._original_components_added: self._add_operation_copy_recursively(s.source.operation) - # Search the graph outwards from each operation. - for op in operations: - if op not in self._original_components_added: - self._add_operation_copy_recursively(op) - @property def type_name(self) -> TypeName: return "sfg" def evaluate(self, *args): - if len(args) != self.input_count: - raise ValueError("Wrong number of inputs supplied to SFG for evaluation") - for arg, op in zip(args, self._input_operations): - op.value = arg - - result = [] - for op in self._output_operations: - result.append(self._evaluate_source(op.input(0).signals[0].source)) - + result = self.evaluate_outputs(args, {}, {}, "") n = len(result) return None if n == 0 else result[0] if n == 1 else result - def evaluate_output(self, i: int, input_values: Sequence[Number]) -> Sequence[Optional[Number]]: - assert i >= 0 and i < self.output_count, "Output index out of range" - result = [None] * self.output_count - result[i] = self._evaluate_source(self._output_operations[i].input(0).signals[0].source) - return result + def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[str, Number]] = None, registers: Optional[MutableMapping[str, Number]] = None, prefix: str = "") -> Number: + if index < 0 or index >= self.output_count: + raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {index})") + if len(input_values) != self.input_count: + raise ValueError(f"Wrong number of inputs supplied to SFG for evaluation (expected {self.input_count}, got {len(input_values)})") + if results is None: + results = {} + if registers is None: + registers = {} + + # Set the values of our input operations to the given input values. + for op, arg in zip(self._input_operations, input_values): + op.value = arg + + key = results_key(self.output_count, prefix, index) + if key in results: + return results[key] + value = self._evaluate_source(self._output_operations[index].input(0).signals[0].source, results, registers, prefix) + results[key] = value + return value def split(self) -> Iterable[Operation]: return filter(lambda comp: isinstance(comp, Operation), self._components_by_id.values()) @@ -156,8 +154,10 @@ class SFG(AbstractOperation): assert original_comp not in self._original_components_added, "Tried to add duplicate SFG component" self._original_components_added.add(original_comp) - new_comp = original_comp.copy_unconnected() - self._components_by_id[self._graph_id_generator.next_id(new_comp.type_name)] = new_comp + new_comp = original_comp.copy_component() + new_id = self._graph_id_generator.next_id(new_comp.type_name) + new_comp.graph_id = new_id + self._components_by_id[new_id] = new_comp self._components_by_name[new_comp.name].append(new_comp) return new_comp @@ -201,9 +201,15 @@ class SFG(AbstractOperation): return new_op - def _evaluate_source(self, src: OutputPort) -> Number: - input_values = [] - for input_port in src.operation.inputs: - input_src = input_port.signals[0].source - input_values.append(self._evaluate_source(input_src)) - return src.operation.evaluate_output(src.index, input_values) \ No newline at end of file + def _evaluate_source(self, src: OutputPort, results: MutableMapping[str, Number], registers: MutableMapping[str, Number], prefix: str) -> Number: + op_prefix = prefix + if op_prefix: + op_prefix += "." + op_prefix += src.operation.graph_id + key = results_key(src.operation.output_count, op_prefix, src.index) + if key in results: + return results[key] + input_values = [self._evaluate_source(input_port.signals[0].source, results, registers, prefix) for input_port in src.operation.inputs] + value = src.operation.evaluate_output(src.index, input_values, results, registers, op_prefix) + results[key] = value + return value \ No newline at end of file diff --git a/b_asic/simulation.py b/b_asic/simulation.py index a2ce11b3..bd7d53a6 100644 --- a/b_asic/simulation.py +++ b/b_asic/simulation.py @@ -3,41 +3,74 @@ B-ASIC Simulation Module. TODO: More info. """ +from collections import defaultdict from numbers import Number -from typing import List, Dict +from typing import List, Dict, DefaultDict, Callable, Sequence, Mapping +from b_asic.signal_flow_graph import SFG -class OperationState: - """Simulation state of an operation. + +class Simulation: + """Simulation. TODO: More info. """ - output_values: List[Number] - iteration: int + _sfg: SFG + _results: DefaultDict[int, Dict[str, Number]] + _registers: Dict[str, Number] + _iteration: int + _input_functions: Sequence[Callable[[int], Number]] + _current_input_values: Sequence[Number] + _latest_output_values: Sequence[Number] - def __init__(self): - self.output_values = [] - self.iteration = 0 + def __init__(self, sfg: SFG, input_functions: Sequence[Callable[[int], Number]]): + if len(input_functions) != sfg.input_count: + raise ValueError(f"Wrong number of inputs supplied to simulation (expected {len(sfg.input_count)}, got {len(input_functions)})") + self._sfg = sfg + self._results = defaultdict(dict) + self._registers = {} + self._iteration = 0 + self._input_functions = list(input_functions) + self._current_input_values = [0 for _ in range(self._sfg.input_count)] + self._latest_output_values = [0 for _ in range(self._sfg.output_count)] + def input_functions(self, input_functions: Sequence[Callable[[int], Number]]) -> None: + """Set the input functions used to get values for the inputs to the internal SFG.""" + if len(input_functions) != len(self._input_functions): + raise ValueError(f"Wrong number of inputs supplied to simulation (expected {len(self._input_functions)}, got {len(input_functions)})") + self._input_functions = input_functions -class SimulationState: - """Simulation state. - TODO: More info. - """ + def input_function(self, index: int, input_function: Callable[[int], Number]) -> None: + """Set the input function used to get values for the specific input at the given index to the internal SFG.""" + if index < 0 or index >= len(self._input_functions): + raise IndexError(f"Input index out of range (expected 0-{self._input_functions - 1}, got {index})") + self._input_functions[index] = input_function + + def run(self) -> Sequence[Number]: + """Run one iteration of the simulation and return the resulting output values.""" + return self.run_for(1) + + def run_until(self, iteration: int) -> Sequence[Number]: + """Run the simulation until its iteration is greater than or equal to the given iteration + and return the resulting output values.""" + while self._iteration < iteration: + self._current_input_values = [self._input_functions[i](self._iteration) for i in range(self._sfg.input_count)] + self._latest_output_values = self._sfg.evaluate_outputs(self._current_input_values, self._results[self._iteration], self._registers) + self._iteration += 1 + return self._latest_output_values + + def run_for(self, iterations: int) -> Sequence[Number]: + """Run a given number of iterations of the simulation and return the resulting output values.""" + return self.run_until(self._iteration + iterations) + + @property + def iteration(self) -> int: + """Get the current iteration number of the simulation.""" + return self._iteration - operation_states: Dict[int, OperationState] - iteration: int - - def __init__(self): - op_state = OperationState() - self.operation_states = {1: op_state} - self.iteration = 0 - - # @property - # #def iteration(self): - # return self.iteration - # @iteration.setter - # def iteration(self, new_iteration: int): - # self.iteration = new_iteration - # - # TODO: More stuff + @property + def results(self) -> Mapping[int, Mapping[str, Number]]: + """Get a mapping of all results, including intermediate values, calculated for each iteration up until now. + The outer mapping maps from iteration number to value mapping. The value mapping maps output port identifiers to values. + Example: {0: {"c1": 3, "c2": 4, "bfly1.0": 7, "bfly1.1": -1, "0": 7}}""" + return self._results diff --git a/b_asic/special_operations.py b/b_asic/special_operations.py index 465c0086..36ebf2d8 100644 --- a/b_asic/special_operations.py +++ b/b_asic/special_operations.py @@ -4,7 +4,7 @@ TODO: More info. """ from numbers import Number -from typing import Optional +from typing import Optional, Sequence, MutableMapping from b_asic.operation import AbstractOperation from b_asic.graph_component import Name, TypeName @@ -44,11 +44,49 @@ class Output(AbstractOperation): """ def __init__(self, src0: Optional[SignalSourceProvider] = None, name: Name = ""): - super().__init__(input_count = 1, output_count = 0, name = name, input_sources=[src0]) + super().__init__(input_count = 1, output_count = 0, name = name, input_sources = [src0]) @property def type_name(self) -> TypeName: return "out" - def evaluate(self): - return None \ No newline at end of file + def evaluate(self, _): + return None + + +class Register(AbstractOperation): + """Delay operation. + TODO: More info. + """ + + def __init__(self, initial_value: Number = 0, src0: Optional[SignalSourceProvider] = None, name: Name = ""): + super().__init__(input_count = 1, output_count = 0, name = name, input_sources = [src0]) + self.set_param("initial_value", initial_value) + + @property + def type_name(self) -> TypeName: + return "reg" + + def evaluate(self, a): + return self.param("initial_value") + + def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[str, Number]] = None, registers: Optional[MutableMapping[str, Number]] = None, prefix: str = ""): + if index != 0: + raise IndexError(f"Output index out of range (expected 0-0, got {index})") + if len(input_values) != 1: + raise ValueError(f"Wrong number of inputs supplied to SFG for evaluation (expected 1, got {len(input_values)})") + if results is None: + results = {} + if registers is None: + registers = {} + + if prefix in results: + return results[prefix] + + if prefix in registers: + return registers[prefix] + + value = registers.get(prefix, self.param("initial_value")) + registers[prefix] = input_values[0] + results[prefix] = value + return value \ No newline at end of file diff --git a/test/conftest.py b/test/conftest.py index 64f39843..48b49489 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,4 +1,5 @@ from test.fixtures.signal import signal, signals from test.fixtures.operation_tree import * from test.fixtures.port import * +from test.fixtures.signal_flow_graph import * import pytest diff --git a/test/fixtures/operation_tree.py b/test/fixtures/operation_tree.py index e5274f02..32758675 100644 --- a/test/fixtures/operation_tree.py +++ b/test/fixtures/operation_tree.py @@ -1,15 +1,15 @@ -from b_asic.core_operations import Addition, Constant -from b_asic.signal import Signal - import pytest +from b_asic import Addition, Constant, Signal + + @pytest.fixture def operation(): return Constant(2) @pytest.fixture def operation_tree(): - """Return a addition operation connected with 2 constants. + """Valid addition operation connected with 2 constants. 2>--+ | 2+3=5> @@ -20,7 +20,7 @@ def operation_tree(): @pytest.fixture def large_operation_tree(): - """Return an addition operation connected with a large operation tree with 2 other additions and 4 constants. + """Valid addition operation connected with a large operation tree with 2 other additions and 4 constants. 2>--+ | 2+3=5>--+ @@ -37,7 +37,7 @@ def large_operation_tree(): @pytest.fixture def operation_graph_with_cycle(): - """Return an invalid addition operation connected with an operation graph containing a cycle. + """Invalid addition operation connected with an operation graph containing a cycle. +---+ | | ?+7=?>-------+ @@ -49,4 +49,4 @@ def operation_graph_with_cycle(): c1 = Constant(7) add1 = Addition(None, c1) add1.input(0).connect(add1) - return Addition(add1, c1) \ No newline at end of file + return Addition(add1, c1) diff --git a/test/fixtures/port.py b/test/fixtures/port.py index 63632ecd..fa528b8d 100644 --- a/test/fixtures/port.py +++ b/test/fixtures/port.py @@ -1,5 +1,7 @@ import pytest -from b_asic.port import InputPort, OutputPort + +from b_asic import InputPort, OutputPort + @pytest.fixture def input_port(): diff --git a/test/fixtures/signal.py b/test/fixtures/signal.py index 0c5692fe..4dba99e2 100644 --- a/test/fixtures/signal.py +++ b/test/fixtures/signal.py @@ -1,6 +1,8 @@ import pytest + from b_asic import Signal + @pytest.fixture def signal(): """Return a signal with no connections.""" diff --git a/test/fixtures/signal_flow_graph.py b/test/fixtures/signal_flow_graph.py new file mode 100644 index 00000000..c8e1dc9a --- /dev/null +++ b/test/fixtures/signal_flow_graph.py @@ -0,0 +1,28 @@ +import pytest + +from b_asic import SFG, Input, Output, Addition + + +@pytest.fixture +def sfg_two_inputs_two_outputs(): + """Valid SFG containing two inputs and two outputs. + . . + in1>------+ +---------------out1> + . | | . + . in1+in2=add1>--+ . + . | | . + in2>------+ | . + | . add1+in2=add2>---out2> + | . | . + +------------------+ . + . . + out1 = in1 + in2 + out2 = in1 + 2 * in2 + """ + in1=Input() + in2=Input() + add1=Addition(in1, in2) + add2=Addition(add1, in2) + out1=Output(add1) + out2=Output(add2) + return SFG(inputs = [in1, in2], outputs = [out1, out2]) \ No newline at end of file diff --git a/test/test_abstract_operation.py b/test/test_abstract_operation.py index 626a2dc3..ab53dabf 100644 --- a/test/test_abstract_operation.py +++ b/test/test_abstract_operation.py @@ -2,11 +2,11 @@ B-ASIC test suite for the AbstractOperation class. """ -from b_asic.core_operations import Addition, ConstantAddition, Subtraction, ConstantSubtraction, \ - Multiplication, ConstantMultiplication, Division, ConstantDivision - import pytest +from b_asic import Addition, ConstantAddition, Subtraction, ConstantSubtraction, \ + Multiplication, ConstantMultiplication, Division, ConstantDivision + def test_addition_overload(): """Tests addition overloading for both operation and number argument.""" diff --git a/test/test_core_operations.py b/test/test_core_operations.py index 93f388fb..2e7506c7 100644 --- a/test/test_core_operations.py +++ b/test/test_core_operations.py @@ -2,7 +2,7 @@ B-ASIC test suite for the core operations. """ -from b_asic.core_operations import Constant, Addition, Subtraction, \ +from b_asic import Constant, Addition, Subtraction, \ Multiplication, Division, SquareRoot, ComplexConjugate, Max, Min, \ Absolute, ConstantMultiplication, ConstantAddition, ConstantSubtraction, \ ConstantDivision, Butterfly @@ -11,199 +11,199 @@ from b_asic.core_operations import Constant, Addition, Subtraction, \ class TestConstant: def test_constant_positive(self): test_operation = Constant(3) - assert test_operation.evaluate_output(0, [])[0] == 3 + assert test_operation.evaluate_output(0, []) == 3 def test_constant_negative(self): test_operation = Constant(-3) - assert test_operation.evaluate_output(0, [])[0] == -3 + assert test_operation.evaluate_output(0, []) == -3 def test_constant_complex(self): test_operation = Constant(3+4j) - assert test_operation.evaluate_output(0, [])[0] == 3+4j + assert test_operation.evaluate_output(0, []) == 3+4j class TestAddition: def test_addition_positive(self): test_operation = Addition() - assert test_operation.evaluate_output(0, [3, 5])[0] == 8 + assert test_operation.evaluate_output(0, [3, 5]) == 8 def test_addition_negative(self): test_operation = Addition() - assert test_operation.evaluate_output(0, [-3, -5])[0] == -8 + assert test_operation.evaluate_output(0, [-3, -5]) == -8 def test_addition_complex(self): test_operation = Addition() - assert test_operation.evaluate_output(0, [3+5j, 4+6j])[0] == 7+11j + assert test_operation.evaluate_output(0, [3+5j, 4+6j]) == 7+11j class TestSubtraction: def test_subtraction_positive(self): test_operation = Subtraction() - assert test_operation.evaluate_output(0, [5, 3])[0] == 2 + assert test_operation.evaluate_output(0, [5, 3]) == 2 def test_subtraction_negative(self): test_operation = Subtraction() - assert test_operation.evaluate_output(0, [-5, -3])[0] == -2 + assert test_operation.evaluate_output(0, [-5, -3]) == -2 def test_subtraction_complex(self): test_operation = Subtraction() - assert test_operation.evaluate_output(0, [3+5j, 4+6j])[0] == -1-1j + assert test_operation.evaluate_output(0, [3+5j, 4+6j]) == -1-1j class TestMultiplication: def test_multiplication_positive(self): test_operation = Multiplication() - assert test_operation.evaluate_output(0, [5, 3])[0] == 15 + assert test_operation.evaluate_output(0, [5, 3]) == 15 def test_multiplication_negative(self): test_operation = Multiplication() - assert test_operation.evaluate_output(0, [-5, -3])[0] == 15 + assert test_operation.evaluate_output(0, [-5, -3]) == 15 def test_multiplication_complex(self): test_operation = Multiplication() - assert test_operation.evaluate_output(0, [3+5j, 4+6j])[0] == -18+38j + assert test_operation.evaluate_output(0, [3+5j, 4+6j]) == -18+38j class TestDivision: def test_division_positive(self): test_operation = Division() - assert test_operation.evaluate_output(0, [30, 5])[0] == 6 + assert test_operation.evaluate_output(0, [30, 5]) == 6 def test_division_negative(self): test_operation = Division() - assert test_operation.evaluate_output(0, [-30, -5])[0] == 6 + assert test_operation.evaluate_output(0, [-30, -5]) == 6 def test_division_complex(self): test_operation = Division() - assert test_operation.evaluate_output(0, [60+40j, 10+20j])[0] == 2.8-1.6j + assert test_operation.evaluate_output(0, [60+40j, 10+20j]) == 2.8-1.6j class TestSquareRoot: def test_squareroot_positive(self): test_operation = SquareRoot() - assert test_operation.evaluate_output(0, [36])[0] == 6 + assert test_operation.evaluate_output(0, [36]) == 6 def test_squareroot_negative(self): test_operation = SquareRoot() - assert test_operation.evaluate_output(0, [-36])[0] == 6j + assert test_operation.evaluate_output(0, [-36]) == 6j def test_squareroot_complex(self): test_operation = SquareRoot() - assert test_operation.evaluate_output(0, [48+64j])[0] == 8+4j + assert test_operation.evaluate_output(0, [48+64j]) == 8+4j class TestComplexConjugate: def test_complexconjugate_positive(self): test_operation = ComplexConjugate() - assert test_operation.evaluate_output(0, [3+4j])[0] == 3-4j + assert test_operation.evaluate_output(0, [3+4j]) == 3-4j def test_test_complexconjugate_negative(self): test_operation = ComplexConjugate() - assert test_operation.evaluate_output(0, [-3-4j])[0] == -3+4j + assert test_operation.evaluate_output(0, [-3-4j]) == -3+4j class TestMax: def test_max_positive(self): test_operation = Max() - assert test_operation.evaluate_output(0, [30, 5])[0] == 30 + assert test_operation.evaluate_output(0, [30, 5]) == 30 def test_max_negative(self): test_operation = Max() - assert test_operation.evaluate_output(0, [-30, -5])[0] == -5 + assert test_operation.evaluate_output(0, [-30, -5]) == -5 class TestMin: def test_min_positive(self): test_operation = Min() - assert test_operation.evaluate_output(0, [30, 5])[0] == 5 + assert test_operation.evaluate_output(0, [30, 5]) == 5 def test_min_negative(self): test_operation = Min() - assert test_operation.evaluate_output(0, [-30, -5])[0] == -30 + assert test_operation.evaluate_output(0, [-30, -5]) == -30 class TestAbsolute: def test_absolute_positive(self): test_operation = Absolute() - assert test_operation.evaluate_output(0, [30])[0] == 30 + assert test_operation.evaluate_output(0, [30]) == 30 def test_absolute_negative(self): test_operation = Absolute() - assert test_operation.evaluate_output(0, [-5])[0] == 5 + assert test_operation.evaluate_output(0, [-5]) == 5 def test_absolute_complex(self): test_operation = Absolute() - assert test_operation.evaluate_output(0, [3+4j])[0] == 5.0 + assert test_operation.evaluate_output(0, [3+4j]) == 5.0 class TestConstantMultiplication: def test_constantmultiplication_positive(self): test_operation = ConstantMultiplication(5) - assert test_operation.evaluate_output(0, [20])[0] == 100 + assert test_operation.evaluate_output(0, [20]) == 100 def test_constantmultiplication_negative(self): test_operation = ConstantMultiplication(5) - assert test_operation.evaluate_output(0, [-5])[0] == -25 + assert test_operation.evaluate_output(0, [-5]) == -25 def test_constantmultiplication_complex(self): test_operation = ConstantMultiplication(3+2j) - assert test_operation.evaluate_output(0, [3+4j])[0] == 1+18j + assert test_operation.evaluate_output(0, [3+4j]) == 1+18j class TestConstantAddition: def test_constantaddition_positive(self): test_operation = ConstantAddition(5) - assert test_operation.evaluate_output(0, [20])[0] == 25 + assert test_operation.evaluate_output(0, [20]) == 25 def test_constantaddition_negative(self): test_operation = ConstantAddition(4) - assert test_operation.evaluate_output(0, [-5])[0] == -1 + assert test_operation.evaluate_output(0, [-5]) == -1 def test_constantaddition_complex(self): test_operation = ConstantAddition(3+2j) - assert test_operation.evaluate_output(0, [3+2j])[0] == 6+4j + assert test_operation.evaluate_output(0, [3+2j]) == 6+4j class TestConstantSubtraction: def test_constantsubtraction_positive(self): test_operation = ConstantSubtraction(5) - assert test_operation.evaluate_output(0, [20])[0] == 15 + assert test_operation.evaluate_output(0, [20]) == 15 def test_constantsubtraction_negative(self): test_operation = ConstantSubtraction(4) - assert test_operation.evaluate_output(0, [-5])[0] == -9 + assert test_operation.evaluate_output(0, [-5]) == -9 def test_constantsubtraction_complex(self): test_operation = ConstantSubtraction(4+6j) - assert test_operation.evaluate_output(0, [3+4j])[0] == -1-2j + assert test_operation.evaluate_output(0, [3+4j]) == -1-2j class TestConstantDivision: def test_constantdivision_positive(self): test_operation = ConstantDivision(5) - assert test_operation.evaluate_output(0, [20])[0] == 4 + assert test_operation.evaluate_output(0, [20]) == 4 def test_constantdivision_negative(self): test_operation = ConstantDivision(4) - assert test_operation.evaluate_output(0, [-20])[0] == -5 + assert test_operation.evaluate_output(0, [-20]) == -5 def test_constantdivision_complex(self): test_operation = ConstantDivision(2+2j) - assert test_operation.evaluate_output(0, [10+10j])[0] == 5 + assert test_operation.evaluate_output(0, [10+10j]) == 5 class TestButterfly: def test_butterfly_positive(self): test_operation = Butterfly() - assert test_operation.evaluate_output(0, [2, 3])[0] == 5 - assert test_operation.evaluate_output(1, [2, 3])[1] == -1 + assert test_operation.evaluate_output(0, [2, 3]) == 5 + assert test_operation.evaluate_output(1, [2, 3]) == -1 def test_butterfly_negative(self): test_operation = Butterfly() - assert test_operation.evaluate_output(0, [-2, -3])[0] == -5 - assert test_operation.evaluate_output(1, [-2, -3])[1] == 1 + assert test_operation.evaluate_output(0, [-2, -3]) == -5 + assert test_operation.evaluate_output(1, [-2, -3]) == 1 def test_buttefly_complex(self): test_operation = Butterfly() - assert test_operation.evaluate_output(0, [2+1j, 3-2j])[0] == 5-1j - assert test_operation.evaluate_output(1, [2+1j, 3-2j])[1] == -1+3j + assert test_operation.evaluate_output(0, [2+1j, 3-2j]) == 5-1j + assert test_operation.evaluate_output(1, [2+1j, 3-2j]) == -1+3j diff --git a/test/test_graph_id_generator.py b/test/test_graph_id_generator.py index b8e0cdeb..72c923b6 100644 --- a/test/test_graph_id_generator.py +++ b/test/test_graph_id_generator.py @@ -2,9 +2,10 @@ B-ASIC test suite for graph id generator. """ -from b_asic.signal_flow_graph import GraphIDGenerator, GraphID import pytest +from b_asic import GraphIDGenerator, GraphID + @pytest.fixture def graph_id_generator(): return GraphIDGenerator() diff --git a/test/test_inputport.py b/test/test_inputport.py index b43bf8e3..055eab28 100644 --- a/test/test_inputport.py +++ b/test/test_inputport.py @@ -4,8 +4,7 @@ B-ASIC test suite for Inputport import pytest -from b_asic import InputPort, OutputPort -from b_asic import Signal +from b_asic import InputPort, OutputPort, Signal @pytest.fixture def inp_port(): diff --git a/test/test_operation.py b/test/test_operation.py index c3a05bb5..bb09be0b 100644 --- a/test/test_operation.py +++ b/test/test_operation.py @@ -1,9 +1,6 @@ -from b_asic.core_operations import Constant, Addition, ConstantAddition, Butterfly -from b_asic.signal import Signal -from b_asic.port import InputPort, OutputPort - import pytest +from b_asic import Constant, Addition, ConstantAddition, Butterfly, Signal, InputPort, OutputPort class TestTraverse: def test_traverse_single_tree(self, operation): diff --git a/test/test_outputport.py b/test/test_outputport.py index 21f08764..189c8922 100644 --- a/test/test_outputport.py +++ b/test/test_outputport.py @@ -1,9 +1,11 @@ """ B-ASIC test suite for OutputPort. """ -from b_asic import OutputPort, InputPort, Signal import pytest +from b_asic import OutputPort, InputPort, Signal + + @pytest.fixture def output_port(): return OutputPort(None, 0) @@ -16,6 +18,7 @@ def input_port(): def list_of_input_ports(): return [InputPort(None, i) for i in range(0, 3)] + class TestConnect: def test_multiple_ports(self, output_port, list_of_input_ports): """Can multiple ports connect to an output port?""" diff --git a/test/test_sfg.py b/test/test_sfg.py index 91f50ea1..76c5178e 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -1,9 +1,7 @@ import pytest -from b_asic import SFG -from b_asic.signal import Signal -from b_asic.core_operations import Addition, Constant -from b_asic.special_operations import Input, Output +from b_asic import SFG, Signal, Output + class TestConstructor: def test_outputs_construction(self, operation_tree): @@ -28,8 +26,8 @@ class TestConstructor: class TestEvaluation: def test_evaluate_output(self, operation_tree): sfg = SFG(outputs=[Output(operation_tree)]) - assert sfg.evaluate_output(0, [])[0] == 5 + assert sfg.evaluate_output(0, []) == 5 def test_evaluate_output_large(self, large_operation_tree): sfg = SFG(outputs=[Output(large_operation_tree)]) - assert sfg.evaluate_output(0, [])[0] == 14 \ No newline at end of file + assert sfg.evaluate_output(0, []) == 14 \ No newline at end of file diff --git a/test/test_signal.py b/test/test_signal.py index 9a45086a..94ec1d3d 100644 --- a/test/test_signal.py +++ b/test/test_signal.py @@ -2,11 +2,11 @@ B-ASIC test suit for the signal module which consists of the Signal class. """ -from b_asic.port import InputPort, OutputPort -from b_asic.signal import Signal - import pytest +from b_asic import InputPort, OutputPort, Signal + + def test_signal_creation_and_disconnction_and_connection_changing(): in_port = InputPort(None, 0) out_port = OutputPort(None, 1) diff --git a/test/test_simulation.py b/test/test_simulation.py new file mode 100644 index 00000000..4c053310 --- /dev/null +++ b/test/test_simulation.py @@ -0,0 +1,37 @@ +from b_asic import SFG, Output, Simulation + + +class TestSimulation: + def test_simulate(self, sfg_two_inputs_two_outputs): + simulation = Simulation(sfg_two_inputs_two_outputs, [lambda n: n + 3, lambda n: 1 + n * 2]) + output = simulation.run_for(101) + assert output[0] == 304 + assert output[1] == 505 + + assert simulation.results[0]["in1"] == 3 + assert simulation.results[0]["in2"] == 1 + assert simulation.results[0]["add1"] == 4 + assert simulation.results[0]["add2"] == 5 + assert simulation.results[0]["0"] == 4 + assert simulation.results[0]["1"] == 5 + + assert simulation.results[1]["in1"] == 4 + assert simulation.results[1]["in2"] == 3 + assert simulation.results[1]["add1"] == 7 + assert simulation.results[1]["add2"] == 10 + assert simulation.results[1]["0"] == 7 + assert simulation.results[1]["1"] == 10 + + assert simulation.results[2]["in1"] == 5 + assert simulation.results[2]["in2"] == 5 + assert simulation.results[2]["add1"] == 10 + assert simulation.results[2]["add2"] == 15 + assert simulation.results[2]["0"] == 10 + assert simulation.results[2]["1"] == 15 + + assert simulation.results[3]["in1"] == 6 + assert simulation.results[3]["in2"] == 7 + assert simulation.results[3]["add1"] == 13 + assert simulation.results[3]["add2"] == 20 + assert simulation.results[3]["0"] == 13 + assert simulation.results[3]["1"] == 20 -- GitLab