// FILE:     Matrix3D.java
// PURPOSE:  A class for representing a 3-dimensional matrix which allows
//           for a lot of primitive matrix operations.
// METHOD:   The 4x4 matrix is implemented as a 2D java matrix.  <more here>
//
// MODS:     10.30.98 Jeremy Lueck (jlueck@mit.edu) -- original
//

/**
 * A class for representing a 3-dimensional matrix which contains methods
 * for doing primitive 3D matrix operations.
 */
public class Matrix3D {

    private float[][] data;

    //              //
    // Constructors //
    //              //

    /**
     * The default constructor.  Initializes the matrix to the identity
     * matrix.
     */
    public Matrix3D() {
	// create new float array with all 0.0f initial values.
	data = new float[4][4];
	// create the identiy matrix
	data[0][0] = data[1][1] = data[2][2] = data[3][3] = 1.0f;
    }

    /**
     * The copy constructor.  Initializes the matrix to the value of
     * the copy matrix
     *
     * @param copy A Matrix3D whose values should be copied into this 
     * Matrix3D object.
     */
    public Matrix3D(Matrix3D copy) {
	// copy over the internal matrix
	data = new float[4][4];
	for (int i=0; i < 4; i++) {
	    System.arraycopy(copy.data[i], 0, this.data[i], 0, 4);
	}
    }

    public Matrix3D(Raster r) {
	data = new float[4][4];
	data[0][0] = r.width / 2.0f;
	data[0][3] = r.width / 2.0f;
	data[1][1] = r.height / 2.0f;
	data[1][3] = r.height / 2.0f;
	data[2][2] = 1.0f;
	data[3][3] = 1.0f;
    }

    //                           //
    // General interface methods //
    //                           //

    /**
     * Sets the (r,c)th element in the matrix
     *
     * @param r the row element
     * @param c the column element
     * @param value the new value for the (r,c)th element
     */
    public void set(int r, int c, float value) {
	data[r][c] = value;
    }

    /**
     * Gets the value of the (r,c)th element in the matrix
     * 
     * @param r the row element
     * @param c the column element
     * @return the value of the (j,i)th element in the matrix
     */
    public float get(int r, int c){
	return data[r][c];
    }            

    //
    // Primitive matrix operations
    // 

    public void transform(Point3D in[], Point3D out[], int start, int length) {

	for (int i = 0; i < length; i++) {

	    // compose the current matrix and the current vector.
	    Point3D pin = in[i+start];
	    float x = data[0][0]*pin.x + data[0][1] * pin.y + 
		      data[0][2]*pin.z + data[0][3];
	    float y = data[1][0]*pin.x + data[1][1] * pin.y + 
		      data[1][2]*pin.z + data[1][3];
	    float z = data[2][0]*pin.x + data[2][1] * pin.y + 
		      data[2][2]*pin.z + data[2][3];

	    // stuff answer into a Point3D and fill out array.
	    Point3D pout = new Point3D(x, y, z);
	    out [i+start] = pout;
	}
    }
    

    public void compose(Matrix3D src) {
	// simple matrix multiply of this = this*src
	data = matrixMultiply(data, src.data);
    }

    /**
     * Loads the identity matrix into this Matrix3D object.  
     */
    public void loadIdentity() {
	for (int r=0; r < 4; r++) {
	    for (int c=0; c < 4; c++) {
		data[r][c] = (r == c) ? 1.0f : 0.0f;
	    }
	} 
    }
                          
    public void translate(float tx, float ty, float tz) {

	// create translation matrix
	Matrix3D transMat = new Matrix3D();
	float[][] trans = transMat.data;
	trans[0][3] = tx;
	trans[1][3] = ty;
	trans[2][3] = tz;

	// matrix multiply: self = self * trans
	data = matrixMultiply(data, trans);
    }
 
    public void scale(float sx, float sy, float sz) {
	
	// create scale matrix
	Matrix3D scaleMat = new Matrix3D();
	float[][] scale = scaleMat.data;
	scale[0][0] = sx;
	scale[1][1] = sy;
	scale[2][2] = sz;
	
	// matrix multiply: self = self * trans
	data = matrixMultiply(data, scale);
    }

    public void skew(float kxy, float kxz, float kyz) {
	
	// create skew matrix
	Matrix3D skewMat = new Matrix3D();
	float[][] skew = skewMat.data;
	skew[0][1] = kxy;
	skew[0][2] = kxz;
	skew[1][2] = kyz;

	// matrix multiply: self = self * skew
	data = matrixMultiply(data, skew);
    }

    public void rotate(float ax, float ay, float az, float angle) {

	// create rotation constants
	float sinA2 = (float) Math.sin(angle/2.0);
	float s     = (float) (2.0f * Math.cos(angle/2.0));
	float a     = sinA2 * ax;
	float b     = sinA2 * ay;
	float c     = sinA2 * az;	
	float aa2   = 2*a*a;
	float bb2   = 2*b*b;
	float cc2   = 2*c*c;
	float sa    = s*a;
	float sb    = s*b;
	float sc    = s*c;
	float ab2   = 2*a*b;
	float ac2   = 2*a*c;
	float bc2   = 2*b*c;

	// create rotation matrix
	float[][] rot = new float[4][4];
	rot[0][0] = 1.0f - bb2 - cc2;
	rot[0][1] = ab2 - sc;
	rot[0][2] = ac2 + sb;
	rot[1][0] = ab2 + sc;
	rot[1][1] = 1.0f - aa2 - cc2;
	rot[1][2] = bc2 - sa;
	rot[2][0] = ac2 - sb;
	rot[2][1] = bc2 + sa;
	rot[2][2] = 1.0f - aa2 - bb2;
	rot[3][3] = 1.0f;

	// matrix multiply: self = self * rot
	data = matrixMultiply(data, rot);

    }

