package GfxLib;

public class Matrix3D
{
  //
  // Shared data
  //
  private static final float[][] IDENTITY_ELEMS = {
    {1, 0, 0, 0},
    {0, 1, 0, 0},
    {0, 0, 1, 0},
    {0, 0, 0, 1}
  };

  private static Matrix3D temp1;

  //
  // Public member variables
  //
  public int ID;

  //
  // Private member variables
  //
  private float[][] m_elems;


  //
  // Class Constructors
  //

  public Matrix3D()
  {
    m_elems = new float[4][4];
    this.loadIdentity();
  }

  public Matrix3D(Matrix3D copy)
  {
    m_elems = new float[4][4];
    for (int arr = 0; arr < 4; arr++)
      {
	System.arraycopy(copy.getElements()[arr], 0, m_elems[arr], 0, 4);
      }
  }

       
  public void copy(Matrix3D copy)
  {
    for (int arr = 0; arr < 4; arr++)
      {
	System.arraycopy(copy.getElements()[arr], 0, m_elems[arr], 0, 4);
      }
  }

  public Matrix3D(Raster r)
  {
    m_elems = new float[4][4];
    
    this.loadIdentity();
    m_elems[0][0] = r.getWidth()/2;
    m_elems[1][1] = r.getHeight()/2;
    m_elems[0][3] = r.getWidth()/2;
    m_elems[1][3] = r.getHeight()/2;
    m_elems[2][3] = 0f;
    m_elems[2][2] = 0f;
  }

  public Matrix3D(Raster r, float near, float far)
  {
    m_elems = new float[4][4];
    
    this.loadIdentity();
    m_elems[0][0] = r.getWidth()/2;
    m_elems[1][1] = r.getHeight()/2;
    m_elems[0][3] = r.getWidth()/2;
    m_elems[1][3] = r.getHeight()/2;
    m_elems[2][3] = 256/(far-near);
    m_elems[2][2] = near*256/(far-near);
  }

  //
  // Data access methods
  //

  public final float get(int i, int j)
  {
    return m_elems[j][i];
  }

  public final float[][] getElements()
  {
    return m_elems;
  }

  public final void set(int i, int j, float value)
  {
    m_elems[j][i] = value;
  }

  //
  // Transforms
  //

  public final void compose(Matrix3D src)
  {
    float[][] newElems  = { {0, 0, 0, 0}, {0, 0, 0, 0}, 
			    {0, 0, 0, 0}, {0, 0, 0, 0} };
    float[][] srcElems  = src.getElements();
    float[][] tempElems = m_elems; 

    for (int j = 0; j < 4; j++)
      {
	for (int i = 0; i < 4; i++)
	  {
	    newElems[j][i] = tempElems[j][0] * srcElems[0][i] + 
	      tempElems[j][1] * srcElems[1][i] + 
	      tempElems[j][2] * srcElems[2][i] + 
	      tempElems[j][3] * srcElems[3][i];
	  }   
      }

    m_elems = newElems;
  }

  public void lookAt(float eyex, float eyey, float eyez,
		     float atx, float aty, float atz, 
		     float upx, float upy, float upz)
  {
    float lx, ly, lz, lxHat, lyHat, lzHat, lNormal;
    float rx, ry, rz, rxHat, ryHat, rzHat, rNormal;
    float ux, uy, uz, uxHat, uyHat, uzHat, uNormal;
    Matrix3D viewMatrix = new Matrix3D();
    viewMatrix.loadIdentity();

    lx = atx - eyex;
    ly = aty - eyey;
    lz = atz - eyez;
    lNormal = (float)Math.sqrt(lx*lx + ly*ly + lz*lz);
    lxHat = lx / lNormal;
    lyHat = ly / lNormal;
    lzHat = lz / lNormal;

    rx = (ly * upz) - (lz * upy);
    ry = (lz * upx) - (lx * upz);
    rz = (lx * upy) - (ly * upx);
    rNormal = (float)Math.sqrt(rx*rx + ry*ry + rz*rz);
    rxHat = rx / rNormal;
    ryHat = ry / rNormal;
    rzHat = rz / rNormal;

    ux = (ry * lz) - (rz * ly);
    uy = (rz * lx) - (rx * lz);
    uz = (rx * ly) - (ry * lx);
    uNormal = (float)Math.sqrt(ux*ux + uy*uy + uz*uz);
    uxHat = ux / uNormal;
    uyHat = uy / uNormal;
    uzHat = uz / uNormal;

	//  Check for consistency w/ old matrix classs!!!
    
    viewMatrix.set(0, 0, rxHat);
    viewMatrix.set(1, 0, ryHat);
    viewMatrix.set(2, 0, rzHat);
    viewMatrix.set(3, 0, -((rxHat * eyex) + (ryHat * eyey) + (rzHat * eyez)));
    viewMatrix.set(0, 1, uxHat);
    viewMatrix.set(1, 1, uyHat);
    viewMatrix.set(2, 1, uzHat);
    viewMatrix.set(3, 1, -((uxHat * eyex) + (uyHat * eyey) + (uzHat * eyez)));
    viewMatrix.set(0, 2, -lxHat);
    viewMatrix.set(1, 2, -lyHat);
    
    viewMatrix.set(2, 2, -lzHat);
    viewMatrix.set(3, 2, ((lxHat * eyex) + (lyHat * eyey) + (lzHat * eyez)));
    
    this.compose(viewMatrix);
  }

  public void orthographic(float left, float right,
			  float bottom, float top,
			  float near, float far)
  {
    float[][] pMatrixElems;
    Matrix3D pMatrix;

    pMatrix = temp1;
    pMatrix.loadIdentity();
    pMatrixElems = pMatrix.getElements();

    pMatrixElems[0][0] =   2f/(right-left);
    pMatrixElems[0][3] =  -(right + left)/(right-left);
    pMatrixElems[1][1] =   2f/(bottom-top);
    pMatrixElems[1][3] =  -(bottom+top)/(bottom-top);
    pMatrixElems[2][2] =   2f/(far-near);
    pMatrixElems[2][3] =  -(far+near)/(far-near);    

    this.compose(pMatrix);
  }

  public void perspective(float left, float right,
			  float bottom, float top,
			  float near, float far)
  {
    float[][] pMatrixElems;
    Matrix3D pMatrix;

    pMatrix = temp1;
    pMatrix.loadIdentity();
    pMatrixElems = pMatrix.getElements();

    pMatrixElems[0][0] =   (2f * near)/(right - left);
    pMatrixElems[0][2] =  -(right + left)/(right-left);
    pMatrixElems[1][1] =   (2f*near)/(bottom-top);
    pMatrixElems[1][2] =  -(bottom+top)/(bottom-top);
    pMatrixElems[2][2] =   (far+near)/(far-near);
    pMatrixElems[2][3] =  -(2f*far*near)/(far-near);
    pMatrixElems[3][2] =   1f;
    pMatrixElems[3][3] =   0f;
    
    this.compose(pMatrix);
  }

  public void rotate(float ax, float ay, float az, float angle)
  {
    float aNormal, axHat, ayHat, azHat;
    float cosTheta, sinTheta;

    Matrix3D tempMatrix = temp1;
    tempMatrix.loadIdentity();
    float[][] tempMatrixElems = tempMatrix.getElements();

    cosTheta = (float)Math.cos(angle);
    sinTheta = (float)Math.sin(angle);

    // Normalize the axis of rotation
    aNormal = (float)Math.sqrt(ax*ax + ay*ay + az*az);
    axHat   = ax/aNormal;
    ayHat   = ay/aNormal;
    azHat   = az/aNormal;
    
    // Combine more mults?
    float SymACoeffX = (1f - cosTheta)*axHat;
    float SymACoeffY = (1f - cosTheta)*ayHat;
    
    tempMatrixElems[0][0] = SymACoeffX*axHat + cosTheta;
    tempMatrixElems[0][1] = SymACoeffX*ayHat + (-azHat*sinTheta);
    tempMatrixElems[0][2] = SymACoeffX*azHat + ( ayHat*sinTheta);
    tempMatrixElems[1][0] = SymACoeffX*ayHat + ( azHat*sinTheta);
    tempMatrixElems[1][1] = SymACoeffY*ayHat + cosTheta;
    tempMatrixElems[1][2] = SymACoeffY*azHat + (-axHat*sinTheta);
    tempMatrixElems[2][0] = SymACoeffX*azHat + (-ayHat*sinTheta);
    tempMatrixElems[2][1] = SymACoeffY*azHat + ( axHat*sinTheta);
    tempMatrixElems[2][2] = (1f-cosTheta)*azHat*azHat + cosTheta;

    this.sparseCompose(tempMatrix);
  }

  public void scale(float sx, float sy, float sz)
  {
    m_elems[0][0] *= sx;
    m_elems[1][1] *= sy;
    m_elems[2][2] *= sz;

    translate(((1 - sx) * m_elems[0][3]),
	      ((1 - sy) * m_elems[1][3]),
	      ((1 - sz) * m_elems[2][3]));
  } 

  public void shear(float kxy, float kxz, float kyz)
  {
    float[][] tempElems = m_elems;

    tempElems[0][2] += kxz*tempElems[0][0] + kyz*tempElems[0][1];
    tempElems[1][2] += kxz*tempElems[1][0] + kyz*tempElems[1][1];
    tempElems[2][2] += kxz*tempElems[2][0] + kyz*tempElems[2][1];
    tempElems[3][2] += kxz*tempElems[3][0] + kyz*tempElems[3][1];
    tempElems[0][1] += kxy*tempElems[0][0];
    tempElems[1][1] += kxy*tempElems[1][0];
    tempElems[2][1] += kxy*tempElems[2][0];
    tempElems[3][1] += kxy*tempElems[3][0]; 
  }

  public final void sparseCompose(Matrix3D src)
  {
    float[][] newElems = { {0, 0, 0, 0}, {0, 0, 0, 0}, 
			   {0, 0, 0, 0}, {0, 0, 0, 0} };
    float[][] srcElems = src.getElements();
    float[][] tempElems = m_elems;

    for (int j = 0; j < 4; j++)
      {
	for (int i = 0; i < 3; i++)
	  {
	    newElems[j][i] = tempElems[j][0] * srcElems[0][i] + 
	      tempElems[j][1] * srcElems[1][i] + 
	      tempElems[j][2] * srcElems[2][i];
	  }   
	newElems[j][3] = tempElems[j][3];
      }
    m_elems = newElems;
  }

  


  public void transform(Point3D in[], Point3D out[], int start, int length)
  {
    int       end;   
    float     in_x, in_y, in_z;
    float     normalizer;
    float[][] tempElems;

    end       = start + length;
    tempElems = m_elems;

    for (; start < end; start++)
      {
	in_x = in[start].x;
	in_y = in[start].y;
	in_z = in[start].z;

	//System.out.println("Transforming: " + start + ", " + in_x + ", " +
	//	   in_y + ", " + in_z);

	normalizer = (in_x * tempElems[3][0] +
		      in_y * tempElems[3][1] + 
		      in_z * tempElems[3][2] +
		      tempElems[3][3]);
	out[start].x = (in_x * tempElems[0][0] + 
			in_y * tempElems[0][1] +
			in_z * tempElems[0][2] + 
			tempElems[0][3]);
	out[start].y = (in_x * tempElems[1][0] + 
			in_y * tempElems[1][1] +
			in_z * tempElems[1][2] +
			tempElems[1][3]);
	out[start].z = (in_x * tempElems[2][0] + 
			in_y * tempElems[2][1] +
			in_z * tempElems[2][2] + 
			tempElems[2][3]);
	if (normalizer != 1)
	  {
	    out[start].x /= normalizer;
	    out[start].y /= normalizer;
	    out[start].z /= normalizer;
	  }
	out[start].argb  = in[start].argb;
	out[start].m_intensity = in[start].m_intensity;
      }
  }
 
  public void translate(float tx, float ty, float tz)
  {
    float[] tempRow;
    for (int j = 0; j < 4; j++)
      {
	tempRow    = m_elems[j];
	tempRow[3] = tempRow[0]*tx + tempRow[1]*ty + 
	             tempRow[2]*tz + tempRow[3];
      }
  }

  public String toString()
  {
    String out = "";

    out += "[";

    for (int j = 0; j < 4; j++)
      {
	for (int i = 0; i < 4; i++)
	  {
	    out += " " + m_elems[j][i];
	  }
	out += "\n ";
      }
    out += "]";

    return out;
  }

  //
  // Utility methods
  //

  public static final void add(Matrix3D m1, Matrix3D m2, int rows, int cols)
  {
    float[][] m1Elements, m2Elements;

    m1Elements = m1.getElements();
    m2Elements = m2.getElements();

    for (int j = 0; j < rows; j++)
      {
	for (int i = 0; i < cols; i++)
	  {
	    m1Elements[j][i] += m2Elements[j][i];
	  }
      }
  }

  public void loadIdentity()
  {
    for (int arr = 0; arr < 4; arr++)
      {
	System.arraycopy(IDENTITY_ELEMS[arr], 0, m_elems[arr], 0, 4);
      }
  }

  public final void mulByScalar(float scalar, int rows, int cols)
  {
    for (int j = 0; j < rows; j++)
      {
	for (int i = 0; i < cols; i++)
	  {
	    m_elems[j][i] *= scalar;
	  }
      }
  }

