import Vertex3D;

public class Matrix3D {
    private float m[];

    public Matrix3D()      // null constructor allows for extension
    {
        m = new float[16];
        loadIdentity();
    }

//    public Matrix3D(ZRaster r)
//    {
//        m = new float[16];
//        float w = r.width / 2;
//        float h = r.height / 2;
//        float d = ZRaster.MAXZ / 2;
//        m[ 0] = w;  m[ 1] = 0;  m[ 2] = 0;  m[ 3] = w;
//        m[ 4] = 0;  m[ 5] = h;  m[ 6] = 0;  m[ 7] = h;
//        m[ 8] = 0;  m[ 9] = 0;  m[10] = d;  m[11] = d;
//        m[12] = 0;  m[13] = 0;  m[14] = 0;  m[15] = 1;
//    }

    public Matrix3D(Matrix3D copy)    // makes a copy of the matrix
    {
        m = new float[16];
        System.arraycopy(copy, 0, m, 0, 16);
    }

    public void multiply(Matrix3D m3d, float scale) {
        Matrix3D temp = new Matrix3D();
        for (int col = 0; col < 4; col++) {
            for (int row = 0; row < 4; row++) {
                temp.set(row, col, get(row, 0) * m3d.get(0, col) * scale +
                                   get(row, 1) * m3d.get(1, col) * scale +
                                   get(row, 2) * m3d.get(2, col) * scale +
                                   get(row, 3) * m3d.get(3, col) * scale);
            }
        }
        m = temp.m;
    }

    // NEW
    public void rotate_to_z_axis(float mag,
                                 float a, float b, float c, float d,
                                 float x, float y, float z) {
        Matrix3D Rx = new Matrix3D();
        Matrix3D Ry = new Matrix3D();
        
        Rx.set(1, 1, c/d);
        Rx.set(1, 2, -b/d);
        Rx.set(2, 1, b/d);
        Rx.set(2, 2, c/d);
        
        Ry.set(0, 0, d);
        Ry.set(0, 2, -a);
        Ry.set(2, 0, a);
        Ry.set(2, 2, d);
        
        Ry.multiply(Rx, mag);
        Ry.translate(-x, -y, -z);
        
        m = Ry.m;
        
        //m[ 0] = d;   m[ 1] = -a*b/d;   m[ 2] = -a*c/d;   m[ 3] = -x*d + a*z;
        //m[ 4] = 0f;  m[ 5] = c/d;      m[ 6] = -b/d;     m[ 7] = -y;
        //m[ 8] = a;   m[ 9] = b;        m[10] = c;        m[11] = -a*x-d*z;
        //m[12] = 0f;  m[13] = 0f;       m[14] = 0f;       m[15] = 1f;
    }

    /*
        ... Methods for setting and getting matrix elements ...
    */
    public void set(int j, int i, float val)
    {
        m[4*j+i] = val;
    }

    public float get(int j, int i)
    {
        return m[4*j+i];
    }

    protected void set(int i, float val)
    {
        m[i] = val;
    }

    protected float get(int i)
    {
        return m[i];
    }


    public final void copy(Matrix3D src)
    {
        System.arraycopy(src, 0, m, 0, 16);
    }

    public void transform(Vertex3D in[], Vertex3D out[], int vertices)
    {
        for (int i = 0; i < vertices; i++) {
            out[i].x = m[0]*in[i].x + m[1]*in[i].y + m[2]*in[i].z + m[3]*in[i].w;
            out[i].y = m[4]*in[i].x + m[5]*in[i].y + m[6]*in[i].z + m[7]*in[i].w;
            out[i].z = m[8]*in[i].x + m[9]*in[i].y + m[10]*in[i].z + m[11]*in[i].w;
            out[i].w = m[12]*in[i].x + m[13]*in[i].y + m[14]*in[i].z + m[15]*in[i].w;
            
            //if (in[i].hasNormal) {
            //    // First, compute normal transformation matrix
            //    //                  T
            //    //       -1 T      C
            //    // Q =( m  )  = --------
            //    //               det(m)
            //    Matrix3D q = new Matrix3D();
            //    float d = det3();
            //    q = compute_ct(d);
            //    // Tranform normal using Q
            //    //out[i].nx = q.m[0]*in[i].nx + q.m[1]*in[i].ny + q.m[2]*in[i].nz;
            //    //out[i].ny = q.m[4]*in[i].nx + q.m[5]*in[i].ny + q.m[6]*in[i].nz;
            //    //out[i].nz = q.m[8]*in[i].nx + q.m[9]*in[i].ny + q.m[10]*in[i].nz;
            //}
        }
    }

