import Vertex3D;

public class Matrix3D {
    private float m[];

    public Matrix3D()      // null constructor allows for extension
    {
        m = new float[16];
        loadIdentity();
    }

    public Matrix3D(ZRaster r)
    {
        m = new float[16];
        float w = r.width / 2;
        float h = r.height / 2;
        float d = ZRaster.MAXZ / 2;
        m[ 0] = w;  m[ 1] = 0;  m[ 2] = 0;  m[ 3] = w;
        m[ 4] = 0;  m[ 5] = h;  m[ 6] = 0;  m[ 7] = h;
        m[ 8] = 0;  m[ 9] = 0;  m[10] = d;  m[11] = d;
        m[12] = 0;  m[13] = 0;  m[14] = 0;  m[15] = 1;
    }

    public Matrix3D(Matrix3D copy)    // makes a copy of the matrix
    {
        m = new float[16];
        System.arraycopy(copy, 0, m, 0, 16);
    }


    /*
        ... Methods for setting and getting matrix elements ...
    */
    public void set(int j, int i, float val)
    {
        m[4*j+i] = val;
    }

    public float get(int j, int i)
    {
        return m[4*j+i];
    }

    protected void set(int i, float val)
    {
        m[i] = val;
    }
    
    protected float get(int i)
    {
        return m[i];
    }


    public final void copy(Matrix3D src)
    {
        System.arraycopy(src, 0, m, 0, 16);
    }

  

    public void transform(Vertex3D in[], Vertex3D out[], int vertices)
    {
        for (int i = 0; i < vertices; i++) {
            out[i].x = m[0]*in[i].x + m[1]*in[i].y + m[2]*in[i].z + m[3]*in[i].w;
            out[i].y = m[4]*in[i].x + m[5]*in[i].y + m[6]*in[i].z + m[7]*in[i].w;
            out[i].z = m[8]*in[i].x + m[9]*in[i].y + m[10]*in[i].z + m[11]*in[i].w;
            out[i].w = m[12]*in[i].x + m[13]*in[i].y + m[14]*in[i].z + m[15]*in[i].w;

            if (in[i].hasNormal) {
	      Matrix3D n = invtrans44(this);
	      
	      float w = n.m[12]*in[i].nx + n.m[13]*in[i].ny + n.m[14]*in[i].nz + n.m[15];
	      
	      out[i].setNormal(n.m[0]*in[i].nx + n.m[1]*in[i].ny + n.m[2]*in[i].nz + n.m[3], n.m[4]*in[i].nx + n.m[5]*in[i].ny + n.m[6]*in[i].nz + n.m[7], n.m[8]*in[i].nx + n.m[9]*in[i].ny + n.m[10]*in[i].nz + n.m[11]);

            }
	    out[i].r = in[i].r;
	    out[i].g = in[i].g;
	    out[i].b = in[i].b;
	    out[i].hasColor = false;
        }
    }

    public Vertex3D transform(Vertex3D v)
    {
        float x, y, z, w;
        x = m[0]*v.x + m[1]*v.y + m[2]*v.z + m[3]*v.w;
        y = m[4]*v.x + m[5]*v.y + m[6]*v.z + m[7]*v.w;
        z = m[8]*v.x + m[9]*v.y + m[10]*v.z + m[11]*v.w;
        w = m[12]*v.x + m[13]*v.y + m[14]*v.z + m[15]*v.w;

        w = 1 / w;
        Vertex3D result = new Vertex3D(x*w, y*w, z*w);
 
        if (v.hasNormal) {
	  Matrix3D n = invtrans44(this);
	  float tempx, tempy, tempz, tempw;
	  tempx = n.m[0]*v.nx + n.m[1]*v.ny + n.m[2]*v.nz + n.m[3];
	  tempy = n.m[4]*v.nx + n.m[5]*v.ny + n.m[6]*v.nz + n.m[7];
	  tempz = n.m[8]*v.nx + n.m[9]*v.ny + n.m[10]*v.nz + n.m[11];
	  tempw = n.m[12]*v.nx + n.m[13]*v.ny + n.m[14]*v.nz + n.m[15];
	  v.setNormal(tempx/tempw, tempy/ tempw, tempz/ tempw);
	  
        }
        v.hasColor = false;
	result.r = v.r;
	result.g = v.g;
	result.b = v.b;
        return result;
    }
  
