Fourier Transform

Calculus to Optimization & Analysis

Continuous Fourier Transform Properties of Fourier Transform Convolution Theorem Discrete Fourier Transform Fast Fourier Transform Applications in Machine Learning Interactive Demos

Continuous Fourier Transform

The Fourier transform extends the ideas of Fourier series from periodic functions to non-periodic functions defined on \(\mathbb{R}\). While Fourier series decompose periodic signals into discrete frequency components, the Fourier transform decomposes general signals into a continuous spectrum of frequencies.

Motivation from Fourier Series: Consider a function \(f_L\) that is periodic with period \(2L\). It has the complex Fourier series: \[ f_L(x) = \sum_{n=-\infty}^{\infty} c_n e^{-i\frac{n\pi x}{L}}, \quad c_n = \frac{1}{2L}\int_{-L}^{L} f_L(x)e^{i\frac{n\pi x}{L}} \, dx \] As \(L \to \infty\), the discrete frequencies \(\omega_n = \frac{n\pi}{L}\) become densely packed with spacing \(\Delta\omega = \frac{\pi}{L} \to 0\), and the sum approaches an integral. This limiting process yields the Fourier transform.

For a function \(f: \mathbb{R} \to \mathbb{C}\), we define: \[ \boxed{ \hat{f}(\xi) = \mathcal{F}\{f\}(\xi) = \int_{-\infty}^{\infty} f(x)e^{i x\xi} \, dx } \] where \(\xi \in \mathbb{R}\) is the frequency variable.

The inverse Fourier transform recovers \(f\) from \(\hat{f}\): \[ \boxed{ f(x) = \mathcal{F}^{-1}\{\hat{f}\}(x) = \frac{1}{2\pi}\int_{-\infty}^{\infty} \hat{f}(\xi)e^{-i x\xi} \, d\xi } \]
Note on Conventions: As discussed in Part 14, we use the mathematical/PDE convention with positive sign in the forward transform and negative in the inverse. Engineering and physics often use the opposite convention: \(\hat{f}(\omega) = \int_{-\infty}^{\infty} f(t)e^{-i\omega t} \, dt\). Both are mathematically equivalent—just be consistent within your work.

Function Spaces: The Fourier transform is well-defined for different function classes:

Example: Gaussian Function Consider the Gaussian function: \[ f(x) = e^{-\frac{x^2}{2\sigma^2}} \] Its Fourier transform (using the mathematical convention) is: \[ \begin{align*} \hat{f}(\xi) &= \int_{-\infty}^{\infty} e^{-\frac{x^2}{2\sigma^2}}e^{i\xi x} \, dx \\\\ &= \int_{-\infty}^{\infty} e^{-\frac{x^2}{2\sigma^2} + i\xi x} \, dx \end{align*} \] Complete the square in the exponent: \[ -\frac{x^2}{2\sigma^2} + i\xi x = -\frac{1}{2\sigma^2}(x^2 - 2i\sigma^2\xi x) = -\frac{1}{2\sigma^2}((x - i\sigma^2\xi)^2 + \sigma^4\xi^2) \] Therefore: \[ \hat{f}(\xi) = e^{-\frac{\sigma^2\xi^2}{2}} \int_{-\infty}^{\infty} e^{-\frac{(x-i\sigma^2\xi)^2}{2\sigma^2}} \, dx \] By contour integration (the integrand is analytic and decays exponentially), shifting the path doesn't change the value: \[ \int_{-\infty}^{\infty} e^{-\frac{(x-i\sigma^2\xi)^2}{2\sigma^2}} \, dx = \int_{-\infty}^{\infty} e^{-\frac{u^2}{2\sigma^2}} \, du = \sigma\sqrt{2\pi} \] Thus: \[ \boxed{\hat{f}(\xi) = \sigma\sqrt{2\pi} \cdot e^{-\frac{\sigma^2\xi^2}{2}}} \] This shows that the Gaussian is essentially an eigenfunction of the Fourier transform: its transform is also Gaussian, with inverse width (narrow in space ↔ wide in frequency). This property makes Gaussians fundamental in uncertainty principles, quantum mechanics, and signal processing.

Properties of Fourier Transform

The Fourier transform has several important properties that make it a powerful tool for analysis. We state these using the mathematical convention \(\hat{f}(\xi) = \int_{-\infty}^{\infty} f(x)e^{ix\xi} \, dx\):

1. Linearity: \[ \mathcal{F}\{\alpha f + \beta g\} = \alpha \mathcal{F}\{f\} + \beta \mathcal{F}\{g\} \] for constants \(\alpha, \beta \in \mathbb{C}\).

2. Translation (Shift) Property: \[ \mathcal{F}\{f(x - a)\}(\xi) = e^{-ia\xi}\hat{f}(\xi) \] A shift in space corresponds to a phase shift in frequency.

