diff --git a/b_asic/graph_component.py b/b_asic/graph_component.py index c926a1734b5c81642b27ace7db3e7dbc54c7f609..0092e1b6ed8b5e3c3acd591910ce50c8aeb152c8 100644 --- a/b_asic/graph_component.py +++ b/b_asic/graph_component.py @@ -72,7 +72,7 @@ class GraphComponent(ABC): raise NotImplementedError @abstractmethod - def copy_component(self) -> "GraphComponent": + def copy_component(self, *args, **kwargs) -> "GraphComponent": """Get a new instance of this graph component type with the same name, id and parameters.""" raise NotImplementedError @@ -130,22 +130,22 @@ class AbstractGraphComponent(GraphComponent): def set_param(self, name: str, value: Any) -> None: self._parameters[name] = value - def copy_component(self) -> GraphComponent: - new_comp = self.__class__() - new_comp.name = copy(self.name) - new_comp.graph_id = copy(self.graph_id) + def copy_component(self, *args, **kwargs) -> GraphComponent: + new_component = self.__class__(*args, **kwargs) + new_component.name = copy(self.name) + new_component.graph_id = copy(self.graph_id) for name, value in self.params.items(): - new_comp.set_param(copy(name), deepcopy(value)) # pylint: disable=no-member - return new_comp + new_component.set_param(copy(name), deepcopy(value)) # pylint: disable=no-member + return new_component def traverse(self) -> Generator[GraphComponent, None, None]: # Breadth first search. visited = {self} fontier = deque([self]) while fontier: - comp = fontier.popleft() - yield comp - for neighbor in comp.neighbors: + component = fontier.popleft() + yield component + for neighbor in component.neighbors: if neighbor not in visited: visited.add(neighbor) fontier.append(neighbor) \ No newline at end of file diff --git a/b_asic/operation.py b/b_asic/operation.py index c010ebb012983ae4a19141a63e3a7687e97188ad..7ec719ffbb850490b33573dd3977814355fe62b8 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -142,22 +142,14 @@ class AbstractOperation(Operation, AbstractGraphComponent): def __init__(self, input_count: int, output_count: int, name: Name = "", input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None): super().__init__(name) - self._input_ports = [] - self._output_ports = [] - - # Allocate input ports. - for i in range(input_count): - self._input_ports.append(InputPort(self, i)) - - # Allocate output ports. - for i in range(output_count): - self._output_ports.append(OutputPort(self, i)) + self._input_ports = [InputPort(self, i) for i in range(input_count)] # Allocate input ports. + self._output_ports = [OutputPort(self, i) for i in range(output_count)] # Allocate output ports. # Connect given input sources, if any. if input_sources is not None: source_count = len(input_sources) if source_count != input_count: - raise ValueError(f"Operation expected {input_count} input sources but only got {source_count}") + raise ValueError(f"Wrong number of input sources supplied to Operation (expected {input_count}, got {source_count})") for i, src in enumerate(input_sources): if src is not None: self._input_ports[i].connect(src.source) @@ -169,21 +161,21 @@ class AbstractOperation(Operation, AbstractGraphComponent): return n & ((2 ** bits) - 1) @abstractmethod - def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ + def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ """Evaluate the operation and generate a list of output values given a list of input values. """ raise NotImplementedError - def _find_result(self, prefix: str, index: int, results: MutableMapping[str, Optional[Number]]) -> Optional[Number]: - key = results_key(self.output_count, prefix, index) + def _results_key(self, prefix: str, index: int) -> str: + return results_key(self.output_count, prefix, index) + + def _find_result(self, key: str, results: MutableMapping[str, Optional[Number]]) -> Optional[Number]: if key in results: value = results[key] if value is None: raise RuntimeError(f"Direct feedback loop detected when evaluating operation.") return value - - results[key] = None return None def _truncate_inputs(self, input_values: Sequence[Number]): @@ -282,9 +274,11 @@ class AbstractOperation(Operation, AbstractGraphComponent): if registers is None: registers = {} - result = self._find_result(prefix, index, results) + key = self._results_key(prefix, index) + result = self._find_result(key, results) if result is not None: return result + results[key] = None values = self.evaluate(*self._truncate_inputs(input_values)) if isinstance(values, collections.abc.Sequence): if len(values) != self.output_count: diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 78a037e2dbc5ff9fafe8bf146755e6405be1766b..d2df8ceeda8300ebed6b67b1c98bd291baf89c5f 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -176,9 +176,11 @@ class SFG(AbstractOperation): if registers is None: registers = {} - result = self._find_result(prefix, index, results) + key = self._results_key(prefix, index) + result = self._find_result(key, results) if result is not None: return result + results[key] = None # Set the values of our input operations to the given input values. for op, arg in zip(self._input_operations, self._truncate_inputs(input_values)): @@ -190,6 +192,10 @@ class SFG(AbstractOperation): def split(self) -> Iterable[Operation]: return self.operations + + def copy_component(self, *args, **kwargs) -> GraphComponent: + return super().copy_component(*args, **kwargs, inputs = self._input_operations, outputs = self._output_operations, + id_number_offset = self._graph_id_generator.id_number_offset, name = self.name) @property def id_number_offset(self) -> GraphIDNumber: @@ -321,11 +327,14 @@ class SFG(AbstractOperation): op_stack.append(original_connected_op) def _evaluate_source(self, src: OutputPort, results: MutableMapping[str, Number], registers: MutableMapping[str, Number], prefix: str) -> Number: - op_prefix = prefix - if op_prefix: - op_prefix += "." - op_prefix += src.operation.graph_id + src_prefix = prefix + if src_prefix: + src_prefix += "." + src_prefix += src.operation.graph_id + + # TODO: Handle registers. + input_values = [self._evaluate_source(input_port.signals[0].source, results, registers, prefix) for input_port in src.operation.inputs] - value = src.operation.evaluate_output(src.index, input_values, results, registers, op_prefix) - results[results_key(src.operation.output_count, op_prefix, src.index)] = value + value = src.operation.evaluate_output(src.index, input_values, results, registers, src_prefix) + results[results_key(src.operation.output_count, src_prefix, src.index)] = value return value diff --git a/b_asic/special_operations.py b/b_asic/special_operations.py index 140fa410cc63d5e553aea48a32dd344642b1dc1a..951bbde4648c780c9183e56730bbb8d02c4b5b21 100644 --- a/b_asic/special_operations.py +++ b/b_asic/special_operations.py @@ -55,12 +55,12 @@ class Output(AbstractOperation): class Register(AbstractOperation): - """Delay operation. + """Unit delay operation. TODO: More info. """ def __init__(self, initial_value: Number = 0, src0: Optional[SignalSourceProvider] = None, name: Name = ""): - super().__init__(input_count = 1, output_count = 0, name = name, input_sources = [src0]) + super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0]) self.set_param("initial_value", initial_value) @property @@ -79,11 +79,9 @@ class Register(AbstractOperation): results = {} if registers is None: registers = {} - + if prefix in results: return results[prefix] - if prefix in registers: - return registers[prefix] value = registers.get(prefix, self.param("initial_value")) registers[prefix] = self._truncate_inputs(input_values)[0] diff --git a/test/fixtures/operation_tree.py b/test/fixtures/operation_tree.py index 3ac35110c79b1fe3fc7af86e6e0ad975be1db396..fc8008fa4098ca488e23766f5ff7d05711300685 100644 --- a/test/fixtures/operation_tree.py +++ b/test/fixtures/operation_tree.py @@ -10,41 +10,50 @@ def operation(): @pytest.fixture def operation_tree(): """Valid addition operation connected with 2 constants. - 2>--+ + 2---+ | - 2+3=5> + v + add = 2 + 3 = 5 + ^ | - 3>--+ + 3---+ """ return Addition(Constant(2), Constant(3)) @pytest.fixture def large_operation_tree(): """Valid addition operation connected with a large operation tree with 2 other additions and 4 constants. - 2>--+ + 2---+ | - 2+3=5>--+ - | | - 3>--+ | - 5+9=14> - 4>--+ | - | | - 4+5=9>--+ + v + add---+ + ^ | + | | + 3---+ v + add = (2 + 3) + (4 + 5) = 14 + 4---+ ^ + | | + v | + add---+ + ^ | - 5>--+ + 5---+ """ return Addition(Addition(Constant(2), Constant(3)), Addition(Constant(4), Constant(5))) @pytest.fixture def operation_graph_with_cycle(): """Invalid addition operation connected with an operation graph containing a cycle. - +---+ - | | - ?+7=?>-------+ - | | - 7>--+ ?+6=?> - | - 6 + +-+ + | | + v | + add+---+ + ^ | + | v + 7 add = (? + 7) + 6 = ? + ^ + | + 6 """ add1 = Addition(None, Constant(7)) add1.input(0).connect(add1) diff --git a/test/fixtures/signal_flow_graph.py b/test/fixtures/signal_flow_graph.py index af41262a4a240c2c69671de089b70dd9f512ece2..c12b41aca09e47df3d4ae3c2b1edd096f3f44b81 100644 --- a/test/fixtures/signal_flow_graph.py +++ b/test/fixtures/signal_flow_graph.py @@ -6,16 +6,18 @@ from b_asic import SFG, Input, Output, Constant, Register @pytest.fixture def sfg_two_inputs_two_outputs(): """Valid SFG with two inputs and two outputs. - . . - in1>------+ +---------------out1> - . | | . - . in1+in2=add1>--+ . - . | | . - in2>------+ | . - | . add1+in2=add2>---out2> - | . | . - +------------------+ . - . . + . . + in1-------+ +--------->out1 + . | | . + . v | . + . add1+--+ . + . ^ | . + . | v . + in2+------+ add2---->out2 + | . ^ . + | . | . + +------------+ . + . . out1 = in1 + in2 out2 = in1 + 2 * in2 """ diff --git a/test/test_simulation.py b/test/test_simulation.py index 7adccda4b7162edac00de9d9f7d876305bd35715..6d3868a83fc62cacffa1fa70a6669bdded0bdfae 100644 --- a/test/test_simulation.py +++ b/test/test_simulation.py @@ -102,7 +102,7 @@ class TestSimulation: output2 = simulation.run() assert output1[0] == 11405 - assert output2[0] == 8109 + assert output2[0] == 4221 def test_simulate_with_register(self, sfg_accumulator): data_in = np.array([5, -2, 25, -6, 7, 0])