    public Vertex3D transform(Vertex3D v)
    {
        float x, y, z, w;
        x = m[0]*v.x + m[1]*v.y + m[2]*v.z + m[3]*v.w;
        y = m[4]*v.x + m[5]*v.y + m[6]*v.z + m[7]*v.w;
        z = m[8]*v.x + m[9]*v.y + m[10]*v.z + m[11]*v.w;
        w = m[12]*v.x + m[13]*v.y + m[14]*v.z + m[15]*v.w;

        w = 1 / w;
        Vertex3D result = new Vertex3D(x*w, y*w, z*w);

        //if (v.hasNormal) {
        //        // First, compute normal transformation matrix
        //        //                  T
        //        //       -1 T      C
        //        // Q =( m  )  = --------
        //        //               det(m)
        //        Matrix3D q = new Matrix3D();
        //        float d = det3();
        //        q = compute_ct(d);
        //        // Tranform normal using Q
        //        //result.nx = q.m[0]*v.nx + q.m[1]*v.ny + q.m[2]*v.nz;
        //        //result.ny = q.m[4]*v.nx + q.m[5]*v.ny + q.m[6]*v.nz;
        //        //result.nz = q.m[8]*v.nx + q.m[9]*v.ny + q.m[10]*v.nz;
        //}

        return result;
    }

    public final void compose(Matrix3D s)
    {
        float t0, t1, t2, t3;
        for (int i = 0; i < 16; i += 4) {
            t0 = m[i  ];
            t1 = m[i+1];
            t2 = m[i+2];
            t3 = m[i+3];
            m[i  ] = t0*s.get(0) + t1*s.get(4) + t2*s.get( 8) + t3*s.get(12);
            m[i+1] = t0*s.get(1) + t1*s.get(5) + t2*s.get( 9) + t3*s.get(13);
            m[i+2] = t0*s.get(2) + t1*s.get(6) + t2*s.get(10) + t3*s.get(14);
            m[i+3] = t0*s.get(3) + t1*s.get(7) + t2*s.get(11) + t3*s.get(15);
        }
    }

    public void loadIdentity()
    {
        for (int i = 0; i < 16; i++)
            if ((i >> 2) == (i & 3))
                m[i] = 1;
            else
                m[i] = 0;
    }

    public void translate(float tx, float ty, float tz)
    {
        m[ 3] += m[ 0]*tx + m[ 1]*ty + m[ 2]*tz;
        m[ 7] += m[ 4]*tx + m[ 5]*ty + m[ 6]*tz;
        m[11] += m[ 8]*tx + m[ 9]*ty + m[10]*tz;
        m[15] += m[12]*tx + m[13]*ty + m[14]*tz;
    }

    public void scale(float sx, float sy, float sz)
    {
        m[ 0] *= sx; m[ 1] *= sy; m[ 2] *= sz;
        m[ 4] *= sx; m[ 5] *= sy; m[ 6] *= sz;
        m[ 8] *= sx; m[ 9] *= sy; m[10] *= sz;
        m[12] *= sx; m[13] *= sy; m[14] *= sz;
    }

