diff --git a/b_asic/graph_component.py b/b_asic/graph_component.py index efcbbd473f06877563e427bac7cae7e23c2f85f5..c926a1734b5c81642b27ace7db3e7dbc54c7f609 100644 --- a/b_asic/graph_component.py +++ b/b_asic/graph_component.py @@ -4,8 +4,9 @@ TODO: More info. """ from abc import ABC, abstractmethod +from collections import deque from copy import copy, deepcopy -from typing import NewType, Any, Optional, Dict +from typing import NewType, Any, Optional, Dict, Iterable, Generator Name = NewType("Name", str) @@ -75,10 +76,22 @@ class GraphComponent(ABC): """Get a new instance of this graph component type with the same name, id and parameters.""" raise NotImplementedError + @property + @abstractmethod + def neighbors(self) -> Iterable["GraphComponent"]: + """Get all components that are directly connected to this operation.""" + raise NotImplementedError + + @abstractmethod + def traverse(self) -> Generator["GraphComponent", None, None]: + """Get a generator that recursively iterates through all components that are connected to this operation, + as well as the ones that they are connected to. + """ + raise NotImplementedError + class AbstractGraphComponent(GraphComponent): """Abstract Graph Component class which is a component of a signal flow graph. - TODO: More info. """ @@ -123,4 +136,16 @@ class AbstractGraphComponent(GraphComponent): new_comp.graph_id = copy(self.graph_id) for name, value in self.params.items(): new_comp.set_param(copy(name), deepcopy(value)) # pylint: disable=no-member - return new_comp \ No newline at end of file + return new_comp + + def traverse(self) -> Generator[GraphComponent, None, None]: + # Breadth first search. + visited = {self} + fontier = deque([self]) + while fontier: + comp = fontier.popleft() + yield comp + for neighbor in comp.neighbors: + if neighbor not in visited: + visited.add(neighbor) + fontier.append(neighbor) \ No newline at end of file diff --git a/b_asic/operation.py b/b_asic/operation.py index b4b9d243bb161a985bc7c9603e7750e23706bd1f..0f664f6d9bc44ce7f2e4349e589b9b21933a4362 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -7,11 +7,11 @@ import collections from abc import abstractmethod from numbers import Number -from typing import List, Sequence, Iterable, MutableMapping, Optional, Any, Set, Generator, Union -from collections import deque +from typing import List, Sequence, Iterable, MutableMapping, Optional, Any, Set, Union from b_asic.graph_component import GraphComponent, AbstractGraphComponent, Name from b_asic.port import SignalSourceProvider, InputPort, OutputPort +from b_asic.signal import Signal def results_key(output_count: int, prefix: str, index: int): key = prefix @@ -92,6 +92,20 @@ class Operation(GraphComponent, SignalSourceProvider): """Get the output port at index i.""" raise NotImplementedError + @property + @abstractmethod + def input_signals(self) -> Iterable[Signal]: + """Get all the signals that are connected to this operation's input ports, + in no particular order.""" + raise NotImplementedError + + @property + @abstractmethod + def output_signals(self) -> Iterable[Signal]: + """Get all the signals that are connected to this operation's output ports, + in no particular order.""" + raise NotImplementedError + @abstractmethod def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[str, Optional[Number]]] = None, registers: Optional[MutableMapping[str, Number]] = None, prefix: str = "") -> Number: """Evaluate the output at the given index of this operation with the given input values. @@ -116,21 +130,6 @@ class Operation(GraphComponent, SignalSourceProvider): """ raise NotImplementedError - @property - @abstractmethod - def neighbors(self) -> Iterable["Operation"]: - """Return all operations that are connected by signals to this operation. - If no neighbors are found, this returns an empty list. - """ - raise NotImplementedError - - @abstractmethod - def traverse(self) -> Generator["Operation", None, None]: - """Get a generator that recursively iterates through all operations that are connected by signals to this operation, - as well as the ones that they are connected to. - """ - raise NotImplementedError - class AbstractOperation(Operation, AbstractGraphComponent): """Generic abstract operation class which most implementations will derive from. @@ -235,6 +234,22 @@ class AbstractOperation(Operation, AbstractGraphComponent): def output(self, i: int) -> OutputPort: return self._output_ports[i] + @property + def input_signals(self) -> Iterable[Signal]: + result = [] + for p in self.inputs: + for s in p.signals: + result.append(s) + return result + + @property + def output_signals(self) -> Iterable[Signal]: + result = [] + for p in self.outputs: + for s in p.signals: + result.append(s) + return result + def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[str, Optional[Number]]] = None, registers: Optional[MutableMapping[str, Number]] = None, prefix: str = "") -> Number: if index < 0 or index >= self.output_count: raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {index})") @@ -283,32 +298,12 @@ class AbstractOperation(Operation, AbstractGraphComponent): return [self] @property - def neighbors(self) -> Iterable[Operation]: - neighbors = [] - for port in self._input_ports: - for signal in port.signals: - neighbors.append(signal.source.operation) - for port in self._output_ports: - for signal in port.signals: - neighbors.append(signal.destination.operation) - return neighbors - - def traverse(self) -> Generator[Operation, None, None]: - # Breadth first search. - visited = {self} - queue = deque([self]) - while queue: - operation = queue.popleft() - yield operation - for n_operation in operation.neighbors: - if n_operation not in visited: - visited.add(n_operation) - queue.append(n_operation) + def neighbors(self) -> Iterable[GraphComponent]: + return list(self.input_signals) + list(self.output_signals) @property def source(self) -> OutputPort: if self.output_count != 1: diff = "more" if self.output_count > 1 else "less" - raise TypeError( - f"{self.__class__.__name__} cannot be used as an input source because it has {diff} than 1 output") - return self.output(0) + raise TypeError(f"{self.__class__.__name__} cannot be used as an input source because it has {diff} than 1 output") + return self.output(0) \ No newline at end of file diff --git a/b_asic/port.py b/b_asic/port.py index 103d076af2702e7e565067f7568bb6035d24a2c8..e8c007cbf077f9f40df0d53fc08001e6436f0093 100644 --- a/b_asic/port.py +++ b/b_asic/port.py @@ -108,12 +108,10 @@ class InputPort(AbstractPort): """ _source_signal: Optional[Signal] - _value_length: Optional[int] def __init__(self, operation: "Operation", index: int): super().__init__(operation, index) self._source_signal = None - self._value_length = None @property def signal_count(self) -> int: @@ -153,18 +151,6 @@ class InputPort(AbstractPort): # self._source_signal is set by the signal constructor. return Signal(source=src.source, destination=self, name=name) - @property - def value_length(self) -> Optional[int]: - """Get the number of bits that this port should truncate received values to.""" - return self._value_length - - @value_length.setter - def value_length(self, bits: Optional[int]) -> None: - """Set the number of bits that this port should truncate received values to.""" - assert bits is None or (isinstance( - bits, int) and bits >= 0), "Value length must be non-negative." - self._value_length = bits - class OutputPort(AbstractPort, SignalSourceProvider): """Output port. diff --git a/b_asic/signal.py b/b_asic/signal.py index 67e1d0f908ba57f5d355e77794993587343e63cf..c3e9183d6b3d15f13ac70baad4335f034d94446d 100644 --- a/b_asic/signal.py +++ b/b_asic/signal.py @@ -1,9 +1,9 @@ """@package docstring B-ASIC Signal Module. """ -from typing import Optional, TYPE_CHECKING +from typing import Optional, Iterable, TYPE_CHECKING -from b_asic.graph_component import AbstractGraphComponent, TypeName, Name +from b_asic.graph_component import GraphComponent, AbstractGraphComponent, TypeName, Name if TYPE_CHECKING: from b_asic.port import InputPort, OutputPort @@ -16,7 +16,7 @@ class Signal(AbstractGraphComponent): _destination: Optional["InputPort"] def __init__(self, source: Optional["OutputPort"] = None, \ - destination: Optional["InputPort"] = None, name: Name = ""): + destination: Optional["InputPort"] = None, bits: Optional[int] = None, name: Name = ""): super().__init__(name) self._source = None self._destination = None @@ -24,7 +24,16 @@ class Signal(AbstractGraphComponent): self.set_source(source) if destination is not None: self.set_destination(destination) + self.set_param("bits", bits) + @property + def type_name(self) -> TypeName: + return "s" + + @property + def neighbors(self) -> Iterable[GraphComponent]: + return [p.operation for p in [self.source, self.destination] if p is not None] + @property def source(self) -> Optional["OutputPort"]: """Return the source OutputPort of the signal.""" @@ -63,10 +72,6 @@ class Signal(AbstractGraphComponent): if self not in dest.signals: dest.add_signal(self) - @property - def type_name(self) -> TypeName: - return "s" - def remove_source(self) -> None: """Disconnect the source OutputPort of the signal. If the source port still is connected to this signal then also disconnect the source port.""" @@ -88,3 +93,16 @@ class Signal(AbstractGraphComponent): """Returns true if the signal is missing either a source or a destination, else false.""" return self._source is None or self._destination is None + + @property + def bits(self) -> Optional[int]: + """Get the number of bits that this operations using this signal as an input should truncate received values to. + None = unlimited.""" + return self.param("bits") + + @bits.setter + def bits(self, bits: Optional[int]) -> None: + """Set the number of bits that operations using this signal as an input should truncate received values to. + None = unlimited.""" + assert bits is None or (isinstance(bits, int) and bits >= 0), "Bits must be non-negative." + self.set_param("bits", bits) \ No newline at end of file diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 1e2f3125d39933aac3cc8d2bbd3786da6cb25190..b7919f64ca46aa46d787fb477b2d24e72528c217 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -30,7 +30,7 @@ class GraphIDGenerator: @property def id_number_offset(self) -> GraphIDNumber: """Get the graph id number offset of this generator.""" - return self._next_id_number.default_factory() + return self._next_id_number.default_factory() # pylint: disable=not-callable class SFG(AbstractOperation): @@ -71,11 +71,9 @@ class SFG(AbstractOperation): # Setup input signals. for input_index, sig in enumerate(input_signals): assert sig not in self._original_components_to_new, "Duplicate input signals sent to SFG construcctor." - new_input_op = self._add_component_copy_unconnected(Input()) new_sig = self._add_component_copy_unconnected(sig) new_sig.set_source(new_input_op.output(0)) - self._input_operations.append(new_input_op) self._original_input_signals_to_indexes[sig] = input_index @@ -83,12 +81,10 @@ class SFG(AbstractOperation): for input_index, input_op in enumerate(inputs, len(input_signals)): assert input_op not in self._original_components_to_new, "Duplicate input operations sent to SFG constructor." new_input_op = self._add_component_copy_unconnected(input_op) - for sig in input_op.output(0).signals: assert sig not in self._original_components_to_new, "Duplicate input signals connected to input ports sent to SFG construcctor." new_sig = self._add_component_copy_unconnected(sig) new_sig.set_source(new_input_op.output(0)) - self._original_input_signals_to_indexes[sig] = input_index self._input_operations.append(new_input_op) @@ -111,7 +107,6 @@ class SFG(AbstractOperation): # Setup output operations, starting from indexes after output signals. for output_ind, output_op in enumerate(outputs, len(output_signals)): assert output_op not in self._original_components_to_new, "Duplicate output operations sent to SFG constructor." - new_out = self._add_component_copy_unconnected(output_op) for sig in output_op.input(0).signals: if sig in self._original_components_to_new: @@ -161,7 +156,8 @@ class SFG(AbstractOperation): if not input_sources: input_sources = None return SFG(inputs = self._input_operations, outputs = self._output_operations, - id_number_offset = self._graph_id_generator.id_number_offset, name = name, input_sources = input_sources) + id_number_offset = self._graph_id_generator.id_number_offset, + name = name, input_sources = input_sources) @property def type_name(self) -> TypeName: @@ -238,7 +234,6 @@ class SFG(AbstractOperation): def _copy_structure_from_operation_dfs(self, start_op: Operation): op_stack = deque([start_op]) - while op_stack: original_op = op_stack.pop() # Add or get the new copy of the operation.. @@ -255,23 +250,20 @@ class SFG(AbstractOperation): raise ValueError("Unconnected input port in SFG") for original_signal in original_input_port.signals: - # Check if the signal is one of the SFG's input signals if original_signal in self._original_input_signals_to_indexes: # New signal already created during first step of constructor new_signal = self._original_components_to_new[original_signal] new_signal.set_destination(new_op.input(original_input_port.index)) - self._components_in_dfs_order.extend([new_signal, new_signal.source.operation]) # Check if the signal has not been added before elif original_signal not in self._original_components_to_new: if original_signal.source is None: raise ValueError("Dangling signal without source in SFG") - + new_signal = self._add_component_copy_unconnected(original_signal) new_signal.set_destination(new_op.input(original_input_port.index)) - self._components_in_dfs_order.append(new_signal) original_connected_op = original_signal.source.operation @@ -283,7 +275,6 @@ class SFG(AbstractOperation): # Create new operation, set signal source to it new_connected_op = self._add_component_copy_unconnected(original_connected_op) new_signal.set_source(new_connected_op.output(original_signal.source.index)) - self._components_in_dfs_order.append(new_connected_op) # Add connected operation to queue of operations to visit @@ -291,15 +282,12 @@ class SFG(AbstractOperation): # Connect output ports for original_output_port in original_op.outputs: - for original_signal in original_output_port.signals: # Check if the signal is one of the SFG's output signals. if original_signal in self._original_output_signals_to_indexes: - # New signal already created during first step of constructor. new_signal = self._original_components_to_new[original_signal] new_signal.set_source(new_op.output(original_output_port.index)) - self._components_in_dfs_order.extend([new_signal, new_signal.destination.operation]) # Check if signal has not been added before. @@ -309,7 +297,6 @@ class SFG(AbstractOperation): new_signal = self._add_component_copy_unconnected(original_signal) new_signal.set_source(new_op.output(original_output_port.index)) - self._components_in_dfs_order.append(new_signal) original_connected_op = original_signal.destination.operation @@ -321,7 +308,6 @@ class SFG(AbstractOperation): # Create new operation, set destination to it. new_connected_op = self._add_component_copy_unconnected(original_connected_op) new_signal.set_destination(new_connected_op.input(original_signal.destination.index)) - self._components_in_dfs_order.append(new_connected_op) # Add connected operation to the queue of operations to visist diff --git a/b_asic/special_operations.py b/b_asic/special_operations.py index 5a43ded109d59f8cf87eb08142776febde3868f3..a5a3e90f770a5dca8a5f1b52574bfce75f178bab 100644 --- a/b_asic/special_operations.py +++ b/b_asic/special_operations.py @@ -29,12 +29,12 @@ class Input(AbstractOperation): @property def value(self) -> Number: - """TODO: docstring""" + """Get the current value of this input.""" return self.param("value") @value.setter def value(self, value: Number): - """TODO: docstring""" + """Set the current value of this input.""" self.set_param("value", value) diff --git a/test/test_inputport.py b/test/test_inputport.py index 055eab283b1457a20f46187639304f1a0f78fb1e..85f892217c7e0f766417f6cc2e6d066d48d8a537 100644 --- a/test/test_inputport.py +++ b/test/test_inputport.py @@ -73,28 +73,3 @@ def test_add_signal_then_disconnect(inp_port, s_w_source): assert inp_port.signals == [] assert s_w_source.source.signals == [s_w_source] assert s_w_source.destination is None - -def test_set_value_length_pos_int(inp_port): - inp_port.value_length = 10 - assert inp_port.value_length == 10 - -def test_set_value_length_zero(inp_port): - inp_port.value_length = 0 - assert inp_port.value_length == 0 - -def test_set_value_length_neg_int(inp_port): - with pytest.raises(Exception): - inp_port.value_length = -10 - -def test_set_value_length_complex(inp_port): - with pytest.raises(Exception): - inp_port.value_length = (2+4j) - -def test_set_value_length_float(inp_port): - with pytest.raises(Exception): - inp_port.value_length = 3.2 - -def test_set_value_length_pos_then_none(inp_port): - inp_port.value_length = 10 - inp_port.value_length = None - assert inp_port.value_length is None diff --git a/test/test_operation.py b/test/test_operation.py index bb09be0b0caa968a9a75e751d935d9c9595b3495..4b258e00ab127d6489af81bf455bba411509439d 100644 --- a/test/test_operation.py +++ b/test/test_operation.py @@ -10,19 +10,16 @@ class TestTraverse: def test_traverse_tree(self, operation_tree): """Traverse a basic addition tree with two constants.""" - assert len(list(operation_tree.traverse())) == 3 + assert len(list(operation_tree.traverse())) == 5 def test_traverse_large_tree(self, large_operation_tree): """Traverse a larger tree.""" - assert len(list(large_operation_tree.traverse())) == 7 + assert len(list(large_operation_tree.traverse())) == 13 def test_traverse_type(self, large_operation_tree): - traverse = list(large_operation_tree.traverse()) - assert len( - list(filter(lambda type_: isinstance(type_, Addition), traverse))) == 3 - assert len( - list(filter(lambda type_: isinstance(type_, Constant), traverse))) == 4 + result = list(large_operation_tree.traverse()) + assert len(list(filter(lambda type_: isinstance(type_, Addition), result))) == 3 + assert len(list(filter(lambda type_: isinstance(type_, Constant), result))) == 4 - def test_traverse_loop(self, operation_tree): - # TODO: Construct a graph that contains a loop and make sure you can traverse it properly. - assert True + def test_traverse_loop(self, operation_graph_with_cycle): + assert len(list(operation_graph_with_cycle.traverse())) == 8 \ No newline at end of file diff --git a/test/test_signal.py b/test/test_signal.py index 94ec1d3d14c958bb5bdf41127dd971b229799dd9..cad16c9ba5b73b3d597c4e80aa666677d1909888 100644 --- a/test/test_signal.py +++ b/test/test_signal.py @@ -60,3 +60,28 @@ def test_signal_creation_and_disconnction_and_connection_changing(): assert in_port.signals == [s] assert s.source is out_port assert s.destination is in_port + +def test_signal_set_bits_pos_int(signal): + signal.bits = 10 + assert signal.bits == 10 + +def test_signal_set_bits_zero(signal): + signal.bits = 0 + assert signal.bits == 0 + +def test_signal_set_bits_neg_int(signal): + with pytest.raises(Exception): + signal.bits = -10 + +def test_signal_set_bits_complex(signal): + with pytest.raises(Exception): + signal.bits = (2+4j) + +def test_signal_set_bits_float(signal): + with pytest.raises(Exception): + signal.bits = 3.2 + +def test_signal_set_bits_pos_then_none(signal): + signal.bits = 10 + signal.bits = None + assert signal.bits is None \ No newline at end of file