  public static Matrix3D invtrans44(Matrix3D a) {
    Matrix3D m = new Matrix3D();
    float a11, a12, a13, a14, a21, a22, a23, a24, a31, a32, a33, a34, a41, a42, a43, a44;

    a11 = a.get(0,0);
    a12 = a.get(0,1);
    a13 = a.get(0,2);
    a14 = a.get(0,3);
    a21 = a.get(1,0);
    a22 = a.get(1,1);
    a23 = a.get(1,2);
    a24 = a.get(1,3);
    a31 = a.get(2,0);
    a32 = a.get(2,1);
    a33 = a.get(2,2);
    a34 = a.get(2,3);
    a41 = a.get(3,0);
    a42 = a.get(3,1);
    a43 = a.get(3,2);
    a44 = a.get(3,3);

    float det = det44(a11, a12, a13, a14, a21, a22, a23, a24, a31, a32, a33, a34, a41, a42, a43, a44);

    m.set(0,0, det33(a22, a23, a24, a32, a33, a34, a42, a43, a44) / det);
    m.set(0,1, -det33(a21, a23, a24, a31, a33, a34, a41, a43, a44) / det);
    m.set(0,2, det33(a21, a22, a24, a31, a32, a34, a41, a42, a44) / det);
    m.set(0,3, -det33(a21, a22, a23, a31, a32, a33, a41, a42, a43) / det);

    m.set(1,0, -det33(a12, a13, a14, a32, a33, a34, a42, a43, a44) / det);
    m.set(1,1, det33(a11, a13, a14, a31, a33, a34, a41, a43, a44) / det);
    m.set(1,2, -det33(a11, a12, a14, a31, a32, a34, a41, a42, a44) / det);
    m.set(1,3, det33(a11, a12, a13, a31, a32, a33, a41, a42, a43) / det);

    m.set(2,0, det33(a12, a13, a14, a22, a23, a24, a42, a43, a44) / det);
    m.set(2,1, -det33(a11, a13, a14, a21, a23, a24, a41, a43, a44) / det);
    m.set(2,2, det33(a11, a12, a14, a21, a22, a24, a41, a42, a44) / det);
    m.set(2,3, -det33(a11, a12, a13, a21, a22, a23, a41, a42, a43) / det);

    m.set(3,0, -det33(a12, a13, a14, a22, a23, a24, a32, a33, a34) / det);
    m.set(3,1, det33(a11, a13, a14, a21, a23, a24, a31, a33, a34) / det);
    m.set(3,2, -det33(a11, a12, a14, a21, a22, a24, a31, a32, a34) / det);
    m.set(3,3, det33(a11, a12, a13, a21, a22, a23, a31, a32, a33) / det);

    return m;
  }

  


  public static float det44(float a11, float a12, float a13, float a14, float a21, float a22, float a23, float a24, float a31, float a32, float a33, float a34, float a41, float a42, float a43, float a44) {
    float det = (a11*det33(a22, a23, a24, a32, a33, a34, a42, a43, a44) -
		 a12*det33(a21, a23, a24, a31, a33, a34, a41, a43, a44) +
		 a13*det33(a21, a22, a24, a31, a32, a34, a41, a42, a44) -
		 a14*det33(a21, a22, a23, a31, a32, a33, a41, a42, a43));
    return det;
  }
  
  public static float det33(float a11, float a12, float a13, float a21, float a22, float a23, float a31, float a32, float a33) {
    float det = a11*(a22*a33-a23*a32) + a12*(a31*a23 - a21*a33) + a13*(a21*a32 - a22*a31);
    return det;
  }
 
    public final void compose(Matrix3D s)
    {
        float t0, t1, t2, t3;
        for (int i = 0; i < 16; i += 4) {
            t0 = m[i  ];
            t1 = m[i+1];
            t2 = m[i+2];
            t3 = m[i+3];
            m[i  ] = t0*s.get(0) + t1*s.get(4) + t2*s.get( 8) + t3*s.get(12);
            m[i+1] = t0*s.get(1) + t1*s.get(5) + t2*s.get( 9) + t3*s.get(13);
            m[i+2] = t0*s.get(2) + t1*s.get(6) + t2*s.get(10) + t3*s.get(14);
            m[i+3] = t0*s.get(3) + t1*s.get(7) + t2*s.get(11) + t3*s.get(15);
        }
    }

    public void loadIdentity()
    {
        for (int i = 0; i < 16; i++)
            if ((i >> 2) == (i & 3))
                m[i] = 1;
            else
                m[i] = 0;
    }

    public void translate(float tx, float ty, float tz)
    {
        m[ 3] += m[ 0]*tx + m[ 1]*ty + m[ 2]*tz;
        m[ 7] += m[ 4]*tx + m[ 5]*ty + m[ 6]*tz;
        m[11] += m[ 8]*tx + m[ 9]*ty + m[10]*tz;
        m[15] += m[12]*tx + m[13]*ty + m[14]*tz;
    }

    public void scale(float sx, float sy, float sz)
    {
        m[ 0] *= sx; m[ 1] *= sy; m[ 2] *= sz;
        m[ 4] *= sx; m[ 5] *= sy; m[ 6] *= sz;
        m[ 8] *= sx; m[ 9] *= sy; m[10] *= sz;
        m[12] *= sx; m[13] *= sy; m[14] *= sz;
    }