    public void rotate(float ax, float ay, float az, float angle)
    {
        float t0, t1, t2;

        if (angle == 0) return;          // return with m unmodified

        t0 = ax*ax + ay*ay + az*az;
        if (t0 == 0) return;

        float cosx = (float) Math.cos(angle);
        float sinx = (float) Math.sin(angle);
        t0 = 1f / ((float) Math.sqrt(t0));
        ax *= t0;
        ay *= t0;
        az *= t0;
        t0 = 1f - cosx;

        float r11 = ax*ax*t0 + cosx;
        float r22 = ay*ay*t0 + cosx;
        float r33 = az*az*t0 + cosx;

        t1 = ax*ay*t0;
        t2 = az*sinx;
        float r12 = t1 - t2;
        float r21 = t1 + t2;

        t1 = ax*az*t0;
        t2 = ay*sinx;
        float r13 = t1 + t2;
        float r31 = t1 - t2;

        t1 = ay*az*t0;
        t2 = ax*sinx;
        float r23 = t1 - t2;
        float r32 = t1 + t2;

        for (int i = 0; i < 16; i += 4) {
            t0 = m[i];
            t1 = m[i+1];
            t2 = m[i+2];
            m[i  ] = t0*r11 + t1*r21 + t2*r31;
            m[i+1] = t0*r12 + t1*r22 + t2*r32;
            m[i+2] = t0*r13 + t1*r23 + t2*r33;
        }
    }

    public void lookAt(float eyex, float eyey, float eyez,
                       float atx,  float aty,  float atz,
                       float upx,  float upy,  float upz)
    {
        float t0, t1, t2;

        /*
            .... a unit vector along the line of sight ....
        */
        atx -= eyex;
        aty -= eyey;
        atz -= eyez;

        t0 = atx*atx + aty*aty + atz*atz;
        if (t0 == 0) return;                // at and eye at same point
        t0 = (float) (1 / Math.sqrt(t0));
        atx *= t0;
        aty *= t0;
        atz *= t0;

        /*
            .... a unit vector to the right ....
        */
        float rightx, righty, rightz;
        rightx = aty*upz - atz*upy;
        righty = atz*upx - atx*upz;
        rightz = atx*upy - aty*upx;
        t0 = rightx*rightx + righty*righty + rightz*rightz;
        if (t0 == 0) return;                // up is the same as at
        t0 = (float) (1 / Math.sqrt(t0));
        rightx *= t0;
        righty *= t0;
        rightz *= t0;


        /*
            .... a unit up vector ....
        */
        upx = righty*atz - rightz*aty;
        upy = rightz*atx - rightx*atz;
        upz = rightx*aty - righty*atx;


        /*
            .... find camera translation ....
        */
        float tx, ty, tz;
        tx = rightx*eyex + righty*eyey + rightz*eyez;
        ty = upx*eyex + upy*eyey + upz*eyez;
        tz = atx*eyex + aty*eyey + atz*eyez;

        /*
            .... do transform ....
        */
        for (int i = 0; i < 16; i += 4) {
            t0 = m[i];
            t1 = m[i+1];
            t2 = m[i+2];
            m[i  ] = t0*rightx + t1*upx - t2*atx;
            m[i+1] = t0*righty + t1*upy - t2*aty;
            m[i+2] = t0*rightz + t1*upz - t2*atz;
            m[i+3] -= t0*tx + t1*ty - t2*tz;
        }
    }

    public void perspective(float left, float right,
                            float bottom, float top,
                            float near, float far)
    {
        float t0, t1, t2, t3;

        t0 = 1f / (right - left);
        t1 = 1f / (bottom - top);
        t2 = 1f / (far - near);

        float m13 = -t0*(right + left);
        float m23 = -t1*(bottom + top);
        float m33 = t2*(far + near);

        near *= 2;
        float m11 = t0*near;
        float m22 = t1*near;
        float m34 = -t2*far*near;

        for (int i = 0; i < 16; i += 4) {
            t0 = m[i];
            t1 = m[i+1];
            t2 = m[i+2];
            m[i  ] = t0*m11;
            m[i+1] = t1*m22;
            m[i+2] = t0*m13 + t1*m23 + t2*m33 + m[i+3];
            m[i+3] = t2*m34;
        }
    }

