//
//
// Matrix3D
//
//
public class Matrix3D 
{
	protected float mat[][];
	// Matrix:
	//	0	0	0	0
	//	0	0	1	0
	//	0	0	0	0
	//	0	0	0	0
	// This matrix is all zeros except, mat[1][2] = 1

	//
	// Constructors   
	//
	public Matrix3D()                   // initialize with identity transform
	{
		mat = new float[4][4];
		for (int i = 0; i < 4; i++)
			for (int j = 0; j < 4; j++)
		{
			if (i==j)
				mat[j][i] = 1;
			else
				mat[j][i] = 0;
		}
	}
	
	/*
	a	b	c	d
	e	f	g	h
	i	j	k	l
	m	n	o	p
	*/
	public Matrix3D(
		float a,float b,float c,float d,
		float e,float f,float g,float h,
		float i,float j,float k,float l,
		float m,float n,float o,float p
		)   // initialize with these 16 values
	{
		mat = new float[4][4];
		mat[0][0] = a;
		mat[0][1] = b;
		mat[0][2] = c;
		mat[0][3] = d;
		mat[1][0] = e;
		mat[1][1] = f;
		mat[1][2] = g;
		mat[1][3] = h;
		mat[2][0] = i;
		mat[2][1] = j;
		mat[2][2] = k;
		mat[2][3] = l;
		mat[3][0] = m;
		mat[3][1] = n;
		mat[3][2] = o;
		mat[3][3] = p;
	}

	/*
	a	b	c	0
	d	e	f	0
	g	h	i	0
	0	0	0	1
	*/
	public Matrix3D(
		float a,float b,float c,
		float d,float e,float f,
		float g,float h,float i
		)   // initialize with these 9 values
	{
		this(
			a,b,c,0,
			d,e,f,0,
			g,h,i,0,
			0,0,0,1
			);
	}

	public Matrix3D(Matrix3D copy)      // initialize with copy of source
	{
		mat = new float[4][4];
		for (int i = 0; i < 4; i++)
			for (int j = 0; j < 4; j++)
			mat[j][i] = copy.get(i, j);
	}
	
	/*
	width/2		0			0			width/2
	0			height/2	0			height/2
	0			0			MAXZ/2		MAXZ/2
	0			0			0			1
	*/
	public Matrix3D(Raster r)			// initialize with a mapping from canonical space to screen space
	{
		mat = new float[4][4];
        float w = r.width / 2;
        float h = r.height / 2;
        float d = ZRaster.MAXZ / 2;

		for (int i = 0; i < 4; i++)
			for (int j = 0; j < 4; j++)
			mat[j][i] = 0;

		mat[0][0] = w;
		mat[1][1] = h;
		mat[2][2] = d;
		mat[0][3] = w;
		mat[1][3] = h;
		mat[2][3] = d;
		mat[3][3] = (float)1.0;
	}
	
	//
	// General interface methods        
	//
	public void set(int i, int j, float value)           // set element [j][i] to value
	{
		mat[j][i]  = value;
	}
	
	public float get(int i, int j)                       // return element [j][i]
	{
		return mat[j][i];
	}	
	
	// matrix multiplies the jth row of the matrix with the Point3D point
	private float pointRowMultiply(int j, Point3D point)
	{
		return mat[j][0]*point.x + mat[j][1]*point.y + mat[j][2]*point.z + mat[j][3];
	}
	
	// matrix multiplies the point and returns a point
	private Point3D pointMultiply(Point3D point)
	{
		float x = pointRowMultiply(0, point);
		float y = pointRowMultiply(1, point);
		float z = pointRowMultiply(2, point);
		float w = pointRowMultiply(3, point);
		return new Point3D(x/w, y/w, z/w);
	}
	
	//
	// Transform points from the in array to the out array
	// using the current matrix. The subset of points transformed
	// begins at the start index and has the specified length        
	//
	//    for (i = 0; i < length; i++)
	//        out[start+i] = this * in[start+i] 
	//
	public void transform(Point3D in[], Point3D out[], int start, int length)
	{
		for (int i = 0; i < length; i++)
			out[start+i] = pointMultiply(in[start+i]);
	}
	
	// Support Vertex3D also (start from 0)
	public void transform(Vertex3D in[], Vertex3D out[], int length)
	{
		for (int i = 0; i < length; i++)
		{
			out[i] = transform(in[i]);
		}
	}

	// Transform one vertex
	public Vertex3D transform(Vertex3D in)
	{
		Vertex3D out = new Vertex3D();
		Point3D point = pointMultiply(new Point3D(in.x, in.y, in.z));
		out.x = point.x;
		out.y = point.y;
		out.z = point.z;
		out.w = 1;
		if (in.hasNormal)
		{
			// transform normals here

			// get A'
			float a = mat[0][0];	float b = mat[0][1];	float c = mat[0][2];
			float e = mat[1][0];	float f = mat[1][1];	float g = mat[1][2];
			float i = mat[2][0];	float j = mat[2][1];	float k = mat[2][2];

			// find A' inverse
			float det11 = f*k - g*j;
			float det12 = b*k - c*j;
			float det13 = b*g - f*c;
			float det21 = e*k - g*i;
			float det22 = a*k - c*i;
			float det23 = a*g - e*c;
			float det31 = e*j - f*i;
			float det32 = a*j - b*i;
			float det33 = a*f - b*e;

			float det = a*det11 - b*det21 - c*det31;

			a =   det11 / det;
			b = - det21 / det;
			c =   det31 / det;
			e = - det12 / det;
			f =   det22 / det;
			g = - det32 / det;
			i =   det13 / det;
			j = - det23 / det;
			k =   det33 / det;

			// Create new matrix with the ranspose
			Matrix3D m = new Matrix3D(a, e, i,
				b, f, j,
				c, g, k);
			Point3D p = m.pointMultiply(new Point3D(in.nx, in.ny, in.nz));
			out.setNormal(p.x, p.y, p.z);
		}
		return out;
	}

	public final void compose(Matrix3D src)                         // this = this * src
	{
		// Have a temp matrix that stores the result
		Matrix3D tempmat = new Matrix3D();
		for (int i = 0; i < 4; i++)
			for (int j = 0; j < 4; j++)
		{
			tempmat.set(i, j,
				this.get(0, j) * src.get(i, 0) +
				this.get(1, j) * src.get(i, 1) +
				this.get(2, j) * src.get(i, 2) +
				this.get(3, j) * src.get(i, 3)
				);			
		}
		
		// Copy the results from temp matrix to this
		for (int i = 0; i < 4; i++)
			for (int j = 0; j < 4; j++)
			mat[j][i] = tempmat.get(i, j);
	}
	
	public void loadIdentity()                                      // this = identity
	{
		for (int i = 0; i < 4; i++)
			for (int j = 0; j < 4; j++)
		{
			if (i==j)
				mat[j][i] = 1;
			else
				mat[j][i] = 0;
		}
	}
	
	/*
	T:
	1	0	0	tx
	0	1	0	ty
	0	0	1	tz
	0	0	0	1
	*/
	public void translate(float tx, float ty, float tz)             // this = this * t
	{
		compose(new Matrix3D(
			1, 0, 0, tx,
			0, 1, 0, ty,
			0, 0, 1, tz,
			0, 0, 0, 1
			));
	}
	

	/*
	SCALE:
	sx	0	0	0
	0	sy	0	0
	0	0	sz	0
	0	0	0	1
	*/
	public void scale(float sx, float sy, float sz)                 // this = this * scale
	{
		compose(new Matrix3D(
			sx, 0, 0, 0,
			0, sy, 0, 0,
			0, 0, sz, 0,
			0, 0, 0, 1
			));
	}
	
	/*
	SKEW:
	1	kxy	kxz	0
	0	1	kyz	0
	0	0	1	0
	0	0	0	1
	*/
	public void skew(float kxy, float kxz, float kyz)               // this = this * skew
	{
		compose(new Matrix3D(
			1, kxy, kxz, 0,
			0, 1, kyz, 0,
			0, 0, 1, 0,
			0, 0, 0, 1
			));
	}

	/*
	SYMMETRIC(ax, ay, az)*(1-cos(angle)) + SKEW(ax, ay, az)*sin(angle) + I*cos(angle)

	SYMMETRIC:
	ax*ax	ax*ay	ax*az
	ax*ay	ay*ay	ay*az
	ax*az	ay*az	az*az

	SKEW:
	0		-az		ay
	az		0		-ax
	-ay		ax		0

	I:
	1		0		0
	0		1		0
	0		0		1
	*/
	public void rotate(float ax, float ay, float az, float angle)   // this = this * rotate 
	{
		// Normalize A
        float l = (float) Math.sqrt(ax*ax + ay*ay + az*az);
		ax /= l;
		ay /= l;
		az /= l;
	
		/*
		a	b	c
		d	e	f
		g	h	i
		*/
		float a,b,c,d,e,f,g,h,i;
		float factor1, factor2, factor3;

		factor3 = (float)Math.cos(angle);
		factor2 = (float)Math.sin(angle);
		factor1 = 1.0f - factor3;

		a = (ax*ax)*factor1 +     factor3;	b = (ax*ay)*factor1 + -az*factor2;	c = (ax*az)*factor1 +  ay*factor2;
		d = (ax*ay)*factor1 +  az*factor2;	e = (ay*ay)*factor1 +     factor3;	f = (ay*az)*factor1 + -ax*factor2;
		g = (ax*az)*factor1 + -ay*factor2;	h = (ay*az)*factor1 +  ax*factor2;	i = (az*az)*factor1 +     factor3;

		compose(new Matrix3D(
			a, b, c,
			d, e, f,
			g, h, i
			));
	}
	
	/*
	L = AT - EYE
	Ln = normalized L (i.e. L/sqrt(lx^2 + ly^2 +lz^2))
	
	R = L x UP
	Rn = normalized R
	  
	U = R x L
	Un = normalized U
	
    LOOKAT:
		Rn		-Rn dot EYE
		Un		-Un dot EYE
		-Ln		Ln dot EYE
	0	0	0		1
	*/
	public void lookAt(float eyex, float eyey, float eyez,
		float atx,  float aty,  float atz,
		float upx,  float upy,  float upz) 						    // this = this * lookat
	{
		Point3D EYE, UP;
		Point3D L, Ln, R, Rn, U, Un;

		EYE = new Point3D(eyex, eyey, eyez);
		UP = new Point3D(upx, upy, upz);

		L = new Point3D(atx - eyex, aty - eyey, atz - eyez);
		Ln = normalize(L);

		R = cross(L, UP);
		Rn = normalize(R);

		U = cross(R, L);
		Un = normalize(U);
	
		compose(new Matrix3D(
			Rn.x, Rn.y, Rn.z, -dot(Rn, EYE),
			Un.x, Un.y, Un.z, -dot(Un, EYE),
			-Ln.x, -Ln.y, -Ln.z, dot(Ln, EYE),
			0, 0, 0, 1));
	}
	

	public static Point3D normalize(Point3D A)
	{
		float denom = (float)Math.sqrt(A.x*A.x + A.y*A.y + A.z*A.z);
		return new Point3D(A.x/denom, A.y/denom, A.z/denom);
	}

	public static float dot(Point3D A, Point3D B)
	{
		return A.x*B.x + A.y*B.y + A.z*B.z;
	}

	public static Point3D cross(Point3D A, Point3D B)
	{
		Matrix3D temp = new Matrix3D(
			0,-A.z,A.y,
			A.z,0,-A.x,
			-A.y,A.x,0);
		return temp.pointMultiply(B);
	}

	// FRUSTUM
	//
	// Same as perspective
	public void frustum(float left, float right,                // this = this * frustum
		float bottom, float top,
		float near, float far)
	{
		perspective(left, right,
			bottom, top,
			near, far);
	}

	//        
	// Assume the following projection transformations
	// transform points into the canonical viewing space  
	//
	/*
	PERSP:
	(2*near)/(right-left)	0						-(right+left)/(right-left)	0
	0						(2*near)/(bottom-top)	-(bottom+top)/(bottom-top)	0
	0						0						(far+near)/(far-near)		(-2*far*near)/(far-near)
	0						0						1							0
    */
	public void perspective(float left, float right,                // this = this * persp
		float bottom, float top,
		float near, float far)
	{
		compose(new Matrix3D(
			(2f*near)/(right-left),	0,						-(right+left)/(right-left),	0,
			0,						(2f*near)/(bottom-top),	-(bottom+top)/(bottom-top),	0,
			0,						0,						(far+near)/(far-near),		(-2f*far*near)/(far-near),
			0,						0,						1f,							0
			));
	}
	
	/*
	ORTHO:
	2/(right-left)	0				0				-(right+left)/(right-left)
	0				2/(bottom-top)	0				-(bottom+top)/(bottom-top)
	0				0				2/(far-near)	-(far+near)/(far-near)
	0				0				0				1
    */
	public void orthographic(float left, float right,               // this = this * ortho
		float bottom, float top,
		float near, float far)
	{
		compose(new Matrix3D(
			2f/(right-left),	0,					0,				-(right+left)/(right-left),
			0,					2f/(bottom-top),	0,				-(bottom+top)/(bottom-top),
			0,					0,					2f/(far-near),	-(far+near)/(far-near),
			0,					0,					0,				1f
			));
	}
}
