Skip to content
Snippets Groups Projects
Commit ba89f480 authored by Jacob Wahlman's avatar Jacob Wahlman :ok_hand:
Browse files

Merged operation id system with traversing and signal

parents 0133ee94 664b8044
No related branches found
No related tags found
1 merge request!2Integrated ID system, traversing and som signal tests
Pipeline #10016 passed
...@@ -36,7 +36,7 @@ class BasicOperation(Operation): ...@@ -36,7 +36,7 @@ class BasicOperation(Operation):
Evaluate the operation and generate a list of output values given a list of input values. Evaluate the operation and generate a list of output values given a list of input values.
""" """
pass pass
def inputs(self) -> List[InputPort]: def inputs(self) -> List[InputPort]:
return self._input_ports.copy() return self._input_ports.copy()
...@@ -97,4 +97,13 @@ class BasicOperation(Operation): ...@@ -97,4 +97,13 @@ class BasicOperation(Operation):
return results return results
return [self] return [self]
@property
def neighbours(self) -> List[Operation]:
neighbours: List[Operation] = []
for port in self._output_ports + self._input_ports:
for signal in port.signals():
neighbours += [signal.source.operation, signal.destination.operation]
return neighbours
# TODO: More stuff. # TODO: More stuff.
...@@ -30,8 +30,8 @@ class Constant(BasicOperation): ...@@ -30,8 +30,8 @@ class Constant(BasicOperation):
""" """
Construct a Constant. Construct a Constant.
""" """
super().__init__(identifier) super().__init__()
self._output_ports = [OutputPort()] # TODO: Generate appropriate ID for ports. self._output_ports = [OutputPort(1)] # TODO: Generate appropriate ID for ports.
self._parameters["value"] = value self._parameters["value"] = value
def evaluate(self, inputs: list) -> list: def evaluate(self, inputs: list) -> list:
...@@ -50,7 +50,7 @@ class Addition(BasicOperation): ...@@ -50,7 +50,7 @@ class Addition(BasicOperation):
""" """
Construct an Addition. Construct an Addition.
""" """
super().__init__(self) super().__init__()
self._input_ports = [InputPort(1), InputPort(1)] # TODO: Generate appropriate ID for ports. self._input_ports = [InputPort(1), InputPort(1)] # TODO: Generate appropriate ID for ports.
self._output_ports = [OutputPort(1)] # TODO: Generate appropriate ID for ports. self._output_ports = [OutputPort(1)] # TODO: Generate appropriate ID for ports.
...@@ -59,7 +59,7 @@ class Addition(BasicOperation): ...@@ -59,7 +59,7 @@ class Addition(BasicOperation):
def get_op_name(self) -> GraphIDType: def get_op_name(self) -> GraphIDType:
return "add" return "add"
class ConstantMultiplication(BasicOperation): class ConstantMultiplication(BasicOperation):
""" """
...@@ -71,7 +71,7 @@ class ConstantMultiplication(BasicOperation): ...@@ -71,7 +71,7 @@ class ConstantMultiplication(BasicOperation):
""" """
Construct a ConstantMultiplication. Construct a ConstantMultiplication.
""" """
super().__init__(identifier) super().__init__()
self._input_ports = [InputPort(1)] # TODO: Generate appropriate ID for ports. self._input_ports = [InputPort(1)] # TODO: Generate appropriate ID for ports.
self._output_ports = [OutputPort(1)] # TODO: Generate appropriate ID for ports. self._output_ports = [OutputPort(1)] # TODO: Generate appropriate ID for ports.
self._parameters["coefficient"] = coefficient self._parameters["coefficient"] = coefficient
......
...@@ -12,7 +12,7 @@ GraphIDNumber = NewType("GraphIDNumber", int) ...@@ -12,7 +12,7 @@ GraphIDNumber = NewType("GraphIDNumber", int)
class GraphIDGenerator: class GraphIDGenerator:
""" """
A class that generates Graph IDs for objects. A class that generates Graph IDs for objects.
""" """
_next_id_number: DefaultDict[GraphIDType, GraphIDNumber] _next_id_number: DefaultDict[GraphIDType, GraphIDNumber]
......
...@@ -88,7 +88,7 @@ class Operation(ABC): ...@@ -88,7 +88,7 @@ class Operation(ABC):
""" """
Simulate the circuit until its iteration count matches that of the simulation state, Simulate the circuit until its iteration count matches that of the simulation state,
then return the resulting output vector. then return the resulting output vector.
""" """
pass pass
@abstractmethod @abstractmethod
...@@ -104,5 +104,12 @@ class Operation(ABC): ...@@ -104,5 +104,12 @@ class Operation(ABC):
"""Returns a string representing the operation name of the operation.""" """Returns a string representing the operation name of the operation."""
pass pass
@abstractmethod
def neighbours(self) -> "List[Operation]":
"""
Return all operations that are connected by signals to this operation.
If no neighbours are found this returns an empty list
"""
# TODO: More stuff. # TODO: More stuff.
"""
B-ASIC Operation Tree Traversing Module.
TODO:
- Get a first operation or? an entire operation tree
- For each start point, follow it to the next operation from it's out port.
- If we are searching for a specific operation end.
- If we are searching for a specific type of operation add the operation to a list and continue.
- When we no more out ports can be traversed return results and end.
"""
from typing import List, Optional
from collections import deque
from b_asic.operation import Operation
class Traverse:
"""Traverse operation tree.
TODO:
- More info.
- Check if a datastructure other than list suits better as return value.
- Implement the type check for operation.
"""
def __init__(self, operation: Operation):
"""Construct a TraverseTree."""
self._initial_operation = operation
def _breadth_first_search(self, start: Operation) -> List[Operation]:
"""Use breadth first search to traverse the operation tree."""
visited: List[Operation] = [start]
queue = deque([start])
while queue:
operation = queue.popleft()
for n_operation in operation.neighbours:
if n_operation not in visited:
visited.append(n_operation)
queue.append(n_operation)
return visited
def traverse(self, type_: Optional[Operation] = None) -> List[Operation]:
"""Traverse the the operation tree and return operation where type matches.
If the type is None then return the entire tree.
Keyword arguments:
type_-- the operation type to search for (default None)
"""
operations: List[Operation] = self._breadth_first_search(self._initial_operation)
if type_ is not None:
operations = [oper for oper in operations if isinstance(oper, type_)]
return operations
...@@ -6,15 +6,15 @@ Use a fixture for initializing objects and pass them as argument to a test funct ...@@ -6,15 +6,15 @@ Use a fixture for initializing objects and pass them as argument to a test funct
""" """
@pytest.fixture @pytest.fixture
def signal(): def signal():
source = SignalSource(Addition(0), 1) source = SignalSource(Addition(), 1)
dest = SignalDestination(Addition(1), 2) dest = SignalDestination(Addition(), 2)
return Signal(0, source, dest) return Signal(source, dest)
@pytest.fixture @pytest.fixture
def signals(): def signals():
ret = [] ret = []
for i in range(0,3): for _ in range(0,3):
source = SignalSource(Addition(0), 1) source = SignalSource(Addition(), 1)
dest = SignalDestination(Addition(1), 2) dest = SignalDestination(Addition(), 2)
ret.append(Signal(i, source, dest)) ret.append(Signal(source, dest))
return ret return ret
\ No newline at end of file
...@@ -15,8 +15,8 @@ def test_connect_one_signal_to_port(signal): ...@@ -15,8 +15,8 @@ def test_connect_one_signal_to_port(signal):
def test_change_port_signal(): def test_change_port_signal():
source = SignalSource(Addition, 1) source = SignalSource(Addition, 1)
dest = SignalDestination(Addition,2) dest = SignalDestination(Addition,2)
signal1 = Signal(1, source, dest) signal1 = Signal(source, dest)
signal2 = Signal(2, source, dest) signal2 = Signal(source, dest)
port = InputPort(0) port = InputPort(0)
port.connect(signal1) port.connect(signal1)
......
"""
TODO:
- Rewrite to more clean code, not so repetitive
- Update when signals and id's has been merged.
"""
from b_asic.core_operations import Constant, Addition
from b_asic.signal import Signal, SignalSource, SignalDestination
from b_asic.port import InputPort, OutputPort
from b_asic.traverse_tree import Traverse
import pytest
@pytest.fixture
def operation():
return Constant(2)
def create_operation(_type, dest_oper, index, **kwargs):
oper = _type(**kwargs)
oper_signal_source = SignalSource(oper, 0)
oper_signal_dest = SignalDestination(dest_oper, index)
oper_signal = Signal(oper_signal_source, oper_signal_dest)
oper._output_ports[0].connect(oper_signal)
dest_oper._input_ports[index].connect(oper_signal)
return oper
@pytest.fixture
def operation_tree():
add_oper = Addition()
const_oper = create_operation(Constant, add_oper, 0, value=2)
const_oper_2 = create_operation(Constant, add_oper, 1, value=3)
return add_oper
@pytest.fixture
def large_operation_tree():
add_oper = Addition()
add_oper_2 = Addition()
const_oper = create_operation(Constant, add_oper, 0, value=2)
const_oper_2 = create_operation(Constant, add_oper, 1, value=3)
const_oper_3 = create_operation(Constant, add_oper_2, 0, value=4)
const_oper_4 = create_operation(Constant, add_oper_2, 1, value=5)
add_oper_3 = Addition()
add_oper_signal_source = SignalSource(add_oper, 0)
add_oper_signal_dest = SignalDestination(add_oper_3, 0)
add_oper_signal = Signal(add_oper_signal_source, add_oper_signal_dest)
add_oper._output_ports[0].connect(add_oper_signal)
add_oper_3._input_ports[0].connect(add_oper_signal)
add_oper_2_signal_source = SignalSource(add_oper_2, 0)
add_oper_2_signal_dest = SignalDestination(add_oper_3, 1)
add_oper_2_signal = Signal(add_oper_2_signal_source, add_oper_2_signal_dest)
add_oper_2._output_ports[0].connect(add_oper_2_signal)
add_oper_3._input_ports[1].connect(add_oper_2_signal)
return const_oper
def test_traverse_single_tree(operation):
traverse = Traverse(operation)
assert traverse.traverse() == [operation]
def test_traverse_tree(operation_tree):
traverse = Traverse(operation_tree)
assert len(traverse.traverse()) == 3
def test_traverse_large_tree(large_operation_tree):
traverse = Traverse(large_operation_tree)
assert len(traverse.traverse()) == 7
def test_traverse_type(large_operation_tree):
traverse = Traverse(large_operation_tree)
assert len(traverse.traverse(Addition)) == 3
assert len(traverse.traverse(Constant)) == 4
def test_traverse_loop(operation_tree):
add_oper_signal_source = SignalSource(operation_tree, 0)
add_oper_signal_dest = SignalDestination(operation_tree, 0)
add_oper_signal = Signal(add_oper_signal_source, add_oper_signal_dest)
operation_tree._output_ports[0].connect(add_oper_signal)
operation_tree._input_ports[0].connect(add_oper_signal)
traverse = Traverse(operation_tree)
assert len(traverse.traverse()) == 2
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment