diff --git a/b_asic/sfg_generators.py b/b_asic/sfg_generators.py index c2095ec37602db22de2cbeb8a8ddf32ec5e18710..c9435e232caae783593576e7937ccc822138ec92 100644 --- a/b_asic/sfg_generators.py +++ b/b_asic/sfg_generators.py @@ -13,7 +13,6 @@ from b_asic.core_operations import ( Name, SymmetricTwoportAdaptor, ) -from b_asic.port import InputPort, OutputPort from b_asic.signal import Signal from b_asic.signal_flow_graph import SFG from b_asic.special_operations import Delay, Input, Output @@ -51,7 +50,7 @@ def wdf_allpass( ------- Signal flow graph """ - np_coefficients = np.squeeze(np.asarray(coefficients)) + np_coefficients = np.atleast_1d(np.squeeze(np.asarray(coefficients))) order = len(np_coefficients) if not order: raise ValueError("Coefficients cannot be empty") @@ -143,7 +142,7 @@ def direct_form_fir( -------- transposed_direct_form_fir """ - np_coefficients = np.squeeze(np.asarray(coefficients)) + np_coefficients = np.atleast_1d(np.squeeze(np.asarray(coefficients))) taps = len(np_coefficients) if not taps: raise ValueError("Coefficients cannot be empty") @@ -211,7 +210,7 @@ def transposed_direct_form_fir( -------- direct_form_fir """ - np_coefficients = np.squeeze(np.asarray(coefficients)) + np_coefficients = np.atleast_1d(np.squeeze(np.asarray(coefficients))) taps = len(np_coefficients) if not taps: raise ValueError("Coefficients cannot be empty") diff --git a/test/test_sfg_generators.py b/test/test_sfg_generators.py index 8e679b8b48b2abe858359b392cb803e3c259aac2..78ca06f4ef5e778a5ceb22faec390c3b1ef703e8 100644 --- a/test/test_sfg_generators.py +++ b/test/test_sfg_generators.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from b_asic.core_operations import ( Addition, @@ -16,6 +17,7 @@ from b_asic.special_operations import Delay def test_wdf_allpass(): + # Third-order sfg = wdf_allpass([0.3, 0.5, 0.7]) assert ( len( @@ -28,6 +30,9 @@ def test_wdf_allpass(): == 3 ) + assert len([comp for comp in sfg.components if isinstance(comp, Delay)]) == 3 + + # Fourth-order sfg = wdf_allpass([0.3, 0.5, 0.7, 0.9]) assert ( len( @@ -40,6 +45,34 @@ def test_wdf_allpass(): == 4 ) + assert len([comp for comp in sfg.components if isinstance(comp, Delay)]) == 4 + + # First-order + sfg = wdf_allpass([0.3]) + assert ( + len( + [ + comp + for comp in sfg.components + if isinstance(comp, SymmetricTwoportAdaptor) + ] + ) + == 1 + ) + + # First-order with scalar input (happens to work) + sfg = wdf_allpass(0.3) + assert ( + len( + [ + comp + for comp in sfg.components + if isinstance(comp, SymmetricTwoportAdaptor) + ] + ) + == 1 + ) + def test_direct_form_fir(): impulse_response = [0.3, 0.5, 0.7] @@ -75,6 +108,36 @@ def test_direct_form_fir(): impulse_response.append(0.0) assert np.allclose(sim.results['0'], impulse_response) + impulse_response = [0.3] + sfg = direct_form_fir(impulse_response) + assert ( + len( + [ + comp + for comp in sfg.components + if isinstance(comp, ConstantMultiplication) + ] + ) + == 1 + ) + assert len([comp for comp in sfg.components if isinstance(comp, Addition)]) == 0 + assert len([comp for comp in sfg.components if isinstance(comp, Delay)]) == 0 + + impulse_response = 0.3 + sfg = direct_form_fir(impulse_response) + assert ( + len( + [ + comp + for comp in sfg.components + if isinstance(comp, ConstantMultiplication) + ] + ) + == 1 + ) + assert len([comp for comp in sfg.components if isinstance(comp, Addition)]) == 0 + assert len([comp for comp in sfg.components if isinstance(comp, Delay)]) == 0 + def test_transposed_direct_form_fir(): impulse_response = [0.3, 0.5, 0.7] @@ -109,3 +172,42 @@ def test_transposed_direct_form_fir(): sim.run_for(6) impulse_response.append(0.0) assert np.allclose(sim.results['0'], impulse_response) + + impulse_response = [0.3] + sfg = transposed_direct_form_fir(impulse_response) + assert ( + len( + [ + comp + for comp in sfg.components + if isinstance(comp, ConstantMultiplication) + ] + ) + == 1 + ) + assert len([comp for comp in sfg.components if isinstance(comp, Addition)]) == 0 + assert len([comp for comp in sfg.components if isinstance(comp, Delay)]) == 0 + + impulse_response = 0.3 + sfg = transposed_direct_form_fir(impulse_response) + assert ( + len( + [ + comp + for comp in sfg.components + if isinstance(comp, ConstantMultiplication) + ] + ) + == 1 + ) + assert len([comp for comp in sfg.components if isinstance(comp, Addition)]) == 0 + assert len([comp for comp in sfg.components if isinstance(comp, Delay)]) == 0 + + +def test_sfg_generator_errors(): + sfg_gens = [wdf_allpass, transposed_direct_form_fir, direct_form_fir] + for gen in sfg_gens: + with pytest.raises(ValueError, match="Coefficients cannot be empty"): + gen([]) + with pytest.raises(TypeError, match="coefficients must be a 1D-array"): + gen([[1, 2], [1, 3]])