diff --git a/b_asic/operation.py b/b_asic/operation.py index 7ec719ffbb850490b33573dd3977814355fe62b8..49b7833794e68ab8d296719ee392eae5f6f9936c 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -7,20 +7,14 @@ import collections from abc import abstractmethod from numbers import Number -from typing import List, Sequence, Iterable, MutableMapping, Optional, Any, Set, Union +from typing import NewType, List, Sequence, Iterable, Mapping, MutableMapping, Optional, Any, Set, Union from math import trunc from b_asic.graph_component import GraphComponent, AbstractGraphComponent, Name from b_asic.port import SignalSourceProvider, InputPort, OutputPort from b_asic.signal import Signal -def results_key(output_count: int, prefix: str, index: int): - key = prefix - if output_count != 1: - if key: - key += "." - key += str(index) - return key +ResultKey = NewType("ResultKey", str) class Operation(GraphComponent, SignalSourceProvider): """Operation interface. @@ -97,28 +91,54 @@ class Operation(GraphComponent, SignalSourceProvider): @abstractmethod def input_signals(self) -> Iterable[Signal]: """Get all the signals that are connected to this operation's input ports, - in no particular order.""" + in no particular order. + """ raise NotImplementedError @property @abstractmethod def output_signals(self) -> Iterable[Signal]: """Get all the signals that are connected to this operation's output ports, - in no particular order.""" + in no particular order. + """ + raise NotImplementedError + + @abstractmethod + def key(self, index: int, prefix: str = "") -> ResultKey: + """Get the key used to access the result of a certain output of this operation + from the results parameter passed to current_output(s) or evaluate_output(s). + """ + raise NotImplementedError + + @abstractmethod + def current_output(self, index: int, results: Optional[MutableMapping[ResultKey, Optional[Number]]] = None, registers: Optional[Mapping[ResultKey, Number]] = None, prefix: str = "") -> Optional[Number]: + """Get the current output at the given index of this operation, if available. + The results parameter will be used to store any results (including intermediate results) for caching. + The registers parameter will be used for lookup. + The prefix parameter will be used as a prefix for the key string when looking for registers. + See also: current_outputs, evaluate_output, evaluate_outputs. + """ raise NotImplementedError @abstractmethod - def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[str, Optional[Number]]] = None, registers: Optional[MutableMapping[str, Number]] = None, prefix: str = "") -> Number: + def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[ResultKey, Optional[Number]]] = None, registers: Optional[MutableMapping[ResultKey, Number]] = None, prefix: str = "") -> Number: """Evaluate the output at the given index of this operation with the given input values. - The results parameter will be used to store any intermediate results for caching. + The results parameter will be used to store any results (including intermediate results) for caching. The registers parameter will be used to get the current value of any intermediate registers that are encountered, and be updated with their new values. The prefix parameter will be used as a prefix for the key string when storing results/registers. - See also: evaluate_outputs. + See also: evaluate_outputs, current_output, current_outputs. + """ + raise NotImplementedError + + @abstractmethod + def current_outputs(self, results: Optional[MutableMapping[ResultKey, Optional[Number]]] = None, registers: Optional[Mapping[ResultKey, Number]] = None, prefix: str = "") -> Sequence[Optional[Number]]: + """Get all current outputs of this operation, if available. + See current_output for more information. """ raise NotImplementedError @abstractmethod - def evaluate_outputs(self, input_values: Sequence[Number], results: MutableMapping[str, Number], registers: MutableMapping[str, Number], prefix: str = "") -> Sequence[Number]: + def evaluate_outputs(self, input_values: Sequence[Number], results: MutableMapping[ResultKey, Number], registers: MutableMapping[ResultKey, Number], prefix: str = "") -> Sequence[Number]: """Evaluate all outputs of this operation given the input values. See evaluate_output for more information. """ @@ -154,34 +174,17 @@ class AbstractOperation(Operation, AbstractGraphComponent): if src is not None: self._input_ports[i].connect(src.source) - def truncate_input(self, index: int, value: Number, bits: int): + def truncate_input(self, index: int, value: Number, bits: int) -> Number: + """Truncate the value to be used as input at the given index to a certain bit length.""" n = value if not isinstance(n, int): n = trunc(value) return n & ((2 ** bits) - 1) - @abstractmethod - def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ - """Evaluate the operation and generate a list of output values given a - list of input values. - """ - raise NotImplementedError - - def _results_key(self, prefix: str, index: int) -> str: - return results_key(self.output_count, prefix, index) - - def _find_result(self, key: str, results: MutableMapping[str, Optional[Number]]) -> Optional[Number]: - if key in results: - value = results[key] - if value is None: - raise RuntimeError(f"Direct feedback loop detected when evaluating operation.") - return value - return None - - def _truncate_inputs(self, input_values: Sequence[Number]): + def truncate_inputs(self, input_values: Sequence[Number]) -> Sequence[Number]: + """Truncate the values to be used as inputs to the bit lengths specified by the respective signals connected to each input.""" args = [] - for i in range(self.input_count): - input_port = self.input(i) + for i, input_port in enumerate(self.inputs): if input_port.signal_count >= 1: bits = input_port.signals[0].bits if bits is None: @@ -194,6 +197,11 @@ class AbstractOperation(Operation, AbstractGraphComponent): args.append(input_values[i]) return args + @abstractmethod + def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ + """Evaluate the operation and generate a list of output values given a list of input values.""" + raise NotImplementedError + def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Addition, ConstantAddition]": # Import here to avoid circular imports. from b_asic.core_operations import Addition, ConstantAddition @@ -263,8 +271,22 @@ class AbstractOperation(Operation, AbstractGraphComponent): for s in p.signals: result.append(s) return result + + def key(self, index: int, prefix: str = "") -> ResultKey: + key = prefix + if self.output_count != 1: + if key: + key += "." + key += str(index) + elif not key: + key = str(index) + return key + + def current_output(self, index: int, results: Optional[MutableMapping[ResultKey, Optional[Number]]] = None, registers: Optional[Mapping[ResultKey, Number]] = None, prefix: str = "") -> Optional[Number]: + results[self.key(index, prefix)] = None + return None - def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[str, Optional[Number]]] = None, registers: Optional[MutableMapping[str, Number]] = None, prefix: str = "") -> Number: + def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[ResultKey, Optional[Number]]] = None, registers: Optional[MutableMapping[ResultKey, Number]] = None, prefix: str = "") -> Number: if index < 0 or index >= self.output_count: raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {index})") if len(input_values) != self.input_count: @@ -274,12 +296,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): if registers is None: registers = {} - key = self._results_key(prefix, index) - result = self._find_result(key, results) - if result is not None: - return result - results[key] = None - values = self.evaluate(*self._truncate_inputs(input_values)) + values = self.evaluate(*self.truncate_inputs(input_values)) if isinstance(values, collections.abc.Sequence): if len(values) != self.output_count: raise RuntimeError(f"Operation evaluated to incorrect number of outputs (expected {self.output_count}, got {len(values)})") @@ -291,13 +308,16 @@ class AbstractOperation(Operation, AbstractGraphComponent): raise RuntimeError(f"Operation evaluated to invalid type (expected Sequence/Number, got {values.__class__.__name__})") if self.output_count == 1: - results[results_key(self.output_count, prefix, index)] = values[index] + results[self.key(index, prefix)] = values[index] else: for i in range(self.output_count): - results[results_key(self.output_count, prefix, i)] = values[i] + results[self.key(i, prefix)] = values[i] return values[index] - def evaluate_outputs(self, input_values: Sequence[Number], results: MutableMapping[str, Number], registers: MutableMapping[str, Number], prefix: str = "") -> Sequence[Number]: + def current_outputs(self, results: Optional[MutableMapping[ResultKey, Optional[Number]]] = None, registers: Optional[Mapping[ResultKey, Number]] = None, prefix: str = "") -> Sequence[Optional[Number]]: + return [self.current_output(i, results, registers, prefix) for i in range(self.output_count)] + + def evaluate_outputs(self, input_values: Sequence[Number], results: MutableMapping[ResultKey, Number], registers: MutableMapping[ResultKey, Number], prefix: str = "") -> Sequence[Number]: return [self.evaluate_output(i, input_values, results, registers, prefix) for i in range(self.output_count)] def split(self) -> Iterable[Operation]: diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index d2df8ceeda8300ebed6b67b1c98bd291baf89c5f..4b5d2450b4ea350489862e6b06cd951940f12d54 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -8,7 +8,7 @@ from numbers import Number from collections import defaultdict, deque from b_asic.port import SignalSourceProvider, OutputPort -from b_asic.operation import Operation, AbstractOperation, results_key +from b_asic.operation import Operation, AbstractOperation, ResultKey from b_asic.signal import Signal from b_asic.graph_component import GraphID, GraphIDNumber, GraphComponent, Name, TypeName from b_asic.special_operations import Input, Output @@ -166,7 +166,7 @@ class SFG(AbstractOperation): n = len(result) return None if n == 0 else result[0] if n == 1 else result - def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[str, Optional[Number]]] = None, registers: Optional[MutableMapping[str, Number]] = None, prefix: str = "") -> Number: + def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[ResultKey, Optional[Number]]] = None, registers: Optional[MutableMapping[ResultKey, Number]] = None, prefix: str = "") -> Number: if index < 0 or index >= self.output_count: raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {index})") if len(input_values) != self.input_count: @@ -175,19 +175,13 @@ class SFG(AbstractOperation): results = {} if registers is None: registers = {} - - key = self._results_key(prefix, index) - result = self._find_result(key, results) - if result is not None: - return result - results[key] = None # Set the values of our input operations to the given input values. - for op, arg in zip(self._input_operations, self._truncate_inputs(input_values)): + for op, arg in zip(self._input_operations, self.truncate_inputs(input_values)): op.value = arg value = self._evaluate_source(self._output_operations[index].input(0).signals[0].source, results, registers, prefix) - results[results_key(self.output_count, prefix, index)] = value + results[self.key(index, prefix)] = value return value def split(self) -> Iterable[Operation]: @@ -326,15 +320,21 @@ class SFG(AbstractOperation): # Add connected operation to the queue of operations to visit. op_stack.append(original_connected_op) - def _evaluate_source(self, src: OutputPort, results: MutableMapping[str, Number], registers: MutableMapping[str, Number], prefix: str) -> Number: + def _evaluate_source(self, src: OutputPort, results: MutableMapping[ResultKey, Number], registers: MutableMapping[ResultKey, Number], prefix: str) -> Number: src_prefix = prefix if src_prefix: src_prefix += "." src_prefix += src.operation.graph_id - # TODO: Handle registers. - + key = src.operation.key(src.index, src_prefix) + if key in results: + value = results[key] + if value is None: + raise RuntimeError(f"Direct feedback loop detected when evaluating operation.") + return value + + src.operation.current_output(src.index, results, registers, src_prefix) input_values = [self._evaluate_source(input_port.signals[0].source, results, registers, prefix) for input_port in src.operation.inputs] value = src.operation.evaluate_output(src.index, input_values, results, registers, src_prefix) - results[results_key(src.operation.output_count, src_prefix, src.index)] = value + results[key] = value return value diff --git a/b_asic/simulation.py b/b_asic/simulation.py index de8a526993a0bad4f709936693b40ac741553d65..d36cc56876a2cce5dc82a36f7e77cf9ba6863927 100644 --- a/b_asic/simulation.py +++ b/b_asic/simulation.py @@ -7,6 +7,7 @@ from collections import defaultdict from numbers import Number from typing import List, Dict, DefaultDict, Callable, Sequence, Mapping, Union, Optional +from b_asic.operation import ResultKey from b_asic.signal_flow_graph import SFG @@ -23,18 +24,18 @@ class Simulation: _current_input_values: Sequence[Number] _latest_output_values: Sequence[Number] - def __init__(self, sfg: SFG, input_providers: Optional[Sequence[Union[Sequence[Number], Callable[[int], Number]]]] = None): + def __init__(self, sfg: SFG, input_providers: Optional[Sequence[Union[None, Sequence[Number], Callable[[int], Number]]]] = None): self._sfg = sfg self._results = defaultdict(dict) self._registers = {} self._iteration = 0 - self._input_functions = [] + self._input_functions = [lambda n: 0 for _ in range(self._sfg.input_count)] self._current_input_values = [] self._latest_output_values = [0 for _ in range(self._sfg.output_count)] if input_providers is not None: self.set_inputs(input_providers) - def set_inputs(self, input_providers: Sequence[Union[Sequence[Number], Callable[[int], Number]]]) -> None: + def set_inputs(self, input_providers: Sequence[Union[None, Sequence[Number], Callable[[int], Number]]]) -> None: """Set the input functions used to get values for the inputs to the internal SFG.""" if len(input_providers) != self._sfg.input_count: raise ValueError(f"Wrong number of inputs supplied to simulation (expected {self._sfg.input_count}, got {len(input_providers)})") @@ -42,15 +43,16 @@ class Simulation: for index, input_provider in enumerate(input_providers): self.set_input(index, input_provider) - def set_input(self, index: int, input_provider: Union[Sequence[Number], Callable[[int], Number]]) -> None: + def set_input(self, index: int, input_provider: Union[None, Sequence[Number], Callable[[int], Number]]) -> None: """Set the input function used to get values for the specific input at the given index to the internal SFG.""" if index < 0 or index >= len(self._input_functions): - raise IndexError(f"Input index out of range (expected 0-{self._input_functions - 1}, got {index})") + raise IndexError(f"Input index out of range (expected 0-{len(self._input_functions) - 1}, got {index})") - if callable(input_provider): - self._input_functions[index] = input_provider - else: - self._input_functions[index] = lambda n: input_provider[n] + if input_provider is not None: + if callable(input_provider): + self._input_functions[index] = input_provider + else: + self._input_functions[index] = lambda n: input_provider[n] def run(self) -> Sequence[Number]: """Run one iteration of the simulation and return the resulting output values.""" @@ -75,7 +77,7 @@ class Simulation: return self._iteration @property - def results(self) -> Mapping[int, Mapping[str, Number]]: + def results(self) -> Mapping[int, Mapping[ResultKey, Number]]: """Get a mapping of all results, including intermediate values, calculated for each iteration up until now. The outer mapping maps from iteration number to value mapping. The value mapping maps output port identifiers to values. Example: {0: {"c1": 3, "c2": 4, "bfly1.0": 7, "bfly1.1": -1, "0": 7}}""" diff --git a/b_asic/special_operations.py b/b_asic/special_operations.py index 951bbde4648c780c9183e56730bbb8d02c4b5b21..fcc3976faffdc4e5f837a00b31d63f31325f835e 100644 --- a/b_asic/special_operations.py +++ b/b_asic/special_operations.py @@ -4,9 +4,9 @@ TODO: More info. """ from numbers import Number -from typing import Optional, Sequence, MutableMapping +from typing import Optional, Sequence, Mapping, MutableMapping -from b_asic.operation import AbstractOperation +from b_asic.operation import AbstractOperation, ResultKey from b_asic.graph_component import Name, TypeName from b_asic.port import SignalSourceProvider @@ -59,7 +59,7 @@ class Register(AbstractOperation): TODO: More info. """ - def __init__(self, initial_value: Number = 0, src0: Optional[SignalSourceProvider] = None, name: Name = ""): + def __init__(self, src0: Optional[SignalSourceProvider] = None, initial_value: Number = 0, name: Name = ""): super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) self.set_param("initial_value", initial_value) @@ -70,20 +70,27 @@ class Register(AbstractOperation): def evaluate(self, a): return self.param("initial_value") - def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[str, Optional[Number]]] = None, registers: Optional[MutableMapping[str, Number]] = None, prefix: str = ""): + def current_output(self, index: int, results: Optional[MutableMapping[ResultKey, Optional[Number]]] = None, registers: Optional[Mapping[ResultKey, Number]] = None, prefix: str = "") -> Optional[Number]: + key = self.key(index, prefix) + value = self.param("initial_value") + if registers is not None: + value = registers.get(key, value) + results[key] = value + return value + + def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[ResultKey, Optional[Number]]] = None, registers: Optional[MutableMapping[ResultKey, Number]] = None, prefix: str = ""): if index != 0: raise IndexError(f"Output index out of range (expected 0-0, got {index})") if len(input_values) != 1: raise ValueError(f"Wrong number of inputs supplied to SFG for evaluation (expected 1, got {len(input_values)})") - if results is None: - results = {} - if registers is None: - registers = {} - - if prefix in results: - return results[prefix] - value = registers.get(prefix, self.param("initial_value")) - registers[prefix] = self._truncate_inputs(input_values)[0] - results[prefix] = value + key = self.key(index, prefix) + if registers is not None: + value = registers.get(key, self.param("initial_value")) + results[key] = value + registers[key] = self.truncate_inputs(input_values)[0] + return value + + value = self.param("initial_value") + results[key] = value return value \ No newline at end of file diff --git a/test/fixtures/signal_flow_graph.py b/test/fixtures/signal_flow_graph.py index c12b41aca09e47df3d4ae3c2b1edd096f3f44b81..df7d82b397eeaa936517371361d09eebc33a067f 100644 --- a/test/fixtures/signal_flow_graph.py +++ b/test/fixtures/signal_flow_graph.py @@ -38,7 +38,7 @@ def sfg_nested(): mac_in2 = Input() mac_in3 = Input() mac_out1 = Output(mac_in1 + mac_in2 * mac_in3) - MAC = SFG(inputs = [mac_in1, mac_in2, mac_in3], outputs=[mac_out1]) + MAC = SFG(inputs = [mac_in1, mac_in2, mac_in3], outputs = [mac_out1]) in1 = Input() in2 = Input() @@ -46,7 +46,17 @@ def sfg_nested(): mac2 = MAC(in1, in2, mac1) mac3 = MAC(in1, mac1, mac2) out1 = Output(mac3) - return SFG(inputs = [in1, in2], outputs=[out1]) + return SFG(inputs = [in1, in2], outputs = [out1]) + +@pytest.fixture +def sfg_delay(): + """Valid SFG with one input and one output. + out1 = in1' + """ + in1 = Input() + reg1 = Register(in1) + out1 = Output(reg1) + return SFG(inputs = [in1], outputs = [out1]) @pytest.fixture def sfg_accumulator(): diff --git a/test/test_simulation.py b/test/test_simulation.py index 6d3868a83fc62cacffa1fa70a6669bdded0bdfae..a8103aed800736fedf2b425bf3c23fab377a4094 100644 --- a/test/test_simulation.py +++ b/test/test_simulation.py @@ -104,7 +104,21 @@ class TestSimulation: assert output1[0] == 11405 assert output2[0] == 4221 - def test_simulate_with_register(self, sfg_accumulator): + def test_simulate_delay(self, sfg_delay): + simulation = Simulation(sfg_delay) + simulation.set_input(0, [5, -2, 25, -6, 7, 0]) + simulation.run_for(6) + + print(simulation.results) + + assert simulation.results[0]["0"] == 0 + assert simulation.results[1]["0"] == 5 + assert simulation.results[2]["0"] == -2 + assert simulation.results[3]["0"] == 25 + assert simulation.results[4]["0"] == -6 + assert simulation.results[5]["0"] == 7 + + def test_simulate_accumulator(self, sfg_accumulator): data_in = np.array([5, -2, 25, -6, 7, 0]) reset = np.array([0, 0, 0, 1, 0, 0]) simulation = Simulation(sfg_accumulator, [data_in, reset])