    public void rotate(float ax, float ay, float az, float angle)
    {
        float t0, t1, t2;

        if (angle == 0) return;          // return with m unmodified

        t0 = ax*ax + ay*ay + az*az;
        if (t0 == 0) return;

        float cosx = (float) Math.cos(angle);
        float sinx = (float) Math.sin(angle);
        t0 = 1f / ((float) Math.sqrt(t0));
        ax *= t0;
        ay *= t0;
        az *= t0;
        t0 = 1f - cosx;

        float r11 = ax*ax*t0 + cosx;
        float r22 = ay*ay*t0 + cosx;
        float r33 = az*az*t0 + cosx;

        t1 = ax*ay*t0;
        t2 = az*sinx;
        float r12 = t1 - t2;
        float r21 = t1 + t2;

        t1 = ax*az*t0;
        t2 = ay*sinx;
        float r13 = t1 + t2;
        float r31 = t1 - t2;

        t1 = ay*az*t0;
        t2 = ax*sinx;
        float r23 = t1 - t2;
        float r32 = t1 + t2;

        for (int i = 0; i < 16; i += 4) {
            t0 = m[i];
            t1 = m[i+1];
            t2 = m[i+2];
            m[i  ] = t0*r11 + t1*r21 + t2*r31;
            m[i+1] = t0*r12 + t1*r22 + t2*r32;
            m[i+2] = t0*r13 + t1*r23 + t2*r33;
        }
    }

    public void lookAt(float eyex, float eyey, float eyez,
                       float atx,  float aty,  float atz,
                       float upx,  float upy,  float upz)
    {
        float t0, t1, t2;

        /*
            .... a unit vector along the line of sight ....
        */
        atx -= eyex;
        aty -= eyey;
        atz -= eyez;

        t0 = atx*atx + aty*aty + atz*atz;
        if (t0 == 0) return;                // at and eye at same point
        t0 = (float) (1 / Math.sqrt(t0));
        atx *= t0;
        aty *= t0;
        atz *= t0;

        /*
            .... a unit vector to the right ....
        */
        float rightx, righty, rightz;
        rightx = aty*upz - atz*upy;
        righty = atz*upx - atx*upz;
        rightz = atx*upy - aty*upx;
        t0 = rightx*rightx + righty*righty + rightz*rightz;
        if (t0 == 0) return;                // up is the same as at
        t0 = (float) (1 / Math.sqrt(t0));
        rightx *= t0;
        righty *= t0;
        rightz *= t0;


        /*
            .... a unit up vector ....
        */
        upx = righty*atz - rightz*aty;
        upy = rightz*atx - rightx*atz;
        upz = rightx*aty - righty*atx;


        /*
            .... find camera translation ....
        */
        float tx, ty, tz;
        tx = rightx*eyex + righty*eyey + rightz*eyez;
        ty = upx*eyex + upy*eyey + upz*eyez;
        tz = atx*eyex + aty*eyey + atz*eyez;

        /*
            .... do transform ....
        */
        for (int i = 0; i < 16; i += 4) {
            t0 = m[i];
            t1 = m[i+1];
            t2 = m[i+2];
            m[i  ] = t0*rightx + t1*upx - t2*atx;
            m[i+1] = t0*righty + t1*upy - t2*aty;
            m[i+2] = t0*rightz + t1*upz - t2*atz;
            m[i+3] -= t0*tx + t1*ty - t2*tz;
        }
    }

    public void perspective(float left, float right,
                            float bottom, float top,
                            float near, float far)
    {
        float t0, t1, t2, t3;

        t0 = 1f / (right - left);
        t1 = 1f / (bottom - top);
        t2 = 1f / (far - near);

        float m13 = -t0*(right + left);
        float m23 = -t1*(bottom + top);
        float m33 = t2*(far + near);

        near *= 2;
        float m11 = t0*near;
        float m22 = t1*near;
        float m34 = -t2*far*near;

        for (int i = 0; i < 16; i += 4) {
            t0 = m[i];
            t1 = m[i+1];
            t2 = m[i+2];
            m[i  ] = t0*m11;
            m[i+1] = t1*m22;
            m[i+2] = t0*m13 + t1*m23 + t2*m33 + m[i+3];
            m[i+3] = t2*m34;
        }
    }
    
    public void orthographic(float left, float right,
                             float bottom, float top,
                             float near, float far)
    {
        float t0, t1, t2, t3;

        t0 = 1f / (right - left);
        t1 = 1f / (bottom - top);
        t2 = 1f / (far - near);

        float m11 = 2*t0;
        float m22 = 2*t1;
        float m33 = 2*t2;
        float m14 = -t0*(right + left);
        float m24 = -t1*(bottom + top);
        float m34 = -t2*(far + near);

        for (int i = 0; i < 16; i += 4) {
            t0 = m[i];
            t1 = m[i+1];
            t2 = m[i+2];
            m[i  ] = t0*m11;
            m[i+1] = t1*m22;
            m[i+2] = t2*m33;
            m[i+3] = t0*m14 + t1*m24 + t2*m34 + m[i+3];
        }
    }

    public String toString()
    {
        return ("[ ["+m[ 0]+", "+m[ 1]+", "+m[ 2]+", "+m[ 3]+" ], ["+
                      m[ 4]+", "+m[ 5]+", "+m[ 6]+", "+m[ 7]+" ], ["+
                      m[ 8]+", "+m[ 9]+", "+m[10]+", "+m[11]+" ], ["+
                      m[12]+", "+m[13]+", "+m[14]+", "+m[15]+" ] ]");
    }
}
