Thursday, January 17, 2013

Scientific computing: Fast Fourier transform

Thing practiced: programming

Tools used: Visual Studio 2012 on a Windows machine

(Quick background: The Fourier transform lets us translate any irregular signal into a combination of pure sine waves. The numerical version of this is the discrete Fourier transform (DFT), which is ubiquitous today – it’s responsible for jpegs and mp3s, among other things – thanks to the efficient fast Fourier transform (FFT) family of algorithms.)

Let’s try to implement the FFT in F#. First, here’s the naive DFT algorithm (a direct translation to code of the definition):

let slow_dft xs =
    let N = Array.length xs
    [| for k in 0 .. N - 1 ->
           let mutable x_k = Complex.Zero
           for n = 0 to N - 1 do
               let a = 2.0 * Math.PI / float N *
                       float k * float n
               x_k <- x_k + xs.[n] *
                      Complex((cos a), (sin a))
           x_k |]

This gives an exact answer, but with poor performance – this has O(N2) time complexity. To improve performance, we first reduce problem scope by assuming that the number of input terms is an integral power of 2, letting us use the radix-2 DIT divide-and-conquer algorithm. This splits the DFT into two DFTs – a DFT of the even-numbered and a DFT of the odd-numbered terms:

which gets rearranged as:

This article explains further. Code:

let radix2_fft xs =
    let bit_rev arr =
        let N = Array.length arr
        let mutable j = 0
        let swap (a : _[]) i j =
            let tmp = a.[j]
            a.[j] <- a.[i]
            a.[i] <- tmp
        let rec helper m j =
            let m = m / 2
            let j = j ^^^ m
            if j &&& m = 0 then helper m j else j
        for i = 0 to N - 2 do
            if i < j then swap arr i j
            j <- helper N j
    let radix2_helper (xs : Complex[]) N j m =
        let t = Math.PI * float m / float j
        let a = Complex((cos t), (sin t))
        let mutable i = m
        while i < N do
            let x_i = xs.[i]
            let t = a * xs.[i+j]
            xs.[i] <- x_i + t
            xs.[i+j] <- x_i - t
            i <- i + 2 * j
    let N = Array.length xs
    bit_rev xs
    let mutable j = 1
    while j < N do
        for m = 0 to j - 1 do
            radix2_helper xs N j m
        j <- 2 * j

This is fast – O(n log n) – but only handles signals with lengths of integral powers of 2. To use the radix-2 algorithm for a signal of arbitrary length, we need to increase length to the next smallest power of 2. We can’t zero-pad the signal itself, but if we express the transform as a convolution then pad the convolution we’ll get an exact answer. (My hour’s up, but if you’re interested the convolution algorithm is explained here.)