Skip to content
Snippets Groups Projects
Commit 5f47ec2c authored by Simon Bjurek's avatar Simon Bjurek
Browse files

Add SFG generator for DIF FFT

parent ac8d2c32
No related branches found
No related tags found
1 merge request!462Add SFG generator for DIF FFT
Pipeline #155562 passed
......@@ -118,3 +118,5 @@ docs_sphinx/_build/
docs_sphinx/examples
result_images/
.coverage
Digraph.gv
Digraph.gv.pdf
......@@ -4,12 +4,13 @@ B-ASIC signal flow graph generators.
This module contains a number of functions generating SFGs for specific functions.
"""
from typing import Dict, Optional, Sequence, Union
from typing import TYPE_CHECKING, Dict, Optional, Sequence, Union
import numpy as np
from b_asic.core_operations import (
Addition,
Butterfly,
ConstantMultiplication,
Name,
SymmetricTwoportAdaptor,
......@@ -18,6 +19,9 @@ from b_asic.signal import Signal
from b_asic.signal_flow_graph import SFG
from b_asic.special_operations import Delay, Input, Output
if TYPE_CHECKING:
from b_asic.port import OutputPort
def wdf_allpass(
coefficients: Sequence[float],
......@@ -371,3 +375,121 @@ def direct_form_2_iir(
output = Output()
output <<= add
return SFG([input_op], [output], name=Name(name))
def radix_2_dif_fft(points: int) -> SFG:
"""Generates a radix-2 decimation-in-frequency FFT structure.
Parameters
----------
points : int
Number of points for the FFT, needs to be a positive power of 2.
Returns
-------
SFG
Signal Flow Graph
"""
if points < 0:
raise ValueError("Points must be positive number.")
if points & (points - 1) != 0:
raise ValueError("Points must be a power of two.")
inputs = []
for i in range(points):
inputs.append(Input(name=f"Input: {i}"))
ports = inputs
number_of_stages = int(np.log2(points))
twiddles = _generate_twiddles(points, number_of_stages)
for stage in range(number_of_stages):
ports = _construct_dif_fft_stage(
ports, number_of_stages, stage, twiddles[stage]
)
ports = _get_bit_reversed_ports(ports)
outputs = []
for i, port in enumerate(ports):
outputs.append(Output(port, name=f"Output: {i}"))
return SFG(inputs=inputs, outputs=outputs)
def _construct_dif_fft_stage(
ports_from_previous_stage: list["OutputPort"],
number_of_stages: int,
stage: int,
twiddles: list[np.complex128],
):
ports = ports_from_previous_stage.copy()
number_of_butterflies = len(ports) // 2
number_of_groups = 2**stage
group_size = number_of_butterflies // number_of_groups
for group_index in range(number_of_groups):
for bf_index in range(group_size):
input1_index = group_index * 2 * group_size + bf_index
input2_index = input1_index + group_size
input1 = ports[input1_index]
input2 = ports[input2_index]
butterfly = Butterfly(input1, input2)
output1, output2 = butterfly.outputs
twiddle_factor = twiddles[bf_index]
if twiddle_factor != 1:
name = _get_formatted_complex_number(twiddle_factor, 2)
twiddle_mul = ConstantMultiplication(
twiddles[bf_index], output2, name=name
)
output2 = twiddle_mul.output(0)
ports[input1_index] = output1
ports[input2_index] = output2
return ports
def _get_formatted_complex_number(number: np.complex128, digits: int) -> str:
real_str = str(np.round(number.real, digits))
imag_str = str(np.round(number.imag, digits))
if number.imag == 0:
return real_str
elif number.imag > 0:
return f"{real_str} + j{imag_str}"
else:
return f"{real_str} - j{str(-np.round(number.imag, digits))}"
def _get_bit_reversed_number(number: int, number_of_bits: int) -> int:
reversed_number = 0
for i in range(number_of_bits):
# mask out the current bit
shift_num = number
current_bit = (shift_num >> i) & 1
# compute the position of the current bit in the reversed string
reversed_pos = number_of_bits - 1 - i
# place the current bit in that position
reversed_number |= current_bit << reversed_pos
return reversed_number
def _get_bit_reversed_ports(ports: list["OutputPort"]) -> list["OutputPort"]:
num_of_ports = len(ports)
bits = int(np.log2(num_of_ports))
return [ports[_get_bit_reversed_number(i, bits)] for i in range(num_of_ports)]
def _generate_twiddles(points: int, number_of_stages: int) -> list[np.complex128]:
twiddles = []
for stage in range(1, number_of_stages + 1):
stage_twiddles = []
for k in range(points // 2 ** (stage)):
a = 2 ** (stage - 1)
twiddle = np.exp(-1j * 2 * np.pi * a * k / points)
stage_twiddles.append(twiddle)
twiddles.append(stage_twiddles)
return twiddles
......@@ -3,15 +3,17 @@ import pytest
from b_asic.core_operations import (
Addition,
Butterfly,
ConstantMultiplication,
SymmetricTwoportAdaptor,
)
from b_asic.sfg_generators import (
direct_form_fir,
radix_2_dif_fft,
transposed_direct_form_fir,
wdf_allpass,
)
from b_asic.signal_generator import Impulse
from b_asic.signal_generator import Constant, Impulse
from b_asic.simulation import Simulation
from b_asic.special_operations import Delay
......@@ -234,3 +236,135 @@ def test_sfg_generator_errors():
gen([])
with pytest.raises(TypeError, match="coefficients must be a 1D-array"):
gen([[1, 2], [1, 3]])
def test_radix_2_dif_fft_4_points_constant_input():
sfg = radix_2_dif_fft(points=4)
assert len(sfg.inputs) == 4
assert len(sfg.outputs) == 4
bfs = sfg.find_by_type_name(Butterfly.type_name())
assert len(bfs) == 4
muls = sfg.find_by_type_name(ConstantMultiplication.type_name())
assert len(muls) == 1
# simulate when the input signal is a constant 1
input_samples = [Impulse() for _ in range(4)]
sim = Simulation(sfg, input_samples)
sim.run_for(1)
# ensure that the result is an impulse at time 0 with weight 4
res = sim.results
for i in range(4):
exp_res = 4 if i == 0 else 0
assert np.allclose(res[str(i)], exp_res)
def test_radix_2_dif_fft_8_points_impulse_input():
sfg = radix_2_dif_fft(points=8)
assert len(sfg.inputs) == 8
assert len(sfg.outputs) == 8
bfs = sfg.find_by_type_name(Butterfly.type_name())
assert len(bfs) == 12
muls = sfg.find_by_type_name(ConstantMultiplication.type_name())
assert len(muls) == 5
# simulate when the input signal is an impulse at time 0
input_samples = [Impulse(), 0, 0, 0, 0, 0, 0, 0]
sim = Simulation(sfg, input_samples)
sim.run_for(1)
# ensure that the result is a constant 1
res = sim.results
for i in range(8):
assert np.allclose(res[str(i)], 1)
def test_radix_2_dif_fft_8_points_sinus_input():
POINTS = 8
sfg = radix_2_dif_fft(points=POINTS)
assert len(sfg.inputs) == POINTS
assert len(sfg.outputs) == POINTS
n = np.linspace(0, 2 * np.pi, POINTS)
waveform = np.sin(n)
input_samples = [Constant(waveform[i]) for i in range(POINTS)]
sim = Simulation(sfg, input_samples)
sim.run_for(1)
exp_res = abs(np.fft.fft(waveform))
res = sim.results
for i in range(POINTS):
a = abs(res[str(i)])
b = exp_res[i]
assert np.isclose(a, b)
def test_radix_2_dif_fft_16_points_sinus_input():
POINTS = 16
sfg = radix_2_dif_fft(points=POINTS)
assert len(sfg.inputs) == POINTS
assert len(sfg.outputs) == POINTS
bfs = sfg.find_by_type_name(Butterfly.type_name())
assert len(bfs) == 8 * 4
muls = sfg.find_by_type_name(ConstantMultiplication.type_name())
assert len(muls) == 17
n = np.linspace(0, 2 * np.pi, POINTS)
waveform = np.sin(n)
input_samples = [Constant(waveform[i]) for i in range(POINTS)]
sim = Simulation(sfg, input_samples)
sim.run_for(1)
exp_res = np.fft.fft(waveform)
res = sim.results
for i in range(POINTS):
a = res[str(i)]
b = exp_res[i]
assert np.isclose(a, b)
def test_radix_2_dif_fft_256_points_sinus_input():
POINTS = 256
sfg = radix_2_dif_fft(points=POINTS)
assert len(sfg.inputs) == POINTS
assert len(sfg.outputs) == POINTS
n = np.linspace(0, 2 * np.pi, POINTS)
waveform = np.sin(n)
input_samples = [Constant(waveform[i]) for i in range(POINTS)]
sim = Simulation(sfg, input_samples)
sim.run_for(1)
exp_res = np.fft.fft(waveform)
res = sim.results
for i in range(POINTS):
a = res[str(i)]
b = exp_res[i]
assert np.isclose(a, b)
def test_radix_2_dif_fft_negative_number_of_points():
POINTS = -8
with pytest.raises(ValueError, match="Points must be positive number."):
radix_2_dif_fft(points=POINTS)
def test_radix_2_dif_fft_number_of_points_not_power_of_2():
POINTS = 5
with pytest.raises(ValueError, match="Points must be a power of two."):
radix_2_dif_fft(points=POINTS)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment