Skip to content
Snippets Groups Projects
Commit 6a7a2849 authored by Ivar Härnqvist's avatar Ivar Härnqvist
Browse files

fix copy_component for SFGs

parent 97d73a49
No related branches found
No related tags found
4 merge requests!31Resolve "Specify internal input/output dependencies of an Operation",!25Resolve "System tests iteration 1",!24Resolve "System tests iteration 1",!23Resolve "Simulate SFG"
Pipeline #12840 failed
......@@ -72,7 +72,7 @@ class GraphComponent(ABC):
raise NotImplementedError
@abstractmethod
def copy_component(self) -> "GraphComponent":
def copy_component(self, *args, **kwargs) -> "GraphComponent":
"""Get a new instance of this graph component type with the same name, id and parameters."""
raise NotImplementedError
......@@ -130,22 +130,22 @@ class AbstractGraphComponent(GraphComponent):
def set_param(self, name: str, value: Any) -> None:
self._parameters[name] = value
def copy_component(self) -> GraphComponent:
new_comp = self.__class__()
new_comp.name = copy(self.name)
new_comp.graph_id = copy(self.graph_id)
def copy_component(self, *args, **kwargs) -> GraphComponent:
new_component = self.__class__(*args, **kwargs)
new_component.name = copy(self.name)
new_component.graph_id = copy(self.graph_id)
for name, value in self.params.items():
new_comp.set_param(copy(name), deepcopy(value)) # pylint: disable=no-member
return new_comp
new_component.set_param(copy(name), deepcopy(value)) # pylint: disable=no-member
return new_component
def traverse(self) -> Generator[GraphComponent, None, None]:
# Breadth first search.
visited = {self}
fontier = deque([self])
while fontier:
comp = fontier.popleft()
yield comp
for neighbor in comp.neighbors:
component = fontier.popleft()
yield component
for neighbor in component.neighbors:
if neighbor not in visited:
visited.add(neighbor)
fontier.append(neighbor)
\ No newline at end of file
......@@ -142,22 +142,14 @@ class AbstractOperation(Operation, AbstractGraphComponent):
def __init__(self, input_count: int, output_count: int, name: Name = "", input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None):
super().__init__(name)
self._input_ports = []
self._output_ports = []
# Allocate input ports.
for i in range(input_count):
self._input_ports.append(InputPort(self, i))
# Allocate output ports.
for i in range(output_count):
self._output_ports.append(OutputPort(self, i))
self._input_ports = [InputPort(self, i) for i in range(input_count)] # Allocate input ports.
self._output_ports = [OutputPort(self, i) for i in range(output_count)] # Allocate output ports.
# Connect given input sources, if any.
if input_sources is not None:
source_count = len(input_sources)
if source_count != input_count:
raise ValueError(f"Operation expected {input_count} input sources but only got {source_count}")
raise ValueError(f"Wrong number of input sources supplied to Operation (expected {input_count}, got {source_count})")
for i, src in enumerate(input_sources):
if src is not None:
self._input_ports[i].connect(src.source)
......@@ -169,21 +161,21 @@ class AbstractOperation(Operation, AbstractGraphComponent):
return n & ((2 ** bits) - 1)
@abstractmethod
def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ
def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ
"""Evaluate the operation and generate a list of output values given a
list of input values.
"""
raise NotImplementedError
def _find_result(self, prefix: str, index: int, results: MutableMapping[str, Optional[Number]]) -> Optional[Number]:
key = results_key(self.output_count, prefix, index)
def _results_key(self, prefix: str, index: int) -> str:
return results_key(self.output_count, prefix, index)
def _find_result(self, key: str, results: MutableMapping[str, Optional[Number]]) -> Optional[Number]:
if key in results:
value = results[key]
if value is None:
raise RuntimeError(f"Direct feedback loop detected when evaluating operation.")
return value
results[key] = None
return None
def _truncate_inputs(self, input_values: Sequence[Number]):
......@@ -282,9 +274,11 @@ class AbstractOperation(Operation, AbstractGraphComponent):
if registers is None:
registers = {}
result = self._find_result(prefix, index, results)
key = self._results_key(prefix, index)
result = self._find_result(key, results)
if result is not None:
return result
results[key] = None
values = self.evaluate(*self._truncate_inputs(input_values))
if isinstance(values, collections.abc.Sequence):
if len(values) != self.output_count:
......
......@@ -176,9 +176,11 @@ class SFG(AbstractOperation):
if registers is None:
registers = {}
result = self._find_result(prefix, index, results)
key = self._results_key(prefix, index)
result = self._find_result(key, results)
if result is not None:
return result
results[key] = None
# Set the values of our input operations to the given input values.
for op, arg in zip(self._input_operations, self._truncate_inputs(input_values)):
......@@ -190,6 +192,10 @@ class SFG(AbstractOperation):
def split(self) -> Iterable[Operation]:
return self.operations
def copy_component(self, *args, **kwargs) -> GraphComponent:
return super().copy_component(*args, **kwargs, inputs = self._input_operations, outputs = self._output_operations,
id_number_offset = self._graph_id_generator.id_number_offset, name = self.name)
@property
def id_number_offset(self) -> GraphIDNumber:
......@@ -321,11 +327,14 @@ class SFG(AbstractOperation):
op_stack.append(original_connected_op)
def _evaluate_source(self, src: OutputPort, results: MutableMapping[str, Number], registers: MutableMapping[str, Number], prefix: str) -> Number:
op_prefix = prefix
if op_prefix:
op_prefix += "."
op_prefix += src.operation.graph_id
src_prefix = prefix
if src_prefix:
src_prefix += "."
src_prefix += src.operation.graph_id
# TODO: Handle registers.
input_values = [self._evaluate_source(input_port.signals[0].source, results, registers, prefix) for input_port in src.operation.inputs]
value = src.operation.evaluate_output(src.index, input_values, results, registers, op_prefix)
results[results_key(src.operation.output_count, op_prefix, src.index)] = value
value = src.operation.evaluate_output(src.index, input_values, results, registers, src_prefix)
results[results_key(src.operation.output_count, src_prefix, src.index)] = value
return value
......@@ -55,12 +55,12 @@ class Output(AbstractOperation):
class Register(AbstractOperation):
"""Delay operation.
"""Unit delay operation.
TODO: More info.
"""
def __init__(self, initial_value: Number = 0, src0: Optional[SignalSourceProvider] = None, name: Name = ""):
super().__init__(input_count = 1, output_count = 0, name = name, input_sources = [src0])
super().__init__(input_count = 1, output_count = 1, name = name, input_sources = [src0])
self.set_param("initial_value", initial_value)
@property
......@@ -79,11 +79,9 @@ class Register(AbstractOperation):
results = {}
if registers is None:
registers = {}
if prefix in results:
return results[prefix]
if prefix in registers:
return registers[prefix]
value = registers.get(prefix, self.param("initial_value"))
registers[prefix] = self._truncate_inputs(input_values)[0]
......
......@@ -10,41 +10,50 @@ def operation():
@pytest.fixture
def operation_tree():
"""Valid addition operation connected with 2 constants.
2>--+
2---+
|
2+3=5>
v
add = 2 + 3 = 5
^
|
3>--+
3---+
"""
return Addition(Constant(2), Constant(3))
@pytest.fixture
def large_operation_tree():
"""Valid addition operation connected with a large operation tree with 2 other additions and 4 constants.
2>--+
2---+
|
2+3=5>--+
| |
3>--+ |
5+9=14>
4>--+ |
| |
4+5=9>--+
v
add---+
^ |
| |
3---+ v
add = (2 + 3) + (4 + 5) = 14
4---+ ^
| |
v |
add---+
^
|
5>--+
5---+
"""
return Addition(Addition(Constant(2), Constant(3)), Addition(Constant(4), Constant(5)))
@pytest.fixture
def operation_graph_with_cycle():
"""Invalid addition operation connected with an operation graph containing a cycle.
+---+
| |
?+7=?>-------+
| |
7>--+ ?+6=?>
|
6
+-+
| |
v |
add+---+
^ |
| v
7 add = (? + 7) + 6 = ?
^
|
6
"""
add1 = Addition(None, Constant(7))
add1.input(0).connect(add1)
......
......@@ -6,16 +6,18 @@ from b_asic import SFG, Input, Output, Constant, Register
@pytest.fixture
def sfg_two_inputs_two_outputs():
"""Valid SFG with two inputs and two outputs.
. .
in1>------+ +---------------out1>
. | | .
. in1+in2=add1>--+ .
. | | .
in2>------+ | .
| . add1+in2=add2>---out2>
| . | .
+------------------+ .
. .
. .
in1-------+ +--------->out1
. | | .
. v | .
. add1+--+ .
. ^ | .
. | v .
in2+------+ add2---->out2
| . ^ .
| . | .
+------------+ .
. .
out1 = in1 + in2
out2 = in1 + 2 * in2
"""
......
......@@ -102,7 +102,7 @@ class TestSimulation:
output2 = simulation.run()
assert output1[0] == 11405
assert output2[0] == 8109
assert output2[0] == 4221
def test_simulate_with_register(self, sfg_accumulator):
data_in = np.array([5, -2, 25, -6, 7, 0])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment