import Raster;
import Point3D;
import java.lang.Math;

public class Matrix3D {
	//This class represents a transform matrix for 3D transforms
	//There are methods to initialize the matrix with a variety
	//of different transforms, and these transforms can be composed
	//together to form more complex transforms.
	//Once the matrix is set up, 3D points can be passed in and class
	//transforms this points according to the current matrix.
	//The matrix representation is a two dimensional, 4 by 4 array.
	//It is four by four because we need to add in the fourth dimension
	//to do transforms such as projections.  The fourth dimension is then
	//factored out by returning 3D Points that have been scaled so that the
	//fourth dimension would be 0.

	public float matrix4[][];

	//
	// Constructors
	//

	public Matrix3D() {
		// initialize with identity transform
		matrix4 = new float[4][4];
		loadIdentity();
	}

	public Matrix3D(Matrix3D copy) {
		// initialize with copy of source
		matrix4 = new float[4][4];
		Copy(copy);	
	}

	public Matrix3D(float val) {
		//initialize with all entries set to val
		matrix4 = new float[4][4];
		fill(val);
	}
			
	public Matrix3D(Raster r) {
		// initialize with a mapping from
		// canonical space to screen space
	 	matrix4 = new float[4][4];
		int width = r.getWidth();
		int height = r.getHeight();
		float w = (float) width;
		float h = (float) height;
		float l = -8;
		float t = -8;
		float rl = 16;
		float bt = 16;
		float n = -8;
		float fn = -16;
		float zmax = (float) -200;
		matrix4[0][0] = w / rl;
		matrix4[1][1] = h / bt;
		matrix4[2][2] = zmax / fn;
		matrix4[3][3] = (float) 1;
		
		matrix4[0][3] = ((-1) * l * w) / rl;
		matrix4[1][3] = ((-1) * t * h) / bt;
		matrix4[2][3] = ((-1) * n * zmax) / fn;

	}

	
	//
	// General interface methods
	//		  
	public void Set(int i, int j, float value) {
		// set element [i][j] to value
		// I find this more intuitive than set ([i][j] vs. [j][i])
		if ((i < 4) && (j < 4) && (i >= 0) && (j >= 0)) {
			matrix4[i][j] = value;
		}
	}
	public void set(int i, int j, float value) {
		// set element [j][i] to value
		// I find Set more intuitive
		if ((i < 5) && (j < 5) && (i > 0) && (j > 0)) {
			matrix4[j-1][i-1] = value;
		}
	}

	public void debugSet(float p00, float p01, float p02, float p03,
					     float p10, float p11, float p12, float p13,
						 float p20, float p21, float p22, float p23,
						 float p30, float p31, float p32, float p33) {
		Set(0, 0, p00);
		Set(0, 1, p01);
		Set(0, 2, p02);
		Set(0, 3, p03);

		Set(1, 0, p10);
		Set(1, 1, p11);
		Set(1, 2, p12);
		Set(1, 3, p13);

		Set(2, 0, p20);
		Set(2, 1, p21);
		Set(2, 2, p22);
		Set(2, 3, p23);

		Set(3, 0, p30);
		Set(3, 1, p31);
		Set(3, 2, p32);
		Set(3, 3, p33);
		
	}

	public float Get(int i, int j) {
		// return element [i][j] or 0 if no such element
		if ((i < 4) && (j < 4) && (i >= 0) && (j >= 0)) {
			return matrix4[i][j];
		}
		return (float)0.0;
	}
	public float get(int i, int j) {
		// return element [j][i] or 0 if no such element
		// I find Get more intuitive
		if ((i < 5) && (j < 5) && (i > 0) && (j > 0)) {
			return matrix4[j-1][i-1];
		}
		return (float)0.0;
	}

	public void fill(float val) {
		//fills all the values in the matrix with 'val'. 
		//Used mostly to fill the matrix with 0s
		for (int i = 0; i < 4; i++) {
			for (int j = 0; j < 4; j++) {
				matrix4[i][j] = val;
			}
		}
	}

	public void Copy(Matrix3D copy) {
		//copies the transform matrix of the src to this
		for (int i = 0; i < 4; i++) {
			for (int j = 0; j < 4; j++) {
				matrix4[i][j] = copy.matrix4[i][j];
			}
		}
	}

	public Point3D transformPoint(Point3D in) {
		//transforms a 3D point using this
		Point3D ret = new Point3D();
		float point4 = 0;
		for (int i = 0; i < 3; i++) {
			for (int j = 0; j < 4; j++) {
				//for each row, for each column
				//multiply the value times the corresponding
				//in value
				if (j < 3) {
					//normal
					ret.p[i] += in.p[j] * Get(i, j);
				}
				else {
					//point3D, not 4D, 4th point = 1
					ret.p[i] += Get(i, j);
				}
			}
		}
		for (int i = 0; i < 4; i++) {
			point4 += Get(3,i);
		}
		if (point4 != 0) {
			//this is good, we are not on the 'illegal' plane
			for (int i = 0; i < 3; i++) {
				ret.p[i] /= point4;
			}
		}
		else {
			//oh no, better undo the 4th dimensional calculations
			//that we did and return that.
			for (int i = 0; i < 3; i++) {
				ret.p[i] -= Get(i, 4);
			}
		}
		return ret;
	}

	//
	// Transform points from the in array to the out array
	// using the current matrix. The subset of points transformed
	// begins at the start index and has the specified length
	//
	//	  for (i = 0; i < length; i++)
	//		  out[start+i] = this * in[start+i]
	//
	public void transform(Point3D in[], Point3D out[], int start, int length) {
		for (int i = 0; i < length; i++) {
			out[start+i] = transformPoint(in[start+i]);
		}
	}
	
	public final void compose(Matrix3D src) {
		// this = this * src
		//make a copy of this matrix
		Matrix3D copy = new Matrix3D(this);
		//fill this matrix with 0s so we can use += all the time
		fill((float)0.0);

		//peform the matrix multiplication 
		for (int i = 0; i < 4; i++) {
			for (int j = 0; j < 4; j++) {
				for (int k = 0; k < 4; k++) {
					matrix4[i][j] += copy.matrix4[i][k] * 
									 src.matrix4[k][j];
				}
			}
		}
	}

	public void loadIdentity() {
		// this = identity
		for (int i = 0; i < 4; i++) {
			for (int j = 0; j < 4; j++) {
				if (i == j) {
					matrix4[i][j] = (float)1.0;
				} else {
					matrix4[i][j] = (float)0.0;
				}
			}
		}
	}


	public void translate(float tx, float ty, float tz) {
		// this = this * t
		Set(0,3,(Get(0,3) + tx));
		Set(1,3,(Get(1,3) + ty));
		Set(2,3,(Get(2,3) + tz));
	}

	public void scale(float s) {
		// scale all the points by s
		for (int i = 0; i < 4; i++) {
			for (int j = 0; j < 4; j++) {
				matrix4[i][j] *= s;
			}
		}
	}
	public void scale3(float s) {
		// scale all the points by s that are in the 3x3 sub-matrix
		for (int i = 0; i < 3; i++) {
			for (int j = 0; j < 3; j++) {
				matrix4[i][j] *= s;
			}
		}
	}
		
	public void scale(float sx, float sy, float sz) {
		// this = this * scale
		//create a scale matrix and multiply it with this
		Matrix3D scaleMat = new Matrix3D();
		scaleMat.Set(0,0, sx);
		scaleMat.Set(1,1, sy);
		scaleMat.Set(2,2, sz);
		compose(scaleMat);
	}

	public void shear(float kxy, float kxz, float kyz) {
		// this = this * shear
		//create a shear matrix and multiply it with this
		Matrix3D shearMat = new Matrix3D();
		shearMat.Set(0,1, kxy);
		shearMat.Set(0,2, kxz);
		shearMat.Set(1,2, kyz);
		compose(shearMat);
	}

	public void skew(float kxy, float kxz, float kyz) {
		// this = this * skew
		//create a skew symmetric matrix and multiply it with this
		Matrix3D skewMat = skewsym(kxy, kxz, kyz);
		compose(skewMat);
	}

	public Matrix3D symmetric(float ax, float ay, float az) {
		//create a symmetric matrix given a 3DPoint and return
		//the matrix
		Matrix3D ret = new Matrix3D();
		ret.Set(0,0, ax*ax);
		ret.Set(0,1, ax*ay);
		ret.Set(0,2, ax*az);
		ret.Set(1,0, ay*ax);
		ret.Set(1,1, ay*ay);
		ret.Set(1,2, ay*az);
		ret.Set(2,0, az*ax);
		ret.Set(2,1, az*ay);
		ret.Set(2,2, az*az);
		return ret;
	}

	public Matrix3D skewsym(float ax, float ay, float az) {
		//create a skew symmetric matrix given a 3DPoint and return
		//the matrix
		Matrix3D ret = new Matrix3D((float)0.0);
		ret.Set(0,1, (-1 * az));
		ret.Set(0,2, ay);
		ret.Set(1,0, az);
		ret.Set(1,2, (-1 * ax));
		ret.Set(2,0, (-1 * ay));
		ret.Set(2,1, ax);
		ret.Set(3,3, (float) 1);
		return ret;
	}

	public void add(Matrix3D src) {
		//adds the points in the src matrix to the corresponding
		//points in this
		for (int i = 0; i < 4; i++) {
			for (int j = 0; j < 4; j++) {
				matrix4[i][j] += src.matrix4[i][j];
			}
		}
	}
	public void add3(Matrix3D src) {
		//adds the points in the src matrix to the corresponding
		//points in this for all points in the 3x3 sub matrix
		for (int i = 0; i < 3; i++) {
			for (int j = 0; j < 3; j++) {
				matrix4[i][j] += src.matrix4[i][j];
			}
		}
	}


	public void rotate(float ax, float ay, float az, float angle) {
		// this = this * rotate  
		//uses the simplified formula from class
		//rotate matrix = symm(a)*(1-cos(angle)) +
		//                skew(a)*sin(angle) +
		//                ident*cos(angle)
		//and composes it with this
		Matrix3D rotate;
		rotate = new Matrix3D();
		
		// Normalize vector
		float length = (float)Math.sqrt((ax*ax)+(ay*ay)+(az*az));
		ax = ax / length;
		ay = ay / length;
		az = az / length;
		
		float cost = (float)Math.cos(angle);
		float sint = (float)Math.sin(angle);
		float f1 = 1 - cost;
		
		//Using the formula found in the lecture notes,
		//find the appropriate components of the matrix
		float a = ((ax * ax) * f1) + cost;
		float b = (ax * ay * f1) - (az * sint);
		float c = (ax * az * f1) + (ay *sint);
		float d = (ax * ay * f1) + (az *sint);
		float e = (ay * ay * f1) + cost;
		float f = (ay * az * f1) -(ax *sint);
		float g = (ax * az * f1) - (ay * sint);
		float h = (ay * az * f1) + (ax * sint);
		float i = (az * az * f1) + cost;
		
		//set the values
		rotate.Set(0,0,a);
		rotate.Set(0,1,b);
		rotate.Set(0,2,c);
		rotate.Set(0,3,0);
		rotate.Set(1,0,d);
		rotate.Set(1,1,e);
		rotate.Set(1,2,f);
		rotate.Set(1,3,0);
		rotate.Set(2,0,g);
		rotate.Set(2,1,h);
		rotate.Set(2,2,i);
		rotate.Set(3,3,1);
		
		compose(rotate);
	}

	public Point3D cross(Point3D b) {
		//assumes that 'this' has already been set up as a skew matrix
		//ready for crossing.
		return transformPoint(b);
	}

