diff --git a/b_asic/operation.py b/b_asic/operation.py index a0d0f48a1f7429ce0d393ad4e93ef24c84914f7b..ecc471371182017ade2d9244e59e70af211c9ac3 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -180,6 +180,13 @@ class Operation(GraphComponent, SignalSourceProvider): """ raise NotImplementedError + @abstractmethod + def depends(self, output_index: int, input_index: int) -> bool: + """Check if the output at the given output index depends on the input at the + given input index in order to be evaluated. + """ + raise NotImplementedError + class AbstractOperation(Operation, AbstractGraphComponent): """Generic abstract operation class which most implementations will derive from. @@ -340,6 +347,13 @@ class AbstractOperation(Operation, AbstractGraphComponent): pass return [self] + def depends(self, output_index: int, input_index: int) -> bool: + if output_index < 0 or output_index >= self.output_count: + raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {output_index})") + if input_index < 0 or input_index >= self.input_count: + raise IndexError(f"Input index out of range (expected 0-{self.input_count - 1}, got {input_index})") + return True # By default, assume each output depends on all inputs. + @property def neighbors(self) -> Iterable[GraphComponent]: return list(self.input_signals) + list(self.output_signals) diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index e8e7af01ab93fdba948d9ff7ec19078b3b71dee6..3395466805273d39fbcab5aa2809f0905d1d80da 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -3,7 +3,7 @@ B-ASIC Signal Flow Graph Module. TODO: More info. """ -from typing import List, Iterable, Sequence, Dict, Optional, DefaultDict, Set +from typing import List, Iterable, Sequence, Dict, Optional, DefaultDict, MutableSet from numbers import Number from collections import defaultdict, deque @@ -45,7 +45,7 @@ class SFG(AbstractOperation): _graph_id_generator: GraphIDGenerator _input_operations: List[Input] _output_operations: List[Output] - _original_components_to_new: Set[GraphComponent] + _original_components_to_new: MutableSet[GraphComponent] _original_input_signals_to_indices: Dict[Signal, int] _original_output_signals_to_indices: Dict[Signal, int] @@ -233,6 +233,13 @@ class SFG(AbstractOperation): def split(self) -> Iterable[Operation]: return self.operations + + def depends(self, output_index: int, input_index: int) -> bool: + if output_index < 0 or output_index >= self.output_count: + raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {output_index})") + if input_index < 0 or input_index >= self.input_count: + raise IndexError(f"Input index out of range (expected 0-{self.input_count - 1}, got {input_index})") + return self._source_depends(self._output_operations[output_index].input(0).signals[0].source, self._input_operations[input_index], set()) def copy_component(self, *args, **kwargs) -> GraphComponent: return super().copy_component(*args, **kwargs, inputs = self._input_operations, outputs = self._output_operations, @@ -385,3 +392,15 @@ class SFG(AbstractOperation): value = src.operation.evaluate_output(src.index, input_values, results, registers, src_prefix) results[key] = value return value + + def _source_depends(self, src: OutputPort, input_operation: Input, visited: MutableSet[Operation]) -> bool: + if src.operation is input_operation: + return True + + if src.operation not in visited: + visited.add(src.operation) + for i, port in enumerate(src.operation.inputs): + if src.operation.depends(src.index, i): + if self._source_depends(port.signals[0].source, input_operation, visited): + return True + return False \ No newline at end of file diff --git a/test/fixtures/signal_flow_graph.py b/test/fixtures/signal_flow_graph.py index 0a6c554d1340478dad25a11655f0542bf6fba1d1..e196f1f26323c28fcf0673e36cfc53783d371b63 100644 --- a/test/fixtures/signal_flow_graph.py +++ b/test/fixtures/signal_flow_graph.py @@ -29,6 +29,33 @@ def sfg_two_inputs_two_outputs(): out2 = Output(add2) return SFG(inputs = [in1, in2], outputs = [out1, out2]) +@pytest.fixture +def sfg_two_inputs_two_outputs_independent(): + """Valid SFG with two inputs and two outputs, where the first output only depends + on the first input and the second output only depends on the second input. + . . + in1-------------------->out1 + . . + . . + . c1--+ . + . | . + . v . + in2------+ add1---->out2 + . | ^ . + . | | . + . +------+ . + . . + out1 = in1 + out2 = in2 + 3 + """ + in1 = Input() + in2 = Input() + c1 = Constant(3) + add1 = in2 + c1 + out1 = Output(in1) + out2 = Output(add1) + return SFG(inputs = [in1, in2], outputs = [out1, out2]) + @pytest.fixture def sfg_nested(): """Valid SFG with two inputs and one output. diff --git a/test/test_depends.py b/test/test_depends.py new file mode 100644 index 0000000000000000000000000000000000000000..16efcaeab0049856aa796bd9e86cf0799070c77a --- /dev/null +++ b/test/test_depends.py @@ -0,0 +1,26 @@ +from b_asic import Addition, Butterfly + +class TestDepends: + def test_depends_addition(self): + add1 = Addition() + assert add1.depends(0, 0) == True + assert add1.depends(0, 1) == True + + def test_depends_butterfly(self): + bfly1 = Butterfly() + assert bfly1.depends(0, 0) == True + assert bfly1.depends(0, 1) == True + assert bfly1.depends(1, 0) == True + assert bfly1.depends(1, 1) == True + + def test_depends_sfg(self, sfg_two_inputs_two_outputs): + assert sfg_two_inputs_two_outputs.depends(0, 0) == True + assert sfg_two_inputs_two_outputs.depends(0, 1) == True + assert sfg_two_inputs_two_outputs.depends(1, 0) == True + assert sfg_two_inputs_two_outputs.depends(1, 1) == True + + def test_depends_sfg_independent(self, sfg_two_inputs_two_outputs_independent): + assert sfg_two_inputs_two_outputs_independent.depends(0, 0) == True + assert sfg_two_inputs_two_outputs_independent.depends(0, 1) == False + assert sfg_two_inputs_two_outputs_independent.depends(1, 0) == False + assert sfg_two_inputs_two_outputs_independent.depends(1, 1) == True \ No newline at end of file