import Raster;
import Point3D;

public class Matrix3D
{
  public float m[][];
  
  //
  // Constructors
  //
  public Matrix3D()
       // initialize with identity transform
  {
    m = new float[4][4];
    this.loadIdentity();
  }
  
  public Matrix3D(Matrix3D copy)
       // initialize with copy of source
  {
    m = new float[4][4];
    for (int i = 0; i < 4; i++)
      for (int j = 0; j < 4; j++)
	m[j][i] = copy.m[j][i];
  }

  public String toString()
    {
      String ret = "[";
      for (int j = 0; j < 4; j++)
	{
	  ret = ret + "[";
	  for (int i = 0; i < 4; i++)
	    {
	      ret = ret + (new Float(m[j][i])).toString();
	      if (i != 3)
		ret = ret + " ";
	    }
	  ret = ret + ((j == 3) ? "]" : "]\n ");
	}
      return ret + "]\n";
    }
  
  public Matrix3D(Raster r)
    //initialize with a mapping from canonical space to screen space
    {
      m = new float[4][4];
      this.loadIdentity();
      this.translate(r.getWidth() / 2, r.getHeight() / 2, 0);
      this.scale(r.getWidth() / 2, r.getHeight() / 2, 0);
    }
  
  //
  // General interface methods
  //        
  public void set(int i, int j, float value)
       // set element [j][i] to value
  {
    m[j][i] = value;
  }

  public float get(int i, int j)
       // return element [j][i]
  {
    return m[j][i];
  }
  
  //
  // Transform points from the in array to the out array
  // using the current matrix. The subset of points transformed
  // begins at the start index and has the specified length
  //
  //    for (i = 0; i < length; i++)
  //        out[start+i] = this * in[start+i]
  //

  public void transform(Point3D in[], Point3D out[], int start, int length)
  {
    for (int k = start; k < length + start; k++)
      {
	float psp = in[k].x * m[3][0] +
	  in[k].y * m[3][1] +
	  in[k].z * m[3][2] +
	  m[3][3];
	out[k].x = (in[k].x * m[0][0] +
	  in[k].y * m[0][1] +
	  in[k].z * m[0][2] +
	  m[0][3]) / psp;
	out[k].y = (in[k].x * m[1][0] +
	  in[k].y * m[1][1] +
	  in[k].z * m[1][2] +
	  m[1][3]) / psp;
	out[k].z = (in[k].x * m[2][0] +
	  in[k].y * m[2][1] +
	  in[k].z * m[2][2] +
	  m[2][3]) / psp;
      }
  }
  
  public final void compose(Matrix3D src)
       // this = this * src
  {
    float sum;
    float m2[][] = new float[4][4];
    for (int i = 0; i < 4; i++)
      for (int j = 0; j < 4; j++)
	{
	  sum = 0;
	  for (int k = 0; k < 4; k++)
	    sum += m[i][k] * src.m[k][j];
	  m2[i][j] = sum;
	}
    m = m2;
  }

  public void loadIdentity()
       // this = identity
  {
    for (int i = 0; i < 4; i++)
      for (int j = 0; j < 4; j++)
	m[j][i] = (j == i) ? 1 : 0;
  }

  public void translate(float tx, float ty, float tz)
       // this = this * t
  {
    for (int i = 0; i < 4; i++)
      m[i][3] += tx * m[i][0] + ty * m[i][1] + tz * m[i][2];
  }

  public void scale(float sx, float sy, float sz)
  // this = this * scale
  {
    for (int i = 0; i < 3; i++)
      {
	m[i][0] *= sx;
	m[i][1] *= sy;
	m[i][2] *= sz;
      }
  }
  
  public void skew(float kxy, float kxz, float kyz)
       // this = this * skew
  {
    m[0][2] += m[0][0] * kxz + m[0][1] * kyz;
    m[1][2] += m[1][1] * kyz;
    m[0][1] += m[0][0] * kxy;
  }

