/////////////////////
// Sidney Chang    //
// 6.837 Project 3 //
// TA: Jacob       //
// 11/8/98         //
/////////////////////

public class Matrix3D {

  // Variables
  public float matrix[][];

  // Constructors
  public Matrix3D() {                  // initialize with identity transform
    matrix = new float[4][4];
    matrix[0][0] = 1; matrix[0][1] = 0; matrix[0][2] = 0; matrix[0][3] = 0;
    matrix[1][0] = 0; matrix[1][1] = 1; matrix[1][2] = 0; matrix[1][3] = 0;
    matrix[2][0] = 0; matrix[2][1] = 0; matrix[2][2] = 1; matrix[2][3] = 0;
    matrix[3][0] = 0; matrix[3][1] = 0; matrix[3][2] = 0; matrix[3][3] = 1;
  }

  public Matrix3D(Matrix3D copy) {     // initialize with copy of source
    matrix = new float[4][4];
    for (int i=0; i<4; i++ ) {
      for (int j=0; j<4; j++) {
	matrix[i][j] = copy.matrix[i][j];
      }
    }
  }

  public Matrix3D(Raster r) {         // initialize with a mapping from
    matrix = new float[4][4];         // canonical space to screen space
    matrix[0][0] = r.getWidth()/2; matrix[0][1] = 0; matrix[0][2] = 0; matrix[0][3] = r.getWidth()/2;
    matrix[1][0] = 0; matrix[1][1] = -r.getHeight()/2; matrix[1][2] = 0; matrix[1][3] = r.getHeight()/2;
    matrix[2][0] = 0; matrix[2][1] = 0; matrix[2][2] = 0.5f; matrix[2][3] = 0.5f;
    matrix[3][0] = 0; matrix[3][1] = 0; matrix[3][2] = 0; matrix[3][3] = 1;
  } 


  // General interface methods
  public void set(int i, int j, float value) { 
    // set element [i][j] to value (row, column)
    if (i >= 4 || i < 0 || j >= 4 || j < 0)
      System.err.println("Invalid index, matrix is 4x4");
    matrix[i][j] = value;
  }

  public float get(int i, int j) {
    // return element [i][j] (row, column)
    if (i >= 4 || i < 0 || j >= 4 || j < 0)
      System.err.println("Invalid index, matrix is 4x4");
    return matrix[i][j];
  }

  // Transform points from the in array to the out array
  // using the current matrix. The subset of points transformed
  // begins at the start index and has the specified length
  //
  //    for (i = 0; i < length; i++)
  //        out[start+i] = this * in[start+i]
  //
  public void transform(Point3D in[], Point3D out[], int start, int length) {
    for (int i=0; i<length; i++) {
      out[start+i] = multiply(in[start+i]);
    }
  }

  public Point3D multiply(Point3D p) {
    Point3D new_p = new Point3D();
    new_p.x = matrix[0][0]*p.x + matrix[0][1]*p.y + matrix[0][2]*p.z + matrix[0][3];
    new_p.y = matrix[1][0]*p.x + matrix[1][1]*p.y + matrix[1][2]*p.z + matrix[1][3];
    new_p.z = matrix[2][0]*p.x + matrix[2][1]*p.y + matrix[2][2]*p.z + matrix[2][3];
    float w = matrix[3][0]*p.x + matrix[3][1]*p.y + matrix[3][2]*p.z + matrix[3][3];

    if (w != 1.0f) {
      new_p.x = new_p.x/w;
      new_p.y = new_p.y/w;
      new_p.z = new_p.z/w;
    }
    return new_p;
  }

  public final void compose(Matrix3D src) {    // this = this * src
    Matrix3D temp = new Matrix3D(this);
    for (int i=0; i<4; i++) {
      for (int j=0; j<4; j++) {
	matrix[i][j] = 
	  temp.matrix[i][0]*src.matrix[0][j]+
	  temp.matrix[i][1]*src.matrix[1][j]+
	  temp.matrix[i][2]*src.matrix[2][j]+
	  temp.matrix[i][3]*src.matrix[3][j];
      }
    }
  }

  public void loadIdentity() { 
    // this = identity
    matrix[0][0] = 1; matrix[0][1] = 0; matrix[0][2] = 0; matrix[0][3] = 0;
    matrix[1][0] = 0; matrix[1][1] = 1; matrix[1][2] = 0; matrix[1][3] = 0;
    matrix[2][0] = 0; matrix[2][1] = 0; matrix[2][2] = 1; matrix[2][3] = 0;
    matrix[3][0] = 0; matrix[3][1] = 0; matrix[3][2] = 0; matrix[3][3] = 1;
  }

