Skip to content
Snippets Groups Projects
Commit bc7aba48 authored by Ivar Härnqvist's avatar Ivar Härnqvist
Browse files

change interface for finding dependencies

parent 6c00ee33
No related branches found
No related tags found
1 merge request!31Resolve "Specify internal input/output dependencies of an Operation"
Pipeline #14069 passed
...@@ -181,10 +181,8 @@ class Operation(GraphComponent, SignalSourceProvider): ...@@ -181,10 +181,8 @@ class Operation(GraphComponent, SignalSourceProvider):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def depends(self, output_index: int, input_index: int) -> bool: def inputs_required_for_output(self, output_index: int) -> Iterable[int]:
"""Check if the output at the given output index depends on the input at the """Get the input indices of all inputs in this operation whose values are required in order to evalueate the output at the given output index."""
given input index in order to be evaluated.
"""
raise NotImplementedError raise NotImplementedError
...@@ -347,12 +345,10 @@ class AbstractOperation(Operation, AbstractGraphComponent): ...@@ -347,12 +345,10 @@ class AbstractOperation(Operation, AbstractGraphComponent):
pass pass
return [self] return [self]
def depends(self, output_index: int, input_index: int) -> bool: def inputs_required_for_output(self, output_index: int) -> Iterable[int]:
if output_index < 0 or output_index >= self.output_count: if output_index < 0 or output_index >= self.output_count:
raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {output_index})") raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {output_index})")
if input_index < 0 or input_index >= self.input_count: return [i for i in range(self.input_count)] # By default, assume each output depends on all inputs.
raise IndexError(f"Input index out of range (expected 0-{self.input_count - 1}, got {input_index})")
return True # By default, assume each output depends on all inputs.
@property @property
def neighbors(self) -> Iterable[GraphComponent]: def neighbors(self) -> Iterable[GraphComponent]:
......
...@@ -48,6 +48,7 @@ class SFG(AbstractOperation): ...@@ -48,6 +48,7 @@ class SFG(AbstractOperation):
_original_components_to_new: MutableSet[GraphComponent] _original_components_to_new: MutableSet[GraphComponent]
_original_input_signals_to_indices: Dict[Signal, int] _original_input_signals_to_indices: Dict[Signal, int]
_original_output_signals_to_indices: Dict[Signal, int] _original_output_signals_to_indices: Dict[Signal, int]
_dependency_map: Dict[int, MutableSet[int]]
def __init__(self, input_signals: Optional[Sequence[Signal]] = None, output_signals: Optional[Sequence[Signal]] = None, \ def __init__(self, input_signals: Optional[Sequence[Signal]] = None, output_signals: Optional[Sequence[Signal]] = None, \
inputs: Optional[Sequence[Input]] = None, outputs: Optional[Sequence[Output]] = None, \ inputs: Optional[Sequence[Input]] = None, outputs: Optional[Sequence[Output]] = None, \
...@@ -71,6 +72,7 @@ class SFG(AbstractOperation): ...@@ -71,6 +72,7 @@ class SFG(AbstractOperation):
self._original_components_to_new = {} self._original_components_to_new = {}
self._original_input_signals_to_indices = {} self._original_input_signals_to_indices = {}
self._original_output_signals_to_indices = {} self._original_output_signals_to_indices = {}
self._dependency_map = {}
# Setup input signals. # Setup input signals.
if input_signals is not None: if input_signals is not None:
...@@ -155,6 +157,9 @@ class SFG(AbstractOperation): ...@@ -155,6 +157,9 @@ class SFG(AbstractOperation):
raise ValueError(f"Output signal #{output_index} is missing source in SFG") raise ValueError(f"Output signal #{output_index} is missing source in SFG")
if signal.source.operation not in self._original_components_to_new: if signal.source.operation not in self._original_components_to_new:
self._add_operation_connected_tree_copy(signal.source.operation) self._add_operation_connected_tree_copy(signal.source.operation)
# Find dependencies.
def __str__(self) -> str: def __str__(self) -> str:
"""Get a string representation of this SFG.""" """Get a string representation of this SFG."""
...@@ -234,12 +239,10 @@ class SFG(AbstractOperation): ...@@ -234,12 +239,10 @@ class SFG(AbstractOperation):
def split(self) -> Iterable[Operation]: def split(self) -> Iterable[Operation]:
return self.operations return self.operations
def depends(self, output_index: int, input_index: int) -> bool: def inputs_required_for_output(self, output_index: int) -> Iterable[int]:
if output_index < 0 or output_index >= self.output_count: if output_index < 0 or output_index >= self.output_count:
raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {output_index})") raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {output_index})")
if input_index < 0 or input_index >= self.input_count: return self._inputs_required_for_source(self._output_operations[output_index].input(0).signals[0].source, set())
raise IndexError(f"Input index out of range (expected 0-{self.input_count - 1}, got {input_index})")
return self._source_depends(self._output_operations[output_index].input(0).signals[0].source, self._input_operations[input_index], set())
def copy_component(self, *args, **kwargs) -> GraphComponent: def copy_component(self, *args, **kwargs) -> GraphComponent:
return super().copy_component(*args, **kwargs, inputs = self._input_operations, outputs = self._output_operations, return super().copy_component(*args, **kwargs, inputs = self._input_operations, outputs = self._output_operations,
...@@ -431,14 +434,17 @@ class SFG(AbstractOperation): ...@@ -431,14 +434,17 @@ class SFG(AbstractOperation):
results[key] = value results[key] = value
return value return value
def _source_depends(self, src: OutputPort, input_operation: Input, visited: MutableSet[Operation]) -> bool: def _inputs_required_for_source(self, src: OutputPort, visited: MutableSet[Operation]) -> Sequence[bool]:
if src.operation is input_operation: if src.operation in visited:
return True return []
visited.add(src.operation)
if src.operation not in visited: if isinstance(src.operation, Input):
visited.add(src.operation) for i, input_operation in enumerate(self._input_operations):
for i, port in enumerate(src.operation.inputs): if input_operation is src.operation:
if src.operation.depends(src.index, i): return [i]
if self._source_depends(port.signals[0].source, input_operation, visited):
return True input_indices = []
return False for i in src.operation.inputs_required_for_output(src.index):
input_indices.extend(self._inputs_required_for_source(src.operation.input(i).signals[0].source, visited))
return input_indices
...@@ -3,24 +3,36 @@ from b_asic import Addition, Butterfly ...@@ -3,24 +3,36 @@ from b_asic import Addition, Butterfly
class TestDepends: class TestDepends:
def test_depends_addition(self): def test_depends_addition(self):
add1 = Addition() add1 = Addition()
assert add1.depends(0, 0) == True out1_dependencies = add1.inputs_required_for_output(0)
assert add1.depends(0, 1) == True assert len(list(out1_dependencies)) == 2
assert 0 in out1_dependencies
assert 1 in out1_dependencies
def test_depends_butterfly(self): def test_depends_butterfly(self):
bfly1 = Butterfly() bfly1 = Butterfly()
assert bfly1.depends(0, 0) == True out1_dependencies = bfly1.inputs_required_for_output(0)
assert bfly1.depends(0, 1) == True out2_dependencies = bfly1.inputs_required_for_output(1)
assert bfly1.depends(1, 0) == True assert len(list(out1_dependencies)) == 2
assert bfly1.depends(1, 1) == True assert 0 in out1_dependencies
assert 1 in out1_dependencies
assert len(list(out2_dependencies)) == 2
assert 0 in out2_dependencies
assert 1 in out2_dependencies
def test_depends_sfg(self, sfg_two_inputs_two_outputs): def test_depends_sfg(self, sfg_two_inputs_two_outputs):
assert sfg_two_inputs_two_outputs.depends(0, 0) == True out1_dependencies = sfg_two_inputs_two_outputs.inputs_required_for_output(0)
assert sfg_two_inputs_two_outputs.depends(0, 1) == True out2_dependencies = sfg_two_inputs_two_outputs.inputs_required_for_output(1)
assert sfg_two_inputs_two_outputs.depends(1, 0) == True assert len(list(out1_dependencies)) == 2
assert sfg_two_inputs_two_outputs.depends(1, 1) == True assert 0 in out1_dependencies
assert 1 in out1_dependencies
assert len(list(out2_dependencies)) == 2
assert 0 in out2_dependencies
assert 1 in out2_dependencies
def test_depends_sfg_independent(self, sfg_two_inputs_two_outputs_independent): def test_depends_sfg_independent(self, sfg_two_inputs_two_outputs_independent):
assert sfg_two_inputs_two_outputs_independent.depends(0, 0) == True out1_dependencies = sfg_two_inputs_two_outputs_independent.inputs_required_for_output(0)
assert sfg_two_inputs_two_outputs_independent.depends(0, 1) == False out2_dependencies = sfg_two_inputs_two_outputs_independent.inputs_required_for_output(1)
assert sfg_two_inputs_two_outputs_independent.depends(1, 0) == False assert len(list(out1_dependencies)) == 1
assert sfg_two_inputs_two_outputs_independent.depends(1, 1) == True assert 0 in out1_dependencies
\ No newline at end of file assert len(list(out2_dependencies)) == 1
assert 1 in out2_dependencies
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment