/** Matrix3D represents a 4x4 matrix of floats for performing various
    operations on points, which are represented as vectors.  
*/
public class Matrix3D {

  // Rep Invariant: matrix is always a 4x4 array of floats.
  private float[][] matrix;
  
  // Temporary matrix for use in transformations (speed enhancement) 
  private float[][] tmat;

  public Matrix3D() {
    // initialize with identity transform
    loadIdentity();
    tmat = new float[4][4];
  }
  
  public Matrix3D( Matrix3D copy ) {
    // initialize with copy of source
    matrix = (float[][]) copy.matrix.clone();
    tmat = new float[4][4];       
  }
  
  public Matrix3D( Raster r ) {
    // initialize with a mapping from canonical space to screen space
    
    // adapted from lecture 12 slide 40
     
    float width = (float) r.getWidth();
    float height = (float) r.getHeight();
    
    float[][] newm = { { width/2f,  0f, 0f,   width/2f },
		       { 0f, height/2f, 0f,   height/2f },
		       { 0f,        0f, 1f,         0f },
		       { 0f,        0f, 0f,         1f } };
    matrix = newm;
    tmat = new float[4][4];
    
  }

  

  public void set(int i, int j, float value) {
    matrix[i][j] = value;
  }

  public float get(int i, int j) {
    return matrix[i][j];
  }
  


  public void transform(Point3D in[], Point3D out[], 
			int start, int length) {
    // Transform points from the in array to the out array using the
    // current matrix.  The subset of points transformed begins at the
    // start index and has the specified length

    // Question: can in == out? (Assume no for now)

    for (int i=start; i<start+length; i++) {
      Point3D p = in[i];

      float factor = matrix[3][0]*p.x + matrix[3][1]*p.y + matrix[3][2]*p.z + matrix[3][3];

      // System.out.print(factor + " ");

      out[i] = new Point3D( matrix[0][0]*p.x + matrix[0][1]*p.y + 
			    matrix[0][2]*p.z + matrix[0][3],
			    matrix[1][0]*p.x + matrix[1][1]*p.y + 
			    matrix[1][2]*p.z + matrix[1][3],
			    matrix[2][0]*p.x + matrix[2][1]*p.y + 
			    matrix[2][2]*p.z + matrix[2][3] );
      out[i].x /= factor;
      out[i].y /= factor;
      out[i].z /= factor;

    }			    
  }



  public final void compose(Matrix3D src) {
    // this = this * src
    matrix = multiply(matrix, src.matrix);
  }
    
  public void loadIdentity() {
    // this = identity

    float[][] newm = { { 1f, 0f, 0f, 0f},
		       { 0f, 1f, 0f, 0f},
		       { 0f, 0f, 1f, 0f},
		       { 0f, 0f, 0f, 1f} };
    matrix = newm;
  }
   
  public void translate(float tx, float ty, float tz) {
    /*
    float[][] tmatrix = { { 1f, 0f, 0f, tx },
			  { 0f, 1f, 0f, ty },
			  { 0f, 0f, 1f, tz },
			  { 0f, 0f, 0f, 1f } };
    */
    
    tmat[0][0]=1f; tmat[0][1]=0f; tmat[0][2]=0f; tmat[0][3]=tx; 
    tmat[1][0]=0f; tmat[1][1]=1f; tmat[1][2]=0f; tmat[1][3]=ty; 
    tmat[2][0]=0f; tmat[2][1]=0f; tmat[2][2]=1f; tmat[2][3]=tz; 
    tmat[3][0]=0f; tmat[3][1]=0f; tmat[3][2]=0f; tmat[3][3]=1f; 
    
    
    matrix = multiply(matrix, tmat);
    
  }

  public void scale(float sx, float sy, float sz) {
    /*
    float[][] tmatrix = { { sx, 0f, 0f, 0f },
			  { 0f, sy, 0f, 0f },
			  { 0f, 0f, sz, 0f },
			  { 0f, 0f, 0f, 1f } };
    */	

    tmat[0][0]=sx; tmat[0][1]=0f; tmat[0][2]=0f; tmat[0][3]=0f; 
    tmat[1][0]=0f; tmat[1][1]=sy; tmat[1][2]=0f; tmat[1][3]=0f; 
    tmat[2][0]=0f; tmat[2][1]=0f; tmat[2][2]=sz; tmat[2][3]=0f; 
    tmat[3][0]=0f; tmat[3][1]=0f; tmat[3][2]=0f; tmat[3][3]=1f; 
    
    matrix = multiply(matrix, tmat);
  }
  
  public void skew(float kxy, float kxz, float kyz) {
    // taken from 6.837 lecture 12 slide 24
    /*
    float[][] tmatrix = { { 1f, kxy,  kxz, 0f },
			  { 0f, 1f,   kyz, 0f },
			  { 0f, 0f,   1f,  0f },
			  { 0f, 0f,   0f,  1f } };
    */


    tmat[0][0]=1f; tmat[0][1]=kxy; tmat[0][2]=kxz; tmat[0][3]=0f; 
    tmat[1][0]=0f; tmat[1][1]=1f; tmat[1][2]=kyz;  tmat[1][3]=0f; 
    tmat[2][0]=0f; tmat[2][1]=0f; tmat[2][2]=1f;   tmat[2][3]=0f; 
    tmat[3][0]=0f; tmat[3][1]=0f; tmat[3][2]=0f;   tmat[3][3]=1f; 

    matrix = multiply(matrix, tmat);
  }
  