  public void translate(float tx, float ty, float tz) {            
    // this = this * t
    Matrix3D temp = new Matrix3D();
    temp.matrix[0][3] = tx;
    temp.matrix[1][3] = ty;
    temp.matrix[2][3] = tz;
    compose(temp);
  }
  public void scale(float sx, float sy, float sz) {
    // this = this * scale
    Matrix3D temp = new Matrix3D();
    temp.matrix[0][0] = sx;
    temp.matrix[1][1] = sy;
    temp.matrix[2][2] = sz;
    compose(temp);
  }

  public void skew(float kxy, float kxz, float kyz) {
    // this = this * skew

    // calculate skew symmetric matrix
    Matrix3D skew = new Matrix3D();
    skew.matrix[0][0] = 0;    skew.matrix[0][1] = -kyz; skew.matrix[0][2] = kxz;
    skew.matrix[1][0] = kyz;  skew.matrix[1][1] = 0;    skew.matrix[1][2] = -kxy;
    skew.matrix[2][0] = -kxz; skew.matrix[2][1] = kxy;  skew.matrix[2][2] = 0;    
    compose(skew);
  }

  public void rotate(float ax, float ay, float az, float angle) {
    // this = this * rotate 

    float a_mag = (float)(Math.sqrt(ax*ax+ay*ay+az*az));
    ax = ax/a_mag;
    ay = ay/a_mag;
    az = az/a_mag;

    // calculate symmetric matrix of given vector
    float sym[][] = new float[3][3];
    sym[0][0] = ax*ax; sym[0][1] = ax*ay; sym[0][2] = ax*az;
    sym[1][0] = ax*ay; sym[1][1] = ay*ay; sym[1][2] = ay*az;
    sym[2][0] = ax*az; sym[2][1] = ay*az; sym[2][2] = az*az;

    // calculate skew symmetric matrix
    float skew[][] = new float[3][3];
    skew[0][0] = 0;   skew[0][1] = -az; skew[0][2] = ay;
    skew[1][0] = az;  skew[1][1] = 0;   skew[1][2] = -ax;
    skew[2][0] = -ay; skew[2][1] = ax;  skew[2][2] = 0;

    // an identity matrix
    float ident[][] = new float[3][3];
    ident[0][0] = 1; ident[0][1] = 0; ident[0][2] = 0;
    ident[1][0] = 0; ident[1][1] = 1; ident[1][2] = 0;
    ident[2][0] = 0; ident[2][1] = 0; ident[2][2] = 1;

    // rotation matrix is 
    // symmetric*(1-cos(angle))+skew*sin(angle)+identity*cos(angle)
    Matrix3D temp = new Matrix3D();
    for (int i=0; i<3; i++) {
      for (int j=0; j<3; j++) {
	temp.matrix[i][j] = (float)(sym[i][j]*(1-Math.cos(angle))+
	  skew[i][j]*(Math.sin(angle))+
	  ident[i][j]*(Math.cos(angle)));
      }
    }

    compose(temp);
  }