    public void orthographic(float left, float right,
                             float bottom, float top,
                             float near, float far)
    {
        float t0, t1, t2, t3;

        t0 = 1f / (right - left);
        t1 = 1f / (bottom - top);
        t2 = 1f / (far - near);

        float m11 = 2*t0;
        float m22 = 2*t1;
        float m33 = 2*t2;
        float m14 = -t0*(right + left);
        float m24 = -t1*(bottom + top);
        float m34 = -t2*(far + near);

        for (int i = 0; i < 16; i += 4) {
            t0 = m[i];
            t1 = m[i+1];
            t2 = m[i+2];
            m[i  ] = t0*m11;
            m[i+1] = t1*m22;
            m[i+2] = t2*m33;
            m[i+3] = t0*m14 + t1*m24 + t2*m34 + m[i+3];
        }
    }

   public String toString() {
        ///////////////////////////////////////////////////////////////
        // Pretty prints current tranformation matrix to a string
        ///////////////////////////////////////////////////////////////
        String str;
        int col_max[] = new int[4];
        col_max[0] = (new Float(get(0,0))).toString().length();
        col_max[1] = (new Float(get(0,1))).toString().length();
        col_max[2] = (new Float(get(0,2))).toString().length();
        col_max[3] = (new Float(get(0,3))).toString().length();
        for (int col = 0; col<=3; col++) {
            for (int row = 1; row<=3; row++) {
                int num = (new Float(get(row,col))).toString().length();
                if ( num > col_max[col] )
                    col_max[col] = num;
            }
        }

        //for (int col = 0; col<=3; col++) {
        //    System.out.println("col_max["+col+"]="+col_max[col]);
        //}

        str = "+-";
        for (int i = 1; i<=(col_max[0]+col_max[1]+col_max[2]+col_max[3]+3); i++) {
            str += " ";
        }
        str += "-+\n";
        for (int row = 0; row<=3; row++) {
            str += "| ";
            for (int col = 0; col<=3; col++) {
                for (int i = (new Float(get(row,col))).toString().length(); i < col_max[col]; i++) {
                    str += " ";
                }
                str += get(row,col);
                str += " ";
            }
            str += "|\n";
        }
        str += "+-";
        for (int i = 1; i<=(col_max[0]+col_max[1]+col_max[2]+col_max[3]+3); i++) {
            str += " ";
        }
        str += "-+";

        return str;
    }

    public String[] toString2() {
        ///////////////////////////////////////////////////////////////
        // Pretty prints current tranformation matrix to an array
        // of strings, one string per line.
        ///////////////////////////////////////////////////////////////
        String str[] = new String[6];
        int col_max[] = new int[4];
        col_max[0] = (new Float(get(0,0))).toString().length();
        col_max[1] = (new Float(get(0,1))).toString().length();
        col_max[2] = (new Float(get(0,2))).toString().length();
        col_max[3] = (new Float(get(0,3))).toString().length();
        for (int col = 0; col<=3; col++) {
            for (int row = 1; row<=3; row++) {
                int num = (new Float(get(row,col))).toString().length();
                if ( num > col_max[col] )
                    col_max[col] = num;
            }
        }

        //for (int col = 0; col<=3; col++) {
        //    System.out.println("col_max["+col+"]="+col_max[col]);
        //}

        str[0] = "+-";
        for (int i = 1; i<=(col_max[0]+col_max[1]+col_max[2]+col_max[3]+3); i++) {
            str[0] += " ";
        }
        str[0] += "-+";
        str[5] = str[0];

        for (int row = 0; row<=3; row++) {
            str[1+row] = "| ";
            for (int col = 0; col<=3; col++) {
                for (int i = (new Float(get(row,col))).toString().length(); i < col_max[col]; i++) {
                    str[1+row] += " ";
                }
                str[1+row] += get(row,col);
                str[1+row] += " ";
            }
            str[1+row] += "|";
        }

        return str;
    }
    
