diff --git a/b_asic/operation.py b/b_asic/operation.py index 557210bae107fcd5c59e3ae5e28c8ab5e7058fb9..dd5a93d0c3a8dd8818d50863e7169988b693943c 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -477,7 +477,7 @@ class Operation(GraphComponent, SignalSourceProvider): @abstractmethod def is_constant(self) -> bool: """ - Return True if the output of the operation is constant. + Return True if the output(s) of the operation is(are) constant. """ raise NotImplementedError @@ -948,7 +948,7 @@ class AbstractOperation(Operation, AbstractGraphComponent): for i, input_port in enumerate(self.inputs): value = input_values[i] if bits_override is None and input_port.signal_count >= 1: - input_port.signals[0].bits + bits_override = input_port.signals[0].bits if bits_override is not None: if isinstance(value, complex): raise TypeError( diff --git a/b_asic/signal.py b/b_asic/signal.py index 1a05940a9209aa0142278ddb6c011a03aadb0143..5f4e0948282096eae5f43659ebdf223e5578e9f1 100644 --- a/b_asic/signal.py +++ b/b_asic/signal.py @@ -193,3 +193,12 @@ class Signal(AbstractGraphComponent): if bits < 0: raise ValueError("Bits cannot be negative") self.set_param("bits", bits) + + @property + def is_constant(self) -> bool: + """ + Return True if the value of the signal (source) is constant. + """ + if self.source is None: + raise ValueError("Signal source not set") + return self.source.operation.is_constant diff --git a/test/test_signal.py b/test/test_signal.py index 33fd69f6f45530928b84a8f3d9b3685367340478..ca2c57f6b81a8a3d6ff38b730ebeadefe94ea3fe 100644 --- a/test/test_signal.py +++ b/test/test_signal.py @@ -4,9 +4,10 @@ B-ASIC test suit for the signal module which consists of the Signal class. import pytest -from b_asic.core_operations import Addition, Butterfly, ConstantMultiplication +from b_asic.core_operations import Addition, Butterfly, Constant, ConstantMultiplication from b_asic.port import InputPort, OutputPort from b_asic.signal import Signal +from b_asic.special_operations import Input def test_signal_creation_and_disconnection_and_connection_changing(): @@ -105,6 +106,16 @@ def test_create_from_single_input_single_output(): assert signal.source.operation.name == "Zig" +def test_signal_is_constant(): + c = Constant(0.5, name="Foo") + signal = Signal(c) + assert signal.is_constant + + i = Input() + signal = Signal(i) + assert not signal.is_constant + + def test_signal_errors(): cm1 = ConstantMultiplication(0.5, name="Foo") add1 = Addition(name="Zig") @@ -146,3 +157,7 @@ def test_signal_errors(): ), ): signal.set_source(bf) + + signal = Signal() + with pytest.raises(ValueError, match="Signal source not set"): + signal.is_constant