  public void lookAt(float eyex, float eyey, float eyez,
		     float atx,  float aty,  float atz,
		     float upx,  float upy,  float upz) {          
    // this = this * lookat

    // calculate l^ = lookat^ - eye^
    Point3D l = new Point3D();
    l.x = atx - eyex;
    l.y = aty - eyey;
    l.z = atz - eyez;
    // normalize
    float l_mag = (float)(Math.sqrt(l.x*l.x+l.y*l.y+l.z*l.z));
    Point3D l_unit = new Point3D();
    l_unit.x = l.x/l_mag;
    l_unit.y = l.y/l_mag;
    l_unit.z = l.z/l_mag;

    // calculate r^ = l^ x up^
    Point3D r = new Point3D();
    r.x = -l.z*upy +  l.y*upz;
    r.y =  l.z*upx + -l.x*upz;
    r.z = -l.y*upx +  l.x*upy;
    float r_mag = (float)(Math.sqrt(r.x*r.x+r.y*r.y+r.z*r.z));
    Point3D r_unit = new Point3D();
    r_unit.x = r.x/r_mag;
    r_unit.y = r.y/r_mag;
    r_unit.z = r.z/r_mag;

    // calculate u^ = r^ x l^
    Point3D u = new Point3D();
    u.x = -r.z*l.y +  r.y*l.z;
    u.y =  r.z*l.x + -r.x*l.z;
    u.z = -r.y*l.x +  r.x*l.y;    
    float u_mag = (float)(Math.sqrt(u.x*u.x+u.y*u.y+u.z*u.z));
    Point3D u_unit = new Point3D();
    u_unit.x = u.x/u_mag;
    u_unit.y = u.y/u_mag;
    u_unit.z = u.z/u_mag;

    Matrix3D temp = new Matrix3D();
 
    temp.matrix[0][0] = r_unit.x;
    temp.matrix[0][1] = r_unit.y;
    temp.matrix[0][2] = r_unit.z;
    temp.matrix[0][3] = -(r_unit.x*eyex + r_unit.y*eyey + r_unit.z*eyez);

    temp.matrix[1][0] = u_unit.x;
    temp.matrix[1][1] = u_unit.y;
    temp.matrix[1][2] = u_unit.z;
    temp.matrix[1][3] = -(u_unit.x*eyex + u_unit.y*eyey + u_unit.z*eyez);

    temp.matrix[2][0] = -l_unit.x;
    temp.matrix[2][1] = -l_unit.y;
    temp.matrix[2][2] = -l_unit.z;
    temp.matrix[2][3] = l_unit.x*eyex + l_unit.y*eyey + l_unit.z*eyez;

    temp.matrix[3][0] = 0;
    temp.matrix[3][1] = 0;
    temp.matrix[3][2] = 0;
    temp.matrix[3][3] = 1;

    compose(temp);
  }

  // Assume the following projection transformations
  // transform points into the canonical viewing space
  public void perspective(float left, float right,  // this = this * persp
			  float bottom, float top,
			  float near, float far) {

    Matrix3D temp = new Matrix3D();
    
    temp.matrix[0][0] = (2*near)/(right-left);
    temp.matrix[0][1] = 0;
    temp.matrix[0][2] = -(right+left)/(right-left);
    temp.matrix[0][3] = 0;

    temp.matrix[1][0] = 0;
    temp.matrix[1][1] = -(2*near)/(bottom-top);
    temp.matrix[1][2] = (bottom+top)/(bottom-top);
    temp.matrix[1][3] = 0;
    
    temp.matrix[2][0] = 0;
    temp.matrix[2][1] = 0;
    temp.matrix[2][2] = (far+near)/(far-near);
    temp.matrix[2][3] = -(2*far*near)/(far-near);
    
    temp.matrix[3][0] = 0;
    temp.matrix[3][1] = 0;
    temp.matrix[3][2] = 1;
    temp.matrix[3][3] = 0;

    compose(temp);
  }

  public void orthographic(float left, float right, // this = this * ortho
			   float bottom, float top,
			   float near, float far) {

    Matrix3D temp = new Matrix3D();

    temp.matrix[0][0] = 2/(right-left);
    temp.matrix[0][1] = 0;
    temp.matrix[0][2] = 0;
    temp.matrix[0][3] = -(right+left)/(right-left);

    temp.matrix[1][0] = 0;
    temp.matrix[1][1] = -2/(bottom-top);
    temp.matrix[1][2] = 0;
    temp.matrix[1][3] = (bottom+top)/(bottom-top);

    temp.matrix[2][0] = 0;
    temp.matrix[2][1] = 0;
    temp.matrix[2][2] = 2/(far-near);
    temp.matrix[2][3] = -(far+near)/(far-near);

    temp.matrix[3][0] = 0;
    temp.matrix[3][1] = 0;
    temp.matrix[3][2] = 0;
    temp.matrix[3][3] = 1;
    
    compose(temp);
  }

  public String toString() {
    return new String("[ " + matrix[0][0] + " " + matrix[0][1] + " " + matrix[0][2] + " " + matrix[0][3] +
		      "\n  " + matrix[1][0] + " " + matrix[1][1] + " " + matrix[1][2] + " " + matrix[1][3] +
		      "\n  " + matrix[2][0] + " " + matrix[2][1] + " " + matrix[2][2] + " " + matrix[2][3] +
		      "\n  " + matrix[3][0] + " " + matrix[3][1] + " " + matrix[3][2] + " " + matrix[3][3] +
		      " ]");
      
  }
  
}