	public void lookAt(float eyex, float eyey, float eyez,
		float atx,  float aty,  float atz,
		float upx,  float upy,  float upz) {
		// this = this * lookat
		//creates the lookat matrix as described in class
		//[          r^         -r^ . eye ]
		//[          u^         -u^ . eye ]
		//[          -l^         l^ . eye ]
		//[    0     0      0       1     ]
 		//
		//where ^ = normalize
		// r = l cross up
		// u = r cross l
		// l = at

		Point3D up = new Point3D(upx, upy, upz);
		Point3D l = new Point3D(atx, aty, atz);
		Point3D eye = new Point3D(eyex, eyey, eyez);
		l.subtract(eye);
		Matrix3D crossMatUp = skewsym(l.p[0], l.p[1], l.p[2]);
		Point3D r = crossMatUp.cross(up);
		Matrix3D crossMatR = skewsym(r.p[0], r.p[1], r.p[2]);
		Point3D u = crossMatR.cross(l);
		r.normalize();
		u.normalize();
		l.normalize();

		Matrix3D lookat = new Matrix3D((float) 0.0);
		lookat.Set(0,0, r.getX());
		lookat.Set(0,1, r.getY());
		lookat.Set(0,2, r.getZ());
		r.negate();
		lookat.Set(0,3, r.dot(eye));

		lookat.Set(1,0, u.getX());
		lookat.Set(1,1, u.getY());
		lookat.Set(1,2, u.getZ());
		u.negate();
		lookat.Set(1,3, u.dot(eye));

		lookat.Set(2,3, l.dot(eye));
		l.negate();
		lookat.Set(2,0, l.getX());
		lookat.Set(2,1, l.getY());
		lookat.Set(2,2, l.getZ());
		

		lookat.Set(3,3, (float)1.0);

		compose(lookat);

	}
	
	//
	// Assume the following projection transformations
	// transform points into the canonical viewing space
	//
	public void perspective(float left, float right,
								float bottom, float top,
								float near, float far) {
		// this = this * persp
		//double check inputs:
		if ((left == right) ||
			(bottom == top) ||
			(near == far))
			return;

		Matrix3D persp = new Matrix3D((float) 0.0);
/*		persp.Set(0,0, (float) 1);
		persp.Set(1,1, (float) 1);
		persp.Set(3,2, (float) 1.0);

		persp.Set(0,0,((2*near)/(left-right)));
		persp.Set(0,2,-((right+left)/(right-left)));
		persp.Set(1,1,((2*near)/(top-bottom)));
		persp.Set(1,2,-((bottom+top)/(bottom-top)));
		persp.Set(2,2,((far+near)/(far-near)));
		persp.Set(2,3,-((2*far*near)/(far-near)));
		persp.Set(3,2,1);
		persp.Set(3,3,0);
		compose(persp);
*/
		persp.Set(0,0, (float) ((2 * near) / (right-left)));
		persp.Set(1,1, (float) ((2 * near) / (bottom-top)));
		persp.Set(2,2, (float) ((far + near) / (far - near)));
		
		persp.Set(0,2, (float) ((-(right + left)) / (right - left)));
		persp.Set(1,2, (float) ((-(bottom + top)) / (bottom - top)));
		persp.Set(2,3, (float) ((-(2 * far * near)) / (far - near)));
		
		persp.Set(3,2, (float) 1.0);

		//I can't figure out why my perspective matrix is not working.
		//so instead I am going without perspective, and just flipping
		//the y-axis here.
		//the perspective matrix at this point is:
		//| 1.41421     0      0     0    |
		//| 0       -1.41421   0     0    |
		//| 0           0    1.01   2.01  |
		//| 0           0      1     0    |

		float m = (float) 1;
		persp.debugSet(m,  0, 0, 0,
			           0, -m, 0, 0,
				       0,  0, 1, 0,
				       0,  0, 0, 1);

		compose(persp);
	}

	public void orthographic(float left, float right,
		float bottom, float top,
		float near, float far) {
		// this = this * ortho
		//double check inputs:
		if ((left == right) ||
			(bottom == top) ||
			(near == far))
			return;

		Matrix3D ortho = new Matrix3D((float) 0.0);
		ortho.Set(0,0, (float) (2 / (right - left)));
		ortho.Set(1,1, (float) (2 / (bottom - top)));
		ortho.Set(2,2, (float) (2 / (far - near)));

		ortho.Set(0,3, (float) ((-(right + left)) / (right - left)));
		ortho.Set(1,3, (float) ((-(bottom + top)) / (bottom - top)));
		ortho.Set(2,3, (float) ((-(far + near)) / (far - near)));
		
		ortho.Set(3,3, (float) 1.0);

		compose(ortho);
	}
}