3. Modulation Property: \[ \mathcal{F}\{e^{i\xi_0 x}f(x)\}(\xi) = \hat{f}(\xi - \xi_0) \] Multiplication by a complex exponential shifts the frequency spectrum.

4. Scaling Property: \[ \mathcal{F}\{f(ax)\}(\xi) = \frac{1}{|a|}\hat{f}\left(\frac{\xi}{a}\right), \quad a \neq 0 \] This embodies the uncertainty principle: compressing a signal in time stretches its frequency spectrum, and vice versa. One cannot localize a signal arbitrarily well in both time and frequency simultaneously.

5. Differentiation Property: \[ \mathcal{F}\{f'(x)\}(\xi) = -i\xi \hat{f}(\xi) \] More generally, for the \(n\)-th derivative: \[ \mathcal{F}\{f^{(n)}(x)\}(\xi) = (-i\xi)^n \hat{f}(\xi) \] This transforms differentiation into algebraic multiplication, which is why Fourier methods are powerful for solving differential equations. Note the sign difference from the engineering convention.

6. Multiplication by \(x^n\) Property: \[ \mathcal{F}\{x^n f(x)\}(\xi) = i^n \frac{d^n}{d\xi^n}\hat{f}(\xi) \] Multiplication by powers of \(x\) becomes differentiation in frequency domain.

7. Plancherel's Theorem (Parseval for Non-periodic Functions):
For \(f, g \in L^2(\mathbb{R})\): \[ \int_{-\infty}^{\infty} f(x)\overline{g(x)} \, dx = \frac{1}{2\pi}\int_{-\infty}^{\infty} \hat{f}(\xi)\overline{\hat{g}(\xi)} \, d\xi \] In particular, setting \(f = g\): \[ \boxed{\|f\|_{L^2}^2 = \int_{-\infty}^{\infty} |f(x)|^2 \, dx = \frac{1}{2\pi}\int_{-\infty}^{\infty} |\hat{f}(\xi)|^2 \, d\xi = \frac{1}{2\pi}\|\hat{f}\|_{L^2}^2} \] This generalizes Parseval's identity and shows that the Fourier transform preserves energy (up to the factor \(2\pi\)). The map \(f \mapsto \frac{1}{\sqrt{2\pi}}\hat{f}\) is a unitary operator on \(L^2(\mathbb{R})\).

8. Fourier Inversion Theorem:
For suitable functions (e.g., \(f \in L^1(\mathbb{R})\) with \(\hat{f} \in L^1(\mathbb{R})\)): \[ \mathcal{F}\{\mathcal{F}\{f\}(\xi)\}(x) = 2\pi f(-x) \] Applying the Fourier transform twice (up to normalization) gives a reflection of the original function, demonstrating deep symmetry between spatial and frequency domains.

9. Riemann-Lebesgue Lemma:
For \(f \in L^1(\mathbb{R})\): \[ \lim_{|\xi| \to \infty} \hat{f}(\xi) = 0 \] The Fourier transform of an integrable function vanishes at infinity. This has important implications for the smoothness-decay duality: smoother functions have faster-decaying transforms.

Convolution Theorem

The convolution of two functions \(f, g: \mathbb{R} \to \mathbb{C}\) is defined as: \[ (f * g)(x) = \int_{-\infty}^{\infty} f(y)g(x-y) \, dy = \int_{-\infty}^{\infty} f(x-y)g(y) \, dy \] (The second equality shows convolution is commutative.)

The convolution theorem states that the Fourier transform converts convolution into pointwise multiplication: \[ \boxed{\mathcal{F}\{f * g\}(\xi) = \hat{f}(\xi) \cdot \hat{g}(\xi)} \] and conversely, the Fourier transform of a product is a (scaled) convolution: \[ \boxed{\mathcal{F}\{f \cdot g\}(\xi) = \frac{1}{2\pi}(\hat{f} * \hat{g})(\xi)} \]
This is one of the most important results in applied mathematics. It allows us to:


Applications in Computer Science and ML:
Proof of Convolution Theorem: Using the mathematical convention with \(\hat{f}(\xi) = \int f(x)e^{ix\xi} \, dx\): \[ \begin{align*} \mathcal{F}\{f * g\}(\xi) &= \int_{-\infty}^{\infty} (f * g)(x)e^{ix\xi} \, dx \\\\ &= \int_{-\infty}^{\infty} \left(\int_{-\infty}^{\infty} f(y)g(x-y) \, dy\right) e^{ix\xi} \, dx \\\\ &= \int_{-\infty}^{\infty} f(y) \left(\int_{-\infty}^{\infty} g(x-y)e^{ix\xi} \, dx\right) dy \end{align*} \] Substituting \(u = x - y\) (so \(x = u + y\) and \(dx = du\)): \[ \begin{align*} &= \int_{-\infty}^{\infty} f(y) \left(\int_{-\infty}^{\infty} g(u)e^{i(u+y)\xi} \, du\right) dy \\\\ &= \int_{-\infty}^{\infty} f(y)e^{iy\xi} \left(\int_{-\infty}^{\infty} g(u)e^{iu\xi} \, du\right) dy \\\\ &= \left(\int_{-\infty}^{\infty} f(y)e^{iy\xi} \, dy\right) \cdot \left(\int_{-\infty}^{\infty} g(u)e^{iu\xi} \, du\right) \\\\ &= \hat{f}(\xi) \cdot \hat{g}(\xi) \end{align*} \] The interchange of integration order is justified by Fubini's theorem when \(f, g \in L^1(\mathbb{R})\).

Discrete Fourier Transform (DFT)

In computational applications, we work with discrete, finite sequences rather than continuous functions. The Discrete Fourier Transform (DFT) is the discrete analog of the continuous Fourier transform.

Definition: For a sequence \(\{x_0, x_1, \ldots, x_{N-1}\}\) of \(N\) complex numbers, the DFT is: \[ X_k = \sum_{n=0}^{N-1} x_n e^{-\frac{2\pi i}{N}kn}, \quad k = 0, 1, \ldots, N-1 \] The inverse DFT (IDFT) is: \[ x_n = \frac{1}{N}\sum_{k=0}^{N-1} X_k e^{\frac{2\pi i}{N}kn}, \quad n = 0, 1, \ldots, N-1 \]
Connection to Continuous Transform: The DFT can be viewed as:


Matrix Formulation: Define \(\omega = e^{-\frac{2\pi i}{N}}\), a primitive \(N\)-th root of unity (so \(\omega^N = 1\)). The DFT matrix is: \[ W = \begin{bmatrix} 1 & 1 & 1 & \cdots & 1 \\ 1 & \omega & \omega^2 & \cdots & \omega^{N-1} \\ 1 & \omega^2 & \omega^4 & \cdots & \omega^{2(N-1)} \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & \omega^{N-1} & \omega^{2(N-1)} & \cdots & \omega^{(N-1)^2} \end{bmatrix} \] where \(W_{kn} = \omega^{kn}\). Then: \[ \mathbf{X} = W\mathbf{x}, \quad \mathbf{x} = \frac{1}{N}W^*\mathbf{X} \] where \(W^*\) is the conjugate transpose.

Properties of the DFT Matrix:
Computational Complexity:
Practical Considerations for CS/ML:

Fast Fourier Transform (FFT)

The Fast Fourier Transform (FFT) is not a different transform but an efficient algorithm for computing the DFT, reducing complexity from \(O(N^2)\) to \(O(N \log N)\).

Cooley-Tukey Algorithm (Radix-2 FFT):
The key insight is to exploit the structure of the DFT matrix using divide-and-conquer. For \(N = 2^m\), split the input into even and odd indices: \[ \begin{align*} X_k &= \sum_{n=0}^{N-1} x_n \omega^{kn} \\\\ &= \sum_{j=0}^{N/2-1} x_{2j} \omega^{k(2j)} + \sum_{j=0}^{N/2-1} x_{2j+1} \omega^{k(2j+1)} \\\\ &= \sum_{j=0}^{N/2-1} x_{2j} (\omega^2)^{kj} + \omega^k \sum_{j=0}^{N/2-1} x_{2j+1} (\omega^2)^{kj} \end{align*} \] Note that \(\omega^2 = e^{-\frac{4\pi i}{N}} = e^{-\frac{2\pi i}{N/2}}\) is a primitive \((N/2)\)-th root of unity.

Let \(E_k\) be the DFT of the even-indexed subsequence and \(O_k\) the DFT of odd-indexed subsequence. Using the periodicity of the DFT, for \(k = 0, 1, \ldots, N/2-1\): \[ X_k = E_k + \omega^k O_k \] \[ X_{k+N/2} = E_k - \omega^k O_k \] (using the fact that \(\omega^{k+N/2} = -\omega^k\) since \(\omega^{N/2} = e^{-i\pi} = -1\)).

This gives the recurrence \(T(N) = 2T(N/2) + O(N)\), yielding \(T(N) = O(N \log N)\) by the Master theorem.

Other FFT Algorithms:


Implementation Considerations:
Python Example:
                            import numpy as np
                            import matplotlib.pyplot as plt

                            # Generate a signal with multiple frequency components
                            fs = 1000  # Sampling frequency (Hz)
                            t = np.linspace(0, 1, fs, endpoint=False)
                            signal = (np.sin(2*np.pi*50*t) +           # 50 Hz component
                                    0.5*np.sin(2*np.pi*120*t) +      # 120 Hz component
                                    0.3*np.cos(2*np.pi*200*t))       # 200 Hz component

                            # Add some Gaussian noise
                            np.random.seed(42)  # For reproducibility
                            signal += 0.3 * np.random.randn(len(t))

                            # Compute FFT
                            fft_vals = np.fft.fft(signal)
                            fft_freq = np.fft.fftfreq(len(signal), d=1/fs)

                            # Only plot positive frequencies (due to Hermitian symmetry for real signals)
                            pos_mask = fft_freq >= 0
                            plt.figure(figsize=(12, 4))

                            plt.subplot(1, 2, 1)
                            plt.plot(t[:100], signal[:100])
                            plt.xlabel('Time (s)')
                            plt.ylabel('Amplitude')
                            plt.title('Signal (first 100 samples)')
                            plt.grid(True, alpha=0.3)

                            plt.subplot(1, 2, 2)
                            plt.plot(fft_freq[pos_mask], np.abs(fft_vals[pos_mask])/len(signal)*2)
                            plt.xlabel('Frequency (Hz)')
                            plt.ylabel('Magnitude')
                            plt.title('Frequency Spectrum (Single-sided)')
                            plt.xlim(0, 300)
                            plt.grid(True, alpha=0.3)
                            plt.tight_layout()
                            plt.show()

                            # Verify Parseval's theorem
                            energy_time = np.sum(np.abs(signal)**2)
                            energy_freq = np.sum(np.abs(fft_vals)**2) / len(signal)
                            print(f"Energy in time domain: {energy_time:.2f}")
                            print(f"Energy in frequency domain: {energy_freq:.2f}")
                            print(f"Relative error: {abs(energy_time - energy_freq)/energy_time:.2e}")
                        

Applications in Machine Learning

Fourier methods are fundamental to modern machine learning, particularly in areas requiring efficient computation, signal processing, and function approximation. Here we focus on the most significant and current applications.

1. Fourier Neural Operators (FNO) for Scientific Computing:
Introduced by Li et al. (NeurIPS 2020), FNOs have revolutionized neural PDE solvers and are actively used in weather prediction (FourCastNet by NVIDIA, 2022), fluid dynamics, and climate modeling. They achieve 1000× speedup over traditional solvers while maintaining accuracy.

Key Innovation: FNOs learn integral operators directly in Fourier space: \[ (\mathcal{K}\varphi)(x) = \int k(x, y)\varphi(y) \, dy \quad \xrightarrow{\text{Fourier}} \quad \widehat{(\mathcal{K}\varphi)}(\xi) = \hat{k}(\xi) \cdot \hat{\varphi}(\xi) \] By parameterizing \(\hat{k}(\xi)\) with neural networks in frequency domain, FNOs efficiently learn resolution-invariant operators.

                            import torch
                            import torch.nn as nn
                            import torch.fft

                            class SpectralConv2d(nn.Module):
                                """2D Fourier layer used in FNO. Efficient for learning PDEs."""
                                def __init__(self, in_channels, out_channels, modes1, modes2):
                                    super().__init__()
                                    self.in_channels = in_channels
                                    self.out_channels = out_channels
                                    self.modes1 = modes1  # Number of Fourier modes to keep
                                    self.modes2 = modes2
                                    
                                    # Complex weights for Fourier coefficients
                                    # Initialize with Glorot uniform scaling
                                    scale = (1 / (in_channels * out_channels))**(1/2)
                                    self.weights1 = nn.Parameter(scale * torch.randn(
                                        in_channels, out_channels, modes1, modes2, dtype=torch.cfloat))
                                    self.weights2 = nn.Parameter(scale * torch.randn(
                                        in_channels, out_channels, modes1, modes2, dtype=torch.cfloat))
                                
                                def forward(self, x):
                                    batch, channels, height, width = x.shape
                                    
                                    # Compute 2D FFT (real-to-complex for efficiency)
                                    x_ft = torch.fft.rfft2(x, norm='ortho')
                                    
                                    # Initialize output tensor in Fourier space
                                    out_ft = torch.zeros(batch, self.out_channels, height, width // 2 + 1,
                                                    dtype=torch.cfloat, device=x.device)
                                    
                                    # Multiply low frequencies (top-left corner in frequency space)
                                    out_ft[:, :, :self.modes1, :self.modes2] = torch.einsum(
                                        "bixy,ioxy->boxy",
                                        x_ft[:, :, :self.modes1, :self.modes2],
                                        self.weights1)
                                    
                                    # Multiply high frequencies (bottom-left corner due to periodicity)
                                    out_ft[:, :, -self.modes1:, :self.modes2] = torch.einsum(
                                        "bixy,ioxy->boxy",
                                        x_ft[:, :, -self.modes1:, :self.modes2],
                                        self.weights2)
                                    
                                    # Inverse FFT to return to physical space
                                    x = torch.fft.irfft2(out_ft, s=(height, width), norm='ortho')
                                    return x

                            # Example: FNO architecture for solving 2D PDEs
                            class FNO2d(nn.Module):
                                """Fourier Neural Operator for 2D PDEs (e.g., Navier-Stokes)"""
                                def __init__(self, modes=12, width=64, in_channels=3, out_channels=1):
                                    super().__init__()
                                    self.modes = modes
                                    self.width = width
                                    
                                    # Lifting layer (point-wise)
                                    self.fc0 = nn.Linear(in_channels, width)
                                    
                                    # Fourier layers
                                    self.conv0 = SpectralConv2d(width, width, modes, modes)
                                    self.conv1 = SpectralConv2d(width, width, modes, modes)
                                    self.conv2 = SpectralConv2d(width, width, modes, modes)
                                    self.conv3 = SpectralConv2d(width, width, modes, modes)
                                    
                                    # Regular convolutions for local features
                                    self.w0 = nn.Conv2d(width, width, 1)
                                    self.w1 = nn.Conv2d(width, width, 1)
                                    self.w2 = nn.Conv2d(width, width, 1)
                                    self.w3 = nn.Conv2d(width, width, 1)
                                    
                                    # Projection layer
                                    self.fc1 = nn.Linear(width, 128)
                                    self.fc2 = nn.Linear(128, out_channels)
                                    
                                def forward(self, x):
                                    # x shape: (batch, height, width, channels)
                                    x = x.permute(0, 3, 1, 2)  # -> (batch, channels, height, width)
                                    x = self.fc0(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
                                    
                                    # Fourier layers with residual connections
                                    x1 = self.conv0(x)
                                    x2 = self.w0(x)
                                    x = torch.nn.functional.gelu(x1 + x2)
                                    
                                    x1 = self.conv1(x)
                                    x2 = self.w1(x)
                                    x = torch.nn.functional.gelu(x1 + x2)
                                    
                                    x1 = self.conv2(x)
                                    x2 = self.w2(x)
                                    x = torch.nn.functional.gelu(x1 + x2)
                                    
                                    x1 = self.conv3(x)
                                    x2 = self.w3(x)
                                    x = x1 + x2
                                    
                                    # Project to output
                                    x = x.permute(0, 2, 3, 1)  # -> (batch, height, width, channels)
                                    x = self.fc1(x)
                                    x = torch.nn.functional.gelu(x)
                                    x = self.fc2(x)
                                    return x
                        

2. State Space Models and Long-Range Sequence Modeling:
Recent breakthroughs like S4 (Structured State Spaces) by Gu et al. (ICLR 2022) and Mamba (2024) use FFT-based convolutions to model sequences with million-token context lengths efficiently. These models compete with Transformers while using \(O(N \log N)\) complexity instead of \(O(N^2)\).

Key Insight: Long convolutions via FFT enable efficient computation of state space models: \[ y = K * u \quad \text{where} \quad K \in \mathbb{R}^L \text{ is a learned kernel} \]
                            import numpy as np
                            import torch
                            import torch.nn as nn

                            def fft_conv(u, k, L):
                                """Compute convolution u * k using FFT for efficiency.
                                Used in S4, Mamba, and other long-range sequence models.
                                
                                Args:
                                    u: Input sequence of shape (batch, length, channels)
                                    k: Convolution kernel of shape (channels, kernel_length)
                                    L: Output length
                                """
                                batch, _, channels = u.shape
                                
                                # Pad kernel to match sequence length
                                k_f = torch.fft.rfft(k, n=2*L, dim=-1)
                                u_f = torch.fft.rfft(u, n=2*L, dim=1)
                                
                                # Multiply in frequency domain (broadcasting over batch)
                                y_f = u_f[..., None, :] * k_f[None, :, :, None]
                                y_f = y_f.sum(dim=2)  # Sum over input channels
                                
                                # Inverse FFT and truncate
                                y = torch.fft.irfft(y_f, n=2*L, dim=1)[:, :L, :]
                                return y

                            class S4Layer(nn.Module):
                                """Simplified S4 layer using FFT convolution"""
                                def __init__(self, d_model, l_max=1024, d_state=64):
                                    super().__init__()
                                    self.d_model = d_model
                                    self.l_max = l_max
                                    self.d_state = d_state
                                    
                                    # State space parameters (HiPPO initialization)
                                    # A encodes the state evolution
                                    A = self._make_HiPPO_matrix(d_state)
                                    self.A = nn.Parameter(torch.from_numpy(A).float())
                                    self.B = nn.Parameter(torch.randn(d_state, d_model))
                                    self.C = nn.Parameter(torch.randn(d_model, d_state))
                                    self.D = nn.Parameter(torch.randn(d_model))
                                    
                                    # Step size for discretization
                                    self.log_dt = nn.Parameter(torch.randn(d_model) * 0.001)
                                    
                                def _make_HiPPO_matrix(self, N):
                                    """Create HiPPO matrix for long-range memory"""
                                    P = np.sqrt(1 + 2*np.arange(N))
                                    A = P[:, np.newaxis] * P[np.newaxis, :]
                                    A = np.tril(A) - np.diag(np.arange(N))
                                    return -A
                                    
                                def forward(self, u):
                                    """
                                    Args:
                                        u: (batch, length, d_model) input sequence
                                    Returns:
                                        y: (batch, length, d_model) output sequence
                                    """
                                    batch, L, _ = u.shape
                                    
                                    # Discretize continuous state space
                                    dt = torch.exp(self.log_dt)
                                    
                                    # Discretize A using bilinear transform
                                    A_discrete = torch.inverse(
                                        torch.eye(self.d_state, device=u.device) - dt.unsqueeze(-1) * self.A / 2
                                    ) @ (torch.eye(self.d_state, device=u.device) + dt.unsqueeze(-1) * self.A / 2)
                                    
                                    # Generate convolution kernel from state space
                                    # K(t) = C @ A^t @ B
                                    K = torch.zeros(L, self.d_model, self.d_model, device=u.device)
                                    powers = A_discrete.unsqueeze(0).pow(
                                        torch.arange(L, device=u.device).unsqueeze(-1).unsqueeze(-1)
                                    )
                                    K = torch.einsum('lij,jk,km->lmk', powers, self.B, self.C.T)
                                    
                                    # Apply convolution via FFT
                                    y = fft_conv(u, K.reshape(L, -1), L).reshape(batch, L, self.d_model)
                                    
                                    # Add skip connection
                                    y = y + self.D * u
                                    
                                    return y
                        

3. Audio and Speech Processing Foundation Models:
Modern speech models like Whisper (OpenAI, 2022) and Wav2Vec 2.0 (Meta, 2020) fundamentally rely on Fourier transforms for feature extraction:
                            import torch
                            import torchaudio
                            import torchaudio.transforms as T

                            class WhisperFeatureExtractor:
                                """Feature extraction pipeline used in OpenAI's Whisper model"""
                                def __init__(self, 
                                            sample_rate=16000,
                                            n_fft=400,
                                            hop_length=160,
                                            n_mels=80,
                                            chunk_length=30):
                                    self.sample_rate = sample_rate
                                    self.n_fft = n_fft
                                    self.hop_length = hop_length
                                    self.chunk_length = chunk_length
                                    self.n_samples = chunk_length * sample_rate
                                    
                                    # Mel-spectrogram transform (matches Whisper's configuration)
                                    self.mel_filters = T.MelSpectrogram(
                                        sample_rate=sample_rate,
                                        n_fft=n_fft,
                                        hop_length=hop_length,
                                        n_mels=n_mels,
                                        f_min=0,
                                        f_max=8000,
                                        mel_scale='htk',  # HTK scale as used in Whisper
                                        norm='slaney',
                                        window_fn=torch.hann_window
                                    )
                                    
                                def __call__(self, waveform):
                                    """
                                    Convert waveform to log-mel spectrogram features
                                    
                                    Args:
                                        waveform: Audio tensor of shape (channels, samples)
                                    Returns:
                                        features: Log-mel spectrogram of shape (n_mels, time_frames)
                                    """
                                    # Ensure mono audio
                                    if waveform.shape[0] > 1:
                                        waveform = waveform.mean(dim=0, keepdim=True)
                                    
                                    # Pad or trim to 30 seconds
                                    if waveform.shape[1] < self.n_samples:
                                        waveform = torch.nn.functional.pad(
                                            waveform, (0, self.n_samples - waveform.shape[1])
                                        )
                                    else:
                                        waveform = waveform[:, :self.n_samples]
                                    
                                    # Compute mel-spectrogram
                                    mel_spec = self.mel_filters(waveform)
                                    
                                    # Convert to log scale (matching Whisper's normalization)
                                    log_spec = torch.clamp(mel_spec, min=1e-10).log10()
                                    log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
                                    log_spec = (log_spec + 4.0) / 4.0
                                    
                                    return log_spec

                            # Example usage
                            feature_extractor = WhisperFeatureExtractor()
                            # audio, sr = torchaudio.load('speech.wav')
                            # if sr != 16000:
                            #     resampler = T.Resample(sr, 16000)
                            #     audio = resampler(audio)
                            # features = feature_extractor(audio)
                            # Now feed 'features' to Whisper encoder
                        

4. Vision Transformers and Frequency-Aware Architectures:
Recent vision models leverage frequency domain insights for improved efficiency and performance:
                            import torch
                            import torch.nn as nn

                            class FNetBlock(nn.Module):
                                """FNet block: FFT mixing replaces self-attention (Lee-Thorp et al., 2021)"""
                                def __init__(self, d_model, d_ff=None, dropout=0.1):
                                    super().__init__()
                                    d_ff = d_ff or 4 * d_model
                                    
                                    self.mixing_layer_norm = nn.LayerNorm(d_model)
                                    self.feed_forward = nn.Sequential(
                                        nn.LayerNorm(d_model),
                                        nn.Linear(d_model, d_ff),
                                        nn.GELU(),
                                        nn.Dropout(dropout),
                                        nn.Linear(d_ff, d_model),
                                        nn.Dropout(dropout)
                                    )
                                    
                                def forward(self, x):
                                    """
                                    Args:
                                        x: Input tensor of shape (batch, seq_len, d_model)
                                    """
                                    # FFT mixing on sequence and feature dimensions
                                    residual = x
                                    x = self.mixing_layer_norm(x)
                                    
                                    # Apply 2D FFT (real-valued FFT for efficiency)
                                    # FFT on sequence dimension
                                    x_complex = torch.fft.fft(x, dim=1, norm='ortho')
                                    # FFT on feature dimension  
                                    x_complex = torch.fft.fft(x_complex, dim=2, norm='ortho')
                                    # Keep only real part (as in the paper)
                                    x = x_complex.real
                                    
                                    x = residual + x
                                    
                                    # Feed-forward network
                                    x = x + self.feed_forward(x)
                                    
                                    return x

                            class FocalFrequencyLoss(nn.Module):
                                """Focal Frequency Loss for image reconstruction (Jiang et al., 2021)"""
                                def __init__(self, loss_weight=1.0, alpha=1.0, patch_factor=1):
                                    super().__init__()
                                    self.loss_weight = loss_weight
                                    self.alpha = alpha
                                    self.patch_factor = patch_factor
                                    
                                def forward(self, pred, target):
                                    """
                                    Args:
                                        pred: Predicted image (B, C, H, W)
                                        target: Target image (B, C, H, W)
                                    """
                                    # Compute 2D FFT
                                    pred_fft = torch.fft.fft2(pred, norm='ortho')
                                    target_fft = torch.fft.fft2(target, norm='ortho')
                                    
                                    # Compute frequency distance
                                    freq_distance = torch.abs(pred_fft - target_fft)
                                    
                                    # Create frequency weight matrix (emphasize high frequencies)
                                    B, C, H, W = pred.shape
                                    h_half, w_half = H // 2, W // 2
                                    
                                    weight_matrix = torch.ones_like(freq_distance)
                                    # Increase weight for high frequencies (distance from center)
                                    for i in range(H):
                                        for j in range(W):
                                            dist = np.sqrt((i - h_half)**2 + (j - w_half)**2)
                                            weight_matrix[:, :, i, j] *= (1 + self.alpha * dist / h_half)
                                    
                                    # Weighted frequency loss
                                    loss = torch.mean(weight_matrix * freq_distance)
                                    
                                    return self.loss_weight * loss
                        

5. Neural Fields and Positional Encodings:
Fourier features enable neural networks to learn high-frequency functions:
                            import torch
                            import torch.nn as nn
                            import numpy as np

                            class FourierFeatures(nn.Module):
                                """Fourier feature mapping for positional encoding (Tancik et al., 2020)"""
                                def __init__(self, input_dim, mapping_size=256, scale=10.0):
                                    super().__init__()
                                    self.input_dim = input_dim
                                    self.output_dim = 2 * mapping_size
                                    
                                    # Random Fourier feature mapping
                                    # B ~ N(0, scale^2)
                                    self.B = nn.Parameter(
                                        torch.randn(input_dim, mapping_size) * scale,
                                        requires_grad=False
                                    )
                                    
                                def forward(self, x):
                                    """
                                    Args:
                                        x: Coordinates of shape (..., input_dim)
                                    Returns:
                                        Fourier features of shape (..., 2*mapping_size)
                                    """
                                    x_proj = 2 * np.pi * x @ self.B
                                    return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)

                            class NeRF(nn.Module):
                                """Simplified NeRF with Fourier positional encoding"""
                                def __init__(self, 
                                            pos_dim=3,
                                            view_dim=3,
                                            pos_encoding_freq=10,
                                            view_encoding_freq=4,
                                            hidden_dim=256):
                                    super().__init__()
                                    
                                    # Positional encoding dimensions
                                    pos_enc_dim = pos_dim * (2 * pos_encoding_freq + 1)
                                    view_enc_dim = view_dim * (2 * view_encoding_freq + 1)
                                    
                                    # Position encoding frequencies (following NeRF paper)
                                    self.pos_freq_bands = 2.**torch.linspace(0, pos_encoding_freq-1, pos_encoding_freq)
                                    self.view_freq_bands = 2.**torch.linspace(0, view_encoding_freq-1, view_encoding_freq)
                                    
                                    # MLP architecture (simplified)
                                    self.mlp1 = nn.Sequential(
                                        nn.Linear(pos_enc_dim, hidden_dim),
                                        nn.ReLU(),
                                        nn.Linear(hidden_dim, hidden_dim),
                                        nn.ReLU(),
                                        nn.Linear(hidden_dim, hidden_dim),
                                        nn.ReLU(),
                                        nn.Linear(hidden_dim, hidden_dim),
                                        nn.ReLU(),
                                    )
                                    
                                    self.density_head = nn.Linear(hidden_dim, 1)
                                    
                                    self.mlp2 = nn.Sequential(
                                        nn.Linear(hidden_dim + view_enc_dim, hidden_dim // 2),
                                        nn.ReLU(),
                                    )
                                    
                                    self.rgb_head = nn.Sequential(
                                        nn.Linear(hidden_dim // 2, 3),
                                        nn.Sigmoid()
                                    )
                                    
                                def positional_encoding(self, x, freq_bands):
                                    """Apply sinusoidal positional encoding"""
                                    out = [x]
                                    for freq in freq_bands:
                                        out.append(torch.sin(freq * x))
                                        out.append(torch.cos(freq * x))
                                    return torch.cat(out, dim=-1)
                                
                                def forward(self, pos, view_dir):
                                    """
                                    Args:
                                        pos: 3D positions (batch, 3)
                                        view_dir: Viewing directions (batch, 3)
                                    Returns:
                                        rgb: Colors (batch, 3)
                                        density: Volume density (batch, 1)
                                    """
                                    # Encode position with high-frequency features
                                    pos_enc = self.positional_encoding(pos, self.pos_freq_bands)
                                    
                                    # First part of network (position only)
                                    h = self.mlp1(pos_enc)
                                    density = self.density_head(h)
                                    
                                    # Encode viewing direction
                                    view_enc = self.positional_encoding(view_dir, self.view_freq_bands)
                                    
                                    # Second part (position features + viewing direction)
                                    h = self.mlp2(torch.cat([h, view_enc], dim=-1))
                                    rgb = self.rgb_head(h)
                                    
                                    return rgb, torch.relu(density)

                            # Rotary Position Embedding (RoPE) used in LLaMA, Mistral, etc.
                            class RotaryPositionalEmbedding(nn.Module):
                                """RoPE: Fourier-based positional encoding for transformers (Su et al., 2021)"""
                                def __init__(self, dim, max_seq_len=2048, base=10000):
                                    super().__init__()
                                    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
                                    self.register_buffer('inv_freq', inv_freq)
                                    
                                    # Precompute sin/cos for efficiency
                                    t = torch.arange(max_seq_len).float()
                                    freqs = torch.einsum('i,j->ij', t, self.inv_freq)
                                    self.register_buffer('cos_cached', torch.cos(freqs))
                                    self.register_buffer('sin_cached', torch.sin(freqs))
                                    
                                def forward(self, q, k, seq_len):
                                    """
                                    Apply rotary embeddings to query and key tensors
                                    
                                    Args:
                                        q, k: Tensors of shape (batch, seq_len, n_heads, head_dim)
                                        seq_len: Current sequence length
                                    """
                                    cos = self.cos_cached[:seq_len, :].unsqueeze(0).unsqueeze(2)
                                    sin = self.sin_cached[:seq_len, :].unsqueeze(0).unsqueeze(2)
                                    
                                    # Apply rotation using complex number multiplication
                                    q_rot = self.apply_rotation(q, cos, sin)
                                    k_rot = self.apply_rotation(k, cos, sin)
                                    
                                    return q_rot, k_rot
                                
                                def apply_rotation(self, x, cos, sin):
                                    """Apply the rotation to x"""
                                    # Split last dimension into pairs
                                    x1, x2 = x[..., ::2], x[..., 1::2]
                                    
                                    # Apply rotation matrix
                                    # [cos, -sin] [x1]
                                    # [sin,  cos] [x2]
                                    x_rot = torch.zeros_like(x)
                                    x_rot[..., ::2] = x1 * cos - x2 * sin
                                    x_rot[..., 1::2] = x1 * sin + x2 * cos
                                    
                                    return x_rot
                        


These applications demonstrate that Fourier methods are not just theoretical tools but are actively driving state-of-the-art performance in modern ML systems. From enabling million-token context windows to solving PDEs at unprecedented speeds, Fourier transforms remain essential for pushing the boundaries of what neural networks can achieve.

Interactive Demos

Now that we've covered the theory, let's see the Fourier Transform in action. We will start with a simple 1D signal to understand the core principles, then move to a 2D image to see a practical application (and a great example of the Convolution Theorem).

1D Signal & Frequency Spectrum

This demo shows how several sine waves (+ noise) are combined in the Time Domain to create a complex signal. Simultaneously, it computes the Fourier Transform (FFT) of that combined signal to show which frequencies are dominant in the Frequency Domain (the spectrum). Move the sliders to see how the two domains are linked in real-time. This is the fundamental concept of Fourier Analysis.

2D Image & Frequency Domain Filtering

An image is just a 2D signal. This demo computes the 2D FFT of an image to show its Frequency Spectrum. The center of the spectrum represents low frequencies (blurry parts of the image), while the edges represent high frequencies (edges and details). As you learned in the "Convolution Theorem" section, multiplying in the frequency domain is equivalent to convolution in the spatial domain. Try applying a Low-pass filter (which only keeps the center) to see the image blur, or a High-pass filter (which only keeps the edges) to extract details.