  public void rotate(float ax, float ay, float az, float angle)
       // this = this * rotate 
  {
    double axisLen = java.lang.Math.sqrt(ax * ax + ay * ay + az * az);
    ax /= axisLen;
    ay /= axisLen;
    az /= axisLen;
    float rest[][] = new float[4][4];
    float mul = 1 - (float) java.lang.Math.cos(angle);
    int j;
    // Symmetric matrix
    for (int i = 0; i < 3; i++)
      for (j = 0; j < 3; j++)
	rest[j][i] = mul;
    for (int k = 0; k < 3; k++)
      {
	rest[k][0] *= ax;
	rest[k][1] *= ay;
	rest[k][2] *= az;
	rest[0][k] *= ax;
	rest[1][k] *= ay;
	rest[2][k] *= az;
      }
    mul = 1 - mul; // cos angle
    // identity matrix
    rest[0][0] += mul;
    rest[1][1] += mul;
    rest[2][2] += mul;
    mul = (float) java.lang.Math.sin(angle);
    // skew symmetric matrix
    rest[1][0] += az * mul;
    rest[0][1] += -az * mul;
    rest[2][0] += -ay * mul;
    rest[0][2] += ay * mul;
    rest[2][1] += ax * mul;
    rest[1][2] += -ax * mul;
    float m2[][] = new float[4][4];
    float sum;
    for (int i = 0; i < 3; i++)
      for (j = 0; j < 3; j++)
	{
	  sum = 0;
	  for (int k = 0; k < 3; k++)
	    sum += m[k][i] * rest[j][k];
	  m2[j][i] = sum;
	}
    m2[3][0] = 0;
    m2[3][1] = 0;
    m2[3][2] = 0;
    m2[3][3] = 1;
    m2[0][3] = m[0][3];
    m2[1][3] = m[1][3];
    m2[2][3] = m[2][3];
    m = m2;
  }

  public void lookAt(float eyex, float eyey, float eyez,
		     float atx,  float aty,  float atz,
		     float upx,  float upy,  float upz)
    // this = this * lookat
    // I assume that eye is the location of the eye,
    // at is the point of focus in world space,
    // and up is a vector in world space which should be (0, 1, 0) in eye
    // space.
    // Also, we end up with +y up, and -z forward.
    // lookAt(0,0,0, 0,0,-1, 0,1,0) is the identity.
  {
    float lx = atx - eyex;
    float ly = aty - eyey;
    float lz = atz - eyez;
    
    float mag = (float) java.lang.Math.sqrt(lx * lx + ly * ly + lz * lz);
    
    lx /= mag;
    ly /= mag;
    lz /= mag;
    
    float rx = ly * upz - lz * upy;
    float ry = lz * upx - lx * upz;
    float rz = lx * upy - ly * upx;
    
    mag = (float) java.lang.Math.sqrt(rx * rx + ry * ry + rz * rz);
    
    rx /= mag;
    ry /= mag;
    rz /= mag;
    
    float ux = ry * lz - rz * ly;
    float uy = rz * lx - rx * lz;
    float uz = rx * ly - ry * lx;
    
    mag = (float) java.lang.Math.sqrt(ux * ux + uy * uy + uz * uz);
    
    ux /= mag;
    uy /= mag;
    uz /= mag;

    Matrix3D lookat = new Matrix3D();
    
    lookat.m[0][0] = rx;
    lookat.m[0][1] = ry;
    lookat.m[0][2] = rz;
    lookat.m[0][3] = - rx * eyex - ry * eyey - rz * eyez;

    lookat.m[1][0] = ux;
    lookat.m[1][1] = uy;
    lookat.m[1][2] = uz;
    lookat.m[1][3] = - ux * eyex - uy * eyey - uz * eyez;

    lookat.m[2][0] = -lx;
    lookat.m[2][1] = -ly;
    lookat.m[2][2] = -lz;
    lookat.m[2][3] = lx * eyex + ly * eyey + lz * eyez;

    this.compose(lookat);
  }

  //
  // Assume the following projection transformations
  // transform points into the canonical viewing space
  //
  public void perspective(float left, float right,
			  float bottom, float top,
			  float near, float far)
    // this = this * persp
    
  {
    //    this.orthographic(left, right, bottom, top, near, far);
    //float tmp;
    //tmp = m[0][3];
    //m[0][3] = m[0][2];
    //m[0][2] = tmp;
    //tmp = m[1][3];
    //m[1][3] = m[1][2];
    //m[1][2] = tmp;
//    tmp = m[2][3];
 //   m[2][3] = m[2][2];
  //  m[2][2] = tmp;
    //tmp = m[3][3];
    //m[3][3] = m[3][2];
    //m[3][2] = tmp;
    Matrix3D persp = new Matrix3D();
    persp.m[0][0] = (2 * near)/(right-left);
    persp.m[1][1] = (2 * near)/(bottom-top);
    persp.m[0][2] = -(right+left)/(right-left);
    persp.m[1][2] = -(bottom+top)/(bottom-top);
    persp.m[2][2] = (far+near)/(far-near);
    persp.m[3][2] = 1;
    persp.m[2][3] = (-2 * far * near)/(far-near);
    persp.m[3][3] = 0;
    this.compose(persp);
  }

  public void orthographic(float left, float right,
			   float bottom, float top,
			   float near, float far)
    // this = this * ortho
    // Takes left, top, near to -1, -1, -1 and
    //       right, bottom, far to 1, 1, 1
    {
      this.scale(2/(right-left), 2/(bottom-top), 2/(far-near));
      this.translate(-(right+left)/(right-left),
		     -(top+bottom)/(bottom-top),
		     -(far+near)/(far-near));
    }
}