  public void rotate(float ax, float ay, float az, float angle) { 
    // when we multiply these matrices by a scalar, do we also
    // multiply the fourth row and fourth column?  assuming no
    
    // normalize a
    float[] a = { ax, ay, az };
    a = normalize(a);
    ax = a[0];  ay = a[1];  az = a[2];

    // calc constants
    float cos = (float) Math.cos(angle);
    float sym = 1f - cos;
    float skew = (float) Math.sin(angle);
    
    tmat[0][0]= ax * ax * sym + cos; 
    tmat[0][1]= -az * skew + ax * ay * sym;   
    tmat[0][2]= ay * skew + ax*az * sym;
    tmat[0][3]= 0f;

    tmat[1][0]= az*skew + ax*ay*sym;
    tmat[1][1]= ay*ay*sym + cos; 
    tmat[1][2]= -ax*skew + ay*az*sym; 
    tmat[1][3]= 0f;

    tmat[2][0]= -ay*skew + ax*az*sym; 
    tmat[2][1]= ax*skew + ay*az*sym; 
    tmat[2][2]= az*az*sym + cos; 
    tmat[2][3]= 0f; 

    tmat[3][0]= 0f;
    tmat[3][1]= 0f;
    tmat[3][2]= 0f;
    tmat[3][3]= 1f;

    matrix = multiply(matrix, tmat);
  }
  
  public void lookAt(float eyex, float eyey, float eyez, 
		     float atx,  float aty,  float atz, 
		     float upx,  float upy,  float upz) {
    float[] l  = { atx-eyex, aty-eyey, atz-eyez };
    float[] u = { upx, upy, upz };
    float[] eye = {eyex, eyey, eyez };
    // r = l cross up
    float[] r = cross(l, u);
    
    l = normalize(l);
    r = normalize(r);
    u = normalize(u);

    float[][] tmatrix = { {  r[0],  r[1],  r[2], -dot(r,eye) },
			  {  u[0],  u[1],  u[2], -dot(u,eye) },
			  { -l[0], -l[1], -l[2],  dot(l,eye) },
			  {     0,     0,     0,      1 } };
    matrix = multiply(matrix, tmatrix);

    
  }



  // Assume the following projection transformations transform points
  // into the canonical viewing space
  public void perspective(float left, float right,
			  float bottom, float top,
			  float near, float far) {
    // copied from 6.837 lecture 12 slide 44
    float[][] tmatrix = 
    { { 2f*near/(right-left), 0f, -(right+left)/(right-left), 0f },
      { 0f, 2f*near/(bottom-top), -(bottom+top)/(bottom-top), 0f },
      { 0f, 0f, (far+near)/(far-near),   -2f*far*near/(far-near) },
      { 0f, 0f,     		   1f,			      0f } };
    
    matrix = multiply(matrix, tmatrix);

  }
  
  public void orthographic(float left  , float right, 
			   float bottom, float top,
			   float near  , float far) {
    // copied from 6.837 lecture 12 slide 44
    float[][] tmatrix = 
    { { 2f/(right-left), 0f, 0f, -(right+left)/(right-left) },
      { 0f, 2f/(bottom-top), 0f, -(bottom+top)/(bottom-top) },
      { 0f, 0f, 2f/(far-near),       -(far+near)/(far-near) },
      { 0f, 0f,            0f, 			       1f } };
    
    matrix = multiply(matrix, tmatrix);
  }
  
  private float[][] multiply(float[][] src1, float[][] src2) {
    // effects: returns matrix*src.  Note that order of arguments is
    // significant for matrix multiplies.

    float[][] tmatrix = { { 1f, 0f, 0f, 0f},
			  { 0f, 1f, 0f, 0f},
			  { 0f, 0f, 1f, 0f},
			  { 0f, 0f, 0f, 1f} };

    for (int row=0; row < 4; row++) {
      for (int col=0; col < 4; col++) {
	tmatrix[row][col] = 

	  src1[row][0]*src2[0][col] +
	  src1[row][1]*src2[1][col] +
	  src1[row][2]*src2[2][col] +
	  src1[row][3]*src2[3][col];
      }
    }

    return tmatrix;
  }
	    
  
  private float dot(float[] a, float[] b) {
    // requires: a and b are 3 element arrays
    float newm = a[0]*b[0] + a[1]*b[1] + a[2]*b[2];
    return newm;
  }

  private float[] cross(float[] a, float[] b) {
    // requires: a and b are 3 element arrays
    float[] newm = {   a[1]*b[2] - a[2]*b[1],
		       a[2]*b[0] - a[0]*b[2],
		       a[0]*b[2] - a[1]*b[0] };
    return newm;
  }
  
  private float[] normalize(float[] a) {
    // requires: a is a 3 element array
    float factor = (float) Math.sqrt(a[0]*a[0] + a[1]*a[1] + a[2]*a[2]);
    float[] newm = { a[0]/factor, a[1]/factor, a[2]/factor };
    return newm;
  }

  public void print() {
    System.out.println(get(0,0) +"\t"+   get(1,0) +"\t"+ 
		       get(2,0) +"\t"+ get(3,0) +"\n"+	
		       get(0,1) +"\t"+ get(1,1) +"\t"+ 
		       get(2,1) +"\t"+   get(3,1) +"\n"+
		       get(0,2) +"\t"+ get(1,2) +"\t"+ 
		       get(2,2) +"\t"+ get(3,2) +"\n"+
		       get(0,3) +"\t"+ get(1,3) +"\t"+ 
		       get(2,3) +"\t"+ get(3,3) );
  }	

}