    public void lookAt(float eyex, float eyey, float eyez,
		       float atx,  float aty,  float atz,
		       float upx,  float upy,  float upz) {

	// compute l-hat
	float[] l = new float[3];
	l[0] = atx - eyex;
	l[1] = aty - eyey;
	l[2] = atz - eyez;
	normalize(l);
	
	// compute r-hat
	float[] up = new float[3];
	up[0] = upx;
	up[1] = upy;
	up[2] = upz;
	float[] r = crossProduct(l, up);
	normalize(r);

	// compute u-hat
	float[] u = crossProduct(r, l);
	normalize(u);

	// compute eye vector
	float[] eye = new float[3];
	eye[0] = eyex;
	eye[1] = eyey;
	eye[2] = eyez;
	

	// compute lookat matrix
	Matrix3D lookatMat = new Matrix3D();
	float[][] lookat = lookatMat.data;
	lookat[0][0] = r[0];
	lookat[0][1] = r[1];
	lookat[0][2] = r[2];
	lookat[0][3] = -dotProduct(r, eye);
	lookat[1][0] = u[0];
	lookat[1][1] = u[1];
	lookat[1][2] = u[2];
	lookat[1][3] = -dotProduct(u, eye);
	lookat[2][0] = -l[0];
	lookat[2][1] = -l[1];
	lookat[2][2] = -l[2];
	lookat[2][3] = dotProduct(l, eye);
	lookat[3][3] = 1.0f;

	System.err.println("Lookat");
	System.err.println(lookatMat);
	
	// matrix multiply: self = self * lookat
	data = matrixMultiply(data, lookat);

    }

    public void perspective(float left, float right,
			    float bottom, float top,
			    float near, float far) {

	// create perspective constants
	float rml = right-left;
	float bmt = bottom-top;
	float fmn = far-near;
	
	// create perspective matrix
	Matrix3D perspMat = new Matrix3D();
	float[][] persp = perspMat.data;
	persp[0][0] = 2.0f*near/rml;
	persp[0][2] = -(right+left)/rml;
	persp[1][1] = 2.0f*near/bmt;
	persp[1][2] = -(bottom+top)/bmt;
	persp[2][2] = (far+near)/fmn;
	persp[2][3] = -2.0f*far*near/fmn; 
	persp[3][2] = 1.0f;
	persp[3][3] = 0.0f;

	System.err.println("Perspective:");
	System.err.println(perspMat);

	// matrix multiply: self = self * persp
	data = matrixMultiply(data, persp);
    }

    public void orthographic(float left, float right,
			     float bottom, float top,
			     float near, float far) {
	// create orthographic constants
	float rml = right-left;
	float bmt = bottom-top;
	float fmn = far-near;

	// create orthographic matrix
	float[][] ortho = new float[4][4];
	ortho[0][0] = 2.0f/rml;
	ortho[1][1] = 2.0f/bmt;
	ortho[2][2] = 2.0f/fmn;
	ortho[3][3] = 1.0f;
	ortho[0][3] = -(right+left)/rml;
	ortho[1][3] = -(bottom+top)/bmt;
	ortho[2][3] = -(far+near)/fmn;
	
	// matrix multiply: self = self * ortho
	data = matrixMultiply(data, ortho);
    }

    public String toString() {

	String ret = "";
	for (int i=0; i < 4; i++) {
	    ret += "[" + data[i][0] + "\t" + data[i][1] + "\t" + 
		         data[i][2] + "\t" + data[i][3] + "]\n";
	}
	return ret;
    }

    //
    // Private helper methods.
    // 
    private float[][] matrixMultiply(float[][] a, float[][] b) {

	// FIXME: can we make this any more efficient?
	float[][] tmp = new float[4][4];
	for (int r=0; r < 4; r++) {
	    for (int c=0; c < 4; c++) {
		tmp[r][c] = a[r][0]*b[0][c] + a[r][1]*b[1][c] +
		            a[r][2]*b[2][c] + a[r][3]*b[3][c];
	    }
	}
	return tmp;
    }

    private void normalize(float[] a) {
	float norm = (float)Math.sqrt(a[0]*a[0] + a[1]*a[1] + a[2]*a[2]);
	a[0] /= norm;
	a[1] /= norm;
	a[2] /= norm;
    }

    private float[] crossProduct(float[] a, float[] b) {
	float[] ab = new float[3];
	ab[0] = a[1]*b[2] - a[2]*b[1];
	ab[1] = a[2]*b[0] - a[0]*b[2];
	ab[2] = a[0]*b[1] - a[1]*b[0];
	return ab;
    }

    private float dotProduct(float[] a, float[] b) {
	return a[0]*b[0] + a[1]*b[1] + a[2]*b[2];
    }

}