  public static Matrix3D makeNormalMatrix(Matrix3D m)
  {
    float[][] m_elems;
    m_elems = m.getElements();

    //
    // The closed-form soln. for the normal matrix, in terms
    // of matrix m
    float[][] newElems = { 
      { // 
	// Element 0, 0
	((m_elems[1][1]*m_elems[2][2]) -
	 (m_elems[1][2]*m_elems[2][1])) /
        ((m_elems[0][0]*m_elems[1][1]*m_elems[2][2]) -
	 (m_elems[0][0]*m_elems[1][2]*m_elems[2][1]) -
	 (m_elems[1][0]*m_elems[0][1]*m_elems[2][2]) +
	 (m_elems[1][0]*m_elems[0][2]*m_elems[2][1]) +
	 (m_elems[2][0]*m_elems[0][1]*m_elems[1][2]) -
	 (m_elems[2][0]*m_elems[0][2]*m_elems[1][1])), 
	//
	// Element 1, 0
	-((m_elems[1][0]*m_elems[2][2]) -
	  (m_elems[1][2]*m_elems[2][0])) /
	((m_elems[0][0]*m_elems[1][1]*m_elems[2][2]) -
	 (m_elems[0][0]*m_elems[1][2]*m_elems[2][1]) -
	 (m_elems[1][0]*m_elems[0][1]*m_elems[2][2]) +
	 (m_elems[1][0]*m_elems[0][2]*m_elems[2][1]) +
	 (m_elems[2][0]*m_elems[0][1]*m_elems[1][2]) -
	 (m_elems[2][0]*m_elems[0][2]*m_elems[1][1])), 
	
	//
	// Element 2, 0
	((m_elems[1][0]*m_elems[2][1]) -
	 (m_elems[1][1]*m_elems[2][0])) /
	((m_elems[0][0]*m_elems[1][1]*m_elems[2][2]) -
	 (m_elems[0][0]*m_elems[1][2]*m_elems[2][1]) -
	 (m_elems[1][0]*m_elems[0][1]*m_elems[2][2]) +
	 (m_elems[1][0]*m_elems[0][2]*m_elems[2][1]) +
	 (m_elems[2][0]*m_elems[0][1]*m_elems[1][2]) -
	 (m_elems[2][0]*m_elems[0][2]*m_elems[1][1])),

	0 },

      {
	-((m_elems[0][1]*m_elems[2][2]) -
	  (m_elems[0][2]*m_elems[2][1])) /
	((m_elems[0][0]*m_elems[1][1]*m_elems[2][2]) -
	 (m_elems[0][0]*m_elems[1][2]*m_elems[2][1]) -
	 (m_elems[1][0]*m_elems[0][1]*m_elems[2][2]) +
	 (m_elems[1][0]*m_elems[0][2]*m_elems[2][1]) +
	 (m_elems[2][0]*m_elems[0][1]*m_elems[1][2]) -
	 (m_elems[2][0]*m_elems[0][2]*m_elems[1][1])), 
	
	((m_elems[0][0]*m_elems[2][2]) -
	 (m_elems[0][2]*m_elems[2][0])) /
	((m_elems[0][0]*m_elems[1][1]*m_elems[2][2]) -
	 (m_elems[0][0]*m_elems[1][2]*m_elems[2][1]) -
	 (m_elems[1][0]*m_elems[0][1]*m_elems[2][2]) +
	 (m_elems[1][0]*m_elems[0][2]*m_elems[2][1]) +
	 (m_elems[2][0]*m_elems[0][1]*m_elems[1][2]) -
	 (m_elems[2][0]*m_elems[0][2]*m_elems[1][1])), 
	
	-((m_elems[0][0]*m_elems[2][1]) -
	  (m_elems[0][1]*m_elems[2][0])) /
	((m_elems[0][0]*m_elems[1][1]*m_elems[2][2]) -
	 (m_elems[0][0]*m_elems[1][2]*m_elems[2][1]) -
	 (m_elems[1][0]*m_elems[0][1]*m_elems[2][2]) +
	 (m_elems[1][0]*m_elems[0][2]*m_elems[2][1]) +
	 (m_elems[2][0]*m_elems[0][1]*m_elems[1][2]) -
	 (m_elems[2][0]*m_elems[0][2]*m_elems[1][1])),

	0 }, 

      {
	((m_elems[0][1]*m_elems[1][2]) -
	 (m_elems[0][2]*m_elems[1][1])) /
	((m_elems[0][0]*m_elems[1][1]*m_elems[2][2]) -
	 (m_elems[0][0]*m_elems[1][2]*m_elems[2][1]) -
	 (m_elems[1][0]*m_elems[0][1]*m_elems[2][2]) +
	 (m_elems[1][0]*m_elems[0][2]*m_elems[2][1]) +
	 (m_elems[2][0]*m_elems[0][1]*m_elems[1][2]) -
	 (m_elems[2][0]*m_elems[0][2]*m_elems[1][1])),

	-((m_elems[0][0]*m_elems[1][2]) -
	  (m_elems[0][2]*m_elems[1][0])) /
	((m_elems[0][0]*m_elems[1][1]*m_elems[2][2]) -
	 (m_elems[0][0]*m_elems[1][2]*m_elems[2][1]) -
	 (m_elems[1][0]*m_elems[0][1]*m_elems[2][2]) +
	 (m_elems[1][0]*m_elems[0][2]*m_elems[2][1]) +
	 (m_elems[2][0]*m_elems[0][1]*m_elems[1][2]) -
	 (m_elems[2][0]*m_elems[0][2]*m_elems[1][1])), 
	
	((m_elems[0][0]*m_elems[1][1]) -
	 (m_elems[0][1]*m_elems[1][0])) /
	((m_elems[0][0]*m_elems[1][1]*m_elems[2][2]) -
	 (m_elems[0][0]*m_elems[1][2]*m_elems[2][1]) -
	 (m_elems[1][0]*m_elems[0][1]*m_elems[2][2]) +
	 (m_elems[1][0]*m_elems[0][2]*m_elems[2][1]) +
	 (m_elems[2][0]*m_elems[0][1]*m_elems[1][2]) -
	 (m_elems[2][0]*m_elems[0][2]*m_elems[1][1])),

	0 },

      { 0, 0, 0, 1}
    };

    Matrix3D ret = new Matrix3D();
    float[][] retElems = ret.getElements();

    for (int arr = 0; arr < 4; arr++)
      {
	System.arraycopy(newElems, 0, retElems, 0, 4);
      }
    return ret;
  }

