FFT3

 

/**

 * If u and t are complex points, then an in-place butterfly is: 

 * u = u + Wt; 

 * t = u - Wt;

 * where, W is a root of unity.

 */

 

float->float filter Butterfly(float w_re, float w_im) {

 

    init {

    }

 

    work push 4 pop 4 {

        float u_re, u_im, t_re, t_im, wt_re, wt_im;

        u_re = pop();

        u_im = pop();

        t_re = pop();

        t_im = pop();

 

        /* compute Wt */

        wt_re = w_re * t_re - w_im * t_im;

        wt_im = w_re * t_im + w_im * t_re;

        /* the butterfly computation: u = u + Wt; and t = u - Wt; */

        /* note the order: t is computed first */

        t_re = u_re - wt_re;

        t_im = u_im - wt_im;

        u_re = u_re + wt_re;

        u_im = u_im + wt_im;

 

        push(u_re);

        push(u_im);

        push(t_re);

        push(t_im);

    }

}

 

 

 

/** 

 * A butterfly group of a particular fft stage is a set

 * of butterflies that use the same root of unity W. Or

 * graphically, each butterfly group is a bunch of 

 * butterflies that are clustered together in the FFT's

 * butterfly graph for that particular stage

 */    

float->float splitjoin ButterflyGroup(float w_re, float w_im, int numbflies) {

    split roundrobin(2);

    for (int b=0; b < numbflies; b++)

        add Butterfly(w_re, w_im);

    join roundrobin(2);

}

 

/**

 * ComputeStage is a set of butterfly groups and implements a 

 * a particular FFT stage (which is determined by D). 

 */ 

 

float->float splitjoin ComputeStage(int D, int N, float[N/2] W_re3, float[N/2] W_im3) {

    /* the length of a butterfly group in terms of points */

    int grplen = D + D; 

    /* the number of bfly groups and bflies per group */

    int numgrps = N/grplen;

    int numbflies = D;

    float w_re, w_im;

 

    split roundrobin(2*grplen);

     for (int g=0; g<numgrps; g++)

    {

      /* each butterfly group uses only one root of unity. actually, it is the bitrev of this group's number g.

       * BUT 'bitrev' it as a logn-1 bit number because we are using a lookup array of nth root of unity and

       * using cancellation lemma to scale nth root to n/2, n/4,... th root.

       *

       * Basically, it turns out like the foll.

       *   w_re = W_re[bitrev(g, logn-1)];

       *   w_im = W_im[bitrev(g, logn-1)];

       * Still, we just use g, because the lookup array itself is bit-reversal permuted.

       */

      w_re = W_re3[g]; 

      w_im = W_im3[g];

 

      add ButterflyGroup(w_re, w_im, numbflies);

    } 

    join roundrobin(2*grplen);

} 

   

 

/**

 * Same as ComputeStage but this class is only for the last fft stage, where there are

 * there are N/2 butterfly groups with ONLY one butterfly each (So, the butterfly group

 * degenerates from a splitjoin to a filter)

 */

float->float splitjoin LastComputeStage(int D, int N, float[N/2] W_re5, float[N/2] W_im5) {

 

    /* the length of a butterfly group in terms of points */

    int grplen = D + D; 

    /* the number of bfly groups and bflies per group */

    int numgrps = N/grplen;

    int numbflies = D;

    float w_re, w_im;

 

    split roundrobin(2*grplen);

    /* for each butterfly group in this stage */

    for (int g=0; g<numgrps; g++)

    {

      /* see ComputeStage class for details on indexing W[] */ 

      w_re = W_re5[g]; 

      w_im = W_im5[g];

 

      /* bflygrp of the last fft stage is simply a bfly */

      add Butterfly(w_re, w_im);

    } 

    join roundrobin(2*grplen);

  } 

 

 

float->float filter BitReverse(int N, int logN) {

 

    init {}

 

    //note: values of push, pop rates should really be determined by N

    //hardcoded to 32 for now

    //only works if N=16

    work push N*2 pop N*2 peek N*2 {

        for (int i=0; i<N; i++)

        {

            int br = bitrev(i, logN);

            /* treat real and imag part of one point together */  

            push(peek(2*br));

            push(peek(2*br+1));

        }  

        for (int i=0; i<N; i++)

        { 

            pop();

            pop();

        }

    }

 

    int bitrev(int inp, int numbits)

    {

        int i, rev=0;

        for (i=0; i < numbits; i++)

        {

            rev = (rev * 2) | (inp & 1);

            inp = inp / 2;

        }

        return rev;

    }

}

 

 

/**

 * The top-level stream construct of the FFT kernel 

 */

float->float pipeline FFT3Kernel(int N, int logN, float[N/2] W_re2, float[N/2] W_im2) {

 

    add ButterflyGroup(W_re2[0], W_im2[0], N/2);

   

      /* the middle ComputeStages - N/2i bflygrps with i bflies each */

    for (int i=(N/4); i>1; i=i/2)

      add ComputeStage(i, N, W_re2, W_im2);

 

    /* the last ComputeStage - N/2 bflygrps with 1 bfly each */

    add LastComputeStage(1, N, W_re2, W_im2);

 

    /* the bit-reversal phase */

    add BitReverse(N, logN);

}

 

 

/**

 * Sends out real followed by imag part of each input point 

 */

 

void->float filter FloatSource(int N) {

 

    float[N] A_re;

    float[N] A_im;

    int idx;

 

    init {

         for (int i=0; i<N; i++)

         {

             A_re[i] = 0.0;

             A_im[i] = 0.0;

         }

         A_re[1] = 1.0;

         idx = 0;

    }

 

    work push 2 {

        push(A_re[idx]);

        push(A_im[idx]);

        idx++;

        if (idx >= N)

            idx=0;

    }

}

 

 

/**

 * Prints the real and imag part of each output point

 */

float->void filter FloatPrinter {

    init {}

 

    work pop 2

    {

        print(pop());

        print(pop());

    }

}

 

void->void pipeline FFT3 {

   

    int N = 32;

    int logN = 5;

    float[N/2] W_re1;

    float[N/2] W_im1; 

 

    for (int i=0; i<(N/2); i++)

    {

        //int br = bitrev(i, logN-1); 

        //inlined bitrev to help constant propagation

        int br=0;

        int j, temp;

        temp=i;

        for (j=0; j < (logN-1); j++)

        {

                br = (br * 2) | (temp & 1);

                temp = temp / 2;

        }

        W_re1[br] = cos((i*2.0*pi)/N);

        W_im1[br] = sin((i*2.0*pi)/N);

    }

  

    add FloatSource(N);

    add FFT3Kernel(N, logN, W_re1, W_im1);

    add FloatPrinter();

}