    public Matrix3D compute_ct() {
        //      +-                                               -+
        //  T   | a22*a33-a23*a32 a13*a32-a12*a33 a12*a23-a13*a22 |
        // C  = | a23*a31-a21*a33 a11*a33-a13*a31 a13*a21-a11*a23 |
        //      | a21*a32-a22*a31 a12*a31-a11*a32 a11*a22-a12*a21 |
        //      +-                                               -+
        Matrix3D ct = new Matrix3D();
        float c11, c12, c13, c21, c22, c23, c31, c32, c33;
        System.out.println(get(2,2)+"*"+get(3,3)+"-"+get(2,3)+"*"+get(3,2)); //debug
        ct.set(0,0,get(1,1)*get(2,2)-get(1,2)*get(2,1));
        ct.set(0,1,get(0,2)*get(2,1)-get(0,1)*get(2,2));
        ct.set(0,2,get(0,1)*get(1,2)-get(0,2)*get(1,1));
        ct.set(1,0,get(1,2)*get(2,0)-get(1,0)*get(2,2));
        ct.set(1,1,get(0,0)*get(2,2)-get(0,2)*get(2,0));
        ct.set(1,2,get(0,2)*get(1,0)-get(0,0)*get(1,2));
        ct.set(2,0,get(1,0)*get(2,1)-get(1,1)*get(2,0));
        ct.set(2,1,get(0,1)*get(2,0)-get(0,0)*get(2,1));
        ct.set(2,2,get(0,0)*get(1,1)-get(0,1)*get(1,0));
        return ct;
    }
    
    public Matrix3D compute_ct(float div) {
        //            +-                                               -+
        //  T     1   | a22*a33-a23*a32 a13*a32-a12*a33 a12*a23-a13*a22 |
        // C  = ----- | a23*a31-a21*a33 a11*a33-a13*a31 a13*a21-a11*a23 |
        //       div  | a21*a32-a22*a31 a12*a31-a11*a32 a11*a22-a12*a21 |
        //            +-                                               -+
        Matrix3D ct = new Matrix3D();
        float c11, c12, c13, c21, c22, c23, c31, c32, c33;
        float mul = 1 / div;
        //System.out.println(get(2,2)+"*"+get(3,3)+"-"+get(2,3)+"*"+get(3,2)); //debug
        ct.set(0,0,mul*(get(1,1)*get(2,2)-get(1,2)*get(2,1)));
        ct.set(0,1,mul*(get(0,2)*get(2,1)-get(0,1)*get(2,2)));
        ct.set(0,2,mul*(get(0,1)*get(1,2)-get(0,2)*get(1,1)));
        ct.set(1,0,mul*(get(1,2)*get(2,0)-get(1,0)*get(2,2)));
        ct.set(1,1,mul*(get(0,0)*get(2,2)-get(0,2)*get(2,0)));
        ct.set(1,2,mul*(get(0,2)*get(1,0)-get(0,0)*get(1,2)));
        ct.set(2,0,mul*(get(1,0)*get(2,1)-get(1,1)*get(2,0)));
        ct.set(2,1,mul*(get(0,1)*get(2,0)-get(0,0)*get(2,1)));
        ct.set(2,2,mul*(get(0,0)*get(1,1)-get(0,1)*get(1,0)));
        return ct;
    }

    public float det3() {
        //         | a11 a12 a12 |
        // retruns | a21 a22 a23 |
        //         | a31 a32 a33 |
        float det = get(0,0)*det2(get(1,1),get(1,2),get(2,1),get(2,2)) -
                    get(0,1)*det2(get(1,0),get(1,2),get(2,0),get(2,2)) +
                    get(0,2)*det2(get(1,0),get(1,1),get(2,0),get(2,1));
        return det;
    }

    public float det2(float a11, float a12,
                      float a21, float a22) {
        // retruns | a11 a12 | = a11*a22-a12*a21
        //         | a21 a22 |
        return (a11*a22-a12*a21);
    }
}