  public static Matrix3D makeNormalMatrix2(Matrix3D m)
  {
    Matrix3D ret = new Matrix3D();
    float[][] mElems = m.getElements();
    float[][] retElems = ret.getElements();
    float[][] newElems = {
      { mElems[0][0], mElems[0][1], mElems[0][2], 0 },
      { mElems[1][0], mElems[1][1], mElems[1][2], 0 },
      { mElems[2][0], mElems[2][1], mElems[2][2], 0 },
      { 0, 0, 0, 1 } 
    };
    
    for (int arr = 0; arr < 4; arr++)
      {
	System.arraycopy(retElems, 0, newElems, 0, 4);
      }
    
    ret.inverse();
    ret.transpose();

    return ret;
  }

  public void transpose()
  {    
    float tmp;

    for (int j = 0; j < 4; j++)
      {
	for (int i = j; i < 4; i++)
	  {
	    tmp           = m_elems[i][j];
	    m_elems[i][j] = m_elems[j][i];
	    m_elems[j][i] = tmp;
	  }
      }
  }

  public void inverse()
  {
    Matrix3D ident = new Matrix3D();
    float[][] i_elems = ident.getElements();

    for (int j = 0; j < 4; j++)
      {
	for (int i = 0; i < j; i++)
	  {
	    if (m_elems[j][i] != 0)
	      {
		float factorM = m_elems[j][i] / m_elems[i][i];
		for (int k = 0; k < 4; k++)
		  {
		    i_elems[j][k] -= i_elems[i][k] * factorM;
		    m_elems[j][k] -= m_elems[i][k] * factorM;
		  }
	      }
	  }
      }

    for (int j = 0; j < 4; j++)
      {
	for (int i = j + 1; i < 4; i++)
	  {
	    if (m_elems[j][i] != 0)
	      {
		float factorM = m_elems[j][i] / m_elems[i][i];
		for (int k = 0; k < 4; k++)
		  {
		    i_elems[j][k] -= i_elems[i][k] * factorM;
		    m_elems[j][k] -= m_elems[i][k] * factorM;
		  }
	      }
	  }
      }

    for (int i = 0; i < 4; i++)
      {
	float factor = 1 / m_elems[i][i];
	for (int k = 0; k < 4; k++)
	  {
	    i_elems[i][k] *= factor;
	    m_elems[i][k] *= factor;
	  }
      }
    m_elems = i_elems;
  }

  static
  {
    temp1 = new Matrix3D();
  }
}
  




