diff --git a/b_asic/operation.py b/b_asic/operation.py index ecc471371182017ade2d9244e59e70af211c9ac3..bb66e26b30a4a14b116800300d5a00d0855945f2 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -181,10 +181,8 @@ class Operation(GraphComponent, SignalSourceProvider): raise NotImplementedError @abstractmethod - def depends(self, output_index: int, input_index: int) -> bool: - """Check if the output at the given output index depends on the input at the - given input index in order to be evaluated. - """ + def inputs_required_for_output(self, output_index: int) -> Iterable[int]: + """Get the input indices of all inputs in this operation whose values are required in order to evalueate the output at the given output index.""" raise NotImplementedError @@ -347,12 +345,10 @@ class AbstractOperation(Operation, AbstractGraphComponent): pass return [self] - def depends(self, output_index: int, input_index: int) -> bool: + def inputs_required_for_output(self, output_index: int) -> Iterable[int]: if output_index < 0 or output_index >= self.output_count: raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {output_index})") - if input_index < 0 or input_index >= self.input_count: - raise IndexError(f"Input index out of range (expected 0-{self.input_count - 1}, got {input_index})") - return True # By default, assume each output depends on all inputs. + return [i for i in range(self.input_count)] # By default, assume each output depends on all inputs. @property def neighbors(self) -> Iterable[GraphComponent]: diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 27a7c3378a87b66c83d690417959a90110d286d8..dd2c02c247f51a7da1c4e4d1b0a05328126d6a0c 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -48,6 +48,7 @@ class SFG(AbstractOperation): _original_components_to_new: MutableSet[GraphComponent] _original_input_signals_to_indices: Dict[Signal, int] _original_output_signals_to_indices: Dict[Signal, int] + _dependency_map: Dict[int, MutableSet[int]] def __init__(self, input_signals: Optional[Sequence[Signal]] = None, output_signals: Optional[Sequence[Signal]] = None, \ inputs: Optional[Sequence[Input]] = None, outputs: Optional[Sequence[Output]] = None, \ @@ -71,6 +72,7 @@ class SFG(AbstractOperation): self._original_components_to_new = {} self._original_input_signals_to_indices = {} self._original_output_signals_to_indices = {} + self._dependency_map = {} # Setup input signals. if input_signals is not None: @@ -155,6 +157,9 @@ class SFG(AbstractOperation): raise ValueError(f"Output signal #{output_index} is missing source in SFG") if signal.source.operation not in self._original_components_to_new: self._add_operation_connected_tree_copy(signal.source.operation) + + # Find dependencies. + def __str__(self) -> str: """Get a string representation of this SFG.""" @@ -234,12 +239,10 @@ class SFG(AbstractOperation): def split(self) -> Iterable[Operation]: return self.operations - def depends(self, output_index: int, input_index: int) -> bool: + def inputs_required_for_output(self, output_index: int) -> Iterable[int]: if output_index < 0 or output_index >= self.output_count: raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {output_index})") - if input_index < 0 or input_index >= self.input_count: - raise IndexError(f"Input index out of range (expected 0-{self.input_count - 1}, got {input_index})") - return self._source_depends(self._output_operations[output_index].input(0).signals[0].source, self._input_operations[input_index], set()) + return self._inputs_required_for_source(self._output_operations[output_index].input(0).signals[0].source, set()) def copy_component(self, *args, **kwargs) -> GraphComponent: return super().copy_component(*args, **kwargs, inputs = self._input_operations, outputs = self._output_operations, @@ -431,14 +434,17 @@ class SFG(AbstractOperation): results[key] = value return value - def _source_depends(self, src: OutputPort, input_operation: Input, visited: MutableSet[Operation]) -> bool: - if src.operation is input_operation: - return True + def _inputs_required_for_source(self, src: OutputPort, visited: MutableSet[Operation]) -> Sequence[bool]: + if src.operation in visited: + return [] + visited.add(src.operation) - if src.operation not in visited: - visited.add(src.operation) - for i, port in enumerate(src.operation.inputs): - if src.operation.depends(src.index, i): - if self._source_depends(port.signals[0].source, input_operation, visited): - return True - return False + if isinstance(src.operation, Input): + for i, input_operation in enumerate(self._input_operations): + if input_operation is src.operation: + return [i] + + input_indices = [] + for i in src.operation.inputs_required_for_output(src.index): + input_indices.extend(self._inputs_required_for_source(src.operation.input(i).signals[0].source, visited)) + return input_indices diff --git a/test/test_depends.py b/test/test_depends.py index 16efcaeab0049856aa796bd9e86cf0799070c77a..bb8897d53453d5ac967371586f2a626e0d8c01ef 100644 --- a/test/test_depends.py +++ b/test/test_depends.py @@ -3,24 +3,36 @@ from b_asic import Addition, Butterfly class TestDepends: def test_depends_addition(self): add1 = Addition() - assert add1.depends(0, 0) == True - assert add1.depends(0, 1) == True + out1_dependencies = add1.inputs_required_for_output(0) + assert len(list(out1_dependencies)) == 2 + assert 0 in out1_dependencies + assert 1 in out1_dependencies def test_depends_butterfly(self): bfly1 = Butterfly() - assert bfly1.depends(0, 0) == True - assert bfly1.depends(0, 1) == True - assert bfly1.depends(1, 0) == True - assert bfly1.depends(1, 1) == True + out1_dependencies = bfly1.inputs_required_for_output(0) + out2_dependencies = bfly1.inputs_required_for_output(1) + assert len(list(out1_dependencies)) == 2 + assert 0 in out1_dependencies + assert 1 in out1_dependencies + assert len(list(out2_dependencies)) == 2 + assert 0 in out2_dependencies + assert 1 in out2_dependencies def test_depends_sfg(self, sfg_two_inputs_two_outputs): - assert sfg_two_inputs_two_outputs.depends(0, 0) == True - assert sfg_two_inputs_two_outputs.depends(0, 1) == True - assert sfg_two_inputs_two_outputs.depends(1, 0) == True - assert sfg_two_inputs_two_outputs.depends(1, 1) == True + out1_dependencies = sfg_two_inputs_two_outputs.inputs_required_for_output(0) + out2_dependencies = sfg_two_inputs_two_outputs.inputs_required_for_output(1) + assert len(list(out1_dependencies)) == 2 + assert 0 in out1_dependencies + assert 1 in out1_dependencies + assert len(list(out2_dependencies)) == 2 + assert 0 in out2_dependencies + assert 1 in out2_dependencies def test_depends_sfg_independent(self, sfg_two_inputs_two_outputs_independent): - assert sfg_two_inputs_two_outputs_independent.depends(0, 0) == True - assert sfg_two_inputs_two_outputs_independent.depends(0, 1) == False - assert sfg_two_inputs_two_outputs_independent.depends(1, 0) == False - assert sfg_two_inputs_two_outputs_independent.depends(1, 1) == True \ No newline at end of file + out1_dependencies = sfg_two_inputs_two_outputs_independent.inputs_required_for_output(0) + out2_dependencies = sfg_two_inputs_two_outputs_independent.inputs_required_for_output(1) + assert len(list(out1_dependencies)) == 1 + assert 0 in out1_dependencies + assert len(list(out2_dependencies)) == 1 + assert 1 in out2_dependencies \ No newline at end of file