import Raster;
import Point3D;
import Vertex3D;
import Vector3D;
import java.math.*;

public class Matrix3D {        
	
	float matrix[][];	

	// Constructors      
	public Matrix3D(){			// initialize with identity transform
	matrix=new float[4][4];
	loadIdentity();
	}                  
	
	public Matrix3D(Matrix3D copy){ // initialize with copy of source
		matrix=new float[4][4];
		for(int i=0; i<4; i++){
			for(int j=0; j<4; j++){
			matrix[i][j]=copy.matrix[i][j];
			}
		}
	}     
	
	public Matrix3D(Raster r){   // initialize with a mapping from
								 // canonical space to screen space
								// map z from 0 to 1
		matrix=new float[4][4];
		loadScreenSpace(r);
	
	}         
                                         
	
	// General interface methods                
	public void set(int i, int j, float value){ // set element [j][i] to value
		matrix[i][j]=value;
	}          
	
	public float get(int i, int j){   // return element [j][i]
		return matrix[i][j];
	}                      
	
	
	// Transform points from the in array to the out array
    // using the current matrix. The subset of points transformed
    // begins at the start index and has the specified length       
    //    for (i = 0; i < length; i++)
    //        out[start+i] = this * in[start+i]       
        
	public void transform(Point3D in[], Point3D out[], int start, int length){
		for(int i=0; i<length; i++){
			float w=matrix[3][0]*in[start+i].x+matrix[3][1]*in[start+i].y+matrix[3][2]*in[start+i].z +matrix[3][3];
			out[start+i].x=(matrix[0][0]*in[start+i].x+matrix[0][1]*in[start+i].y+matrix[0][2]*in[start+i].z +matrix[0][3])/w;
			out[start+i].y=(matrix[1][0]*in[start+i].x+matrix[1][1]*in[start+i].y+matrix[1][2]*in[start+i].z +matrix[1][3])/w;
			out[start+i].z=(matrix[2][0]*in[start+i].x+matrix[2][1]*in[start+i].y+matrix[2][2]*in[start+i].z +matrix[2][3])/w;
		}
	
	}
	
	public void transform(Point3D in[], Point3D out[], int length){
		for(int i=0; i<length; i++){
			float w=matrix[3][0]*in[i].x+matrix[3][1]*in[i].y+matrix[3][2]*in[i].z +matrix[3][3];
			out[i].x=(matrix[0][0]*in[i].x+matrix[0][1]*in[i].y+matrix[0][2]*in[i].z +matrix[0][3])/w;
			out[i].y=(matrix[1][0]*in[i].x+matrix[1][1]*in[i].y+matrix[1][2]*in[i].z +matrix[1][3])/w;
			out[i].z=(matrix[2][0]*in[i].x+matrix[2][1]*in[i].y+matrix[2][2]*in[i].z +matrix[2][3])/w;
		}
	
	}
	
	//transforms vertex3ds along with normals if there are any
	public void transform(Vertex3D in[], Vertex3D out[], int length){ 
		for(int i=0; i<length; i++){
			
			out[i].x=(matrix[0][0]*in[i].x+matrix[0][1]*in[i].y+matrix[0][2]*in[i].z +matrix[0][3]*in[i].w);
			out[i].y=(matrix[1][0]*in[i].x+matrix[1][1]*in[i].y+matrix[1][2]*in[i].z +matrix[1][3]*in[i].w);
			out[i].z=(matrix[2][0]*in[i].x+matrix[2][1]*in[i].y+matrix[2][2]*in[i].z +matrix[2][3]*in[i].w);
			out[i].w=(matrix[3][0]*in[i].x+matrix[3][1]*in[i].y+matrix[3][2]*in[i].z +matrix[3][3]*in[i].w);
			out[i].normalize();
		
		
		//transform normals using upper 3x3 then normalize it
			if(out[i].hasNormal){
				out[i].nx=(matrix[0][0]*in[i].nx+matrix[0][1]*in[i].ny+matrix[0][2]*in[i].nz);
				out[i].ny=(matrix[1][0]*in[i].nx+matrix[1][1]*in[i].ny+matrix[1][2]*in[i].nz);
				out[i].nz=(matrix[2][0]*in[i].nx+matrix[2][1]*in[i].ny+matrix[2][2]*in[i].nz);
				out[i].setNormal(out[i].nx, out[i].ny, out[i].nz);
			}
		}
			
		   
	
	}
        
	public Vector3D transformall(Vector3D in){
		Vector3D v=new Vector3D();
		v.x=matrix[0][0]*in.x+matrix[0][1]*in.y+matrix[0][2]*in.z+matrix[0][3];
		v.y=matrix[1][0]*in.x+matrix[1][1]*in.y+matrix[1][2]*in.z+matrix[1][3];
		v.z=matrix[2][0]*in.x+matrix[2][1]*in.y+matrix[2][2]*in.z+matrix[2][3];
		return v;
	
	}
	
	public Vector3D transform(Vector3D in){
		Vector3D v=new Vector3D();
		v.x=matrix[0][0]*in.x+matrix[0][1]*in.y+matrix[0][2]*in.z;
		v.y=matrix[1][0]*in.x+matrix[1][1]*in.y+matrix[1][2]*in.z;
		v.z=matrix[2][0]*in.x+matrix[2][1]*in.y+matrix[2][2]*in.z;
		return v;
	
	}
	
	public final void compose(Matrix3D src){                        // this = this * src
		matrixMultiply4x4(matrix,src.matrix);
	}
	
	public void loadIdentity(){                                      // this = identity
		
		for (int i=0; i<4; i++){
			for(int j=0; j<4; j++){
				if(i==j){
					matrix[i][j]=1;
				}
	            else{matrix[i][j]=0;}					   
			}
		}
	}
	
	public void loadIdentity(float m[][]){            // load identity into any matrix 4x4
		
		for (int i=0; i<4; i++){
			for(int j=0; j<4; j++){
				if(i==j){
					m[i][j]=1;
				}
	            else{m[i][j]=0;}					   
			}
		}
	}
	
	public void loadIdentity(float m[][], int size){            // create sizexsize identity
		
		for (int i=0; i<size; i++){
			for(int j=0; j<size; j++){
				if(i==j){
					m[i][j]=1;
				}
	            else{m[i][j]=0;}					   
			}
		}
	}
		
	public void loadScreenSpace(Raster r){
		matrix=new float[4][4];
		matrix[0][0]=r.width/2;
		matrix[0][3]=r.width/2;
		matrix[1][1]=-r.height/2; //invert y coordinate for screen
		matrix[1][3]=r.height/2;
		matrix[2][2]=(float).5;   //scale and shift z such that -1 to 1 maps to 0 to 1 
		matrix[2][3]=(float).5;
		matrix[3][3]=1;
	}
	public void translate(float tx, float ty, float tz){            // this = this * t    
		float translate[][]=new float[4][4];
		loadIdentity(translate);
		translate[0][3]=tx;
		translate[1][3]=ty;
		translate[2][3]=tz;
		matrixMultiply4x4(matrix,translate);
		
	}
	public void scale(float sx, float sy, float sz){                 // this = this * scale
		float scale[][]=new float[4][4];
		scale[0][0]=sx;
		scale[1][1]=sy;
		scale[2][2]=sz;
		scale[3][3]=1;
		matrixMultiply4x4(matrix,scale);
	}
	
	public void skew(float kxy, float kxz, float kyz){              // this = this * skew
		float skew[][]=new float[4][4];
		loadIdentity(skew);
		skew[0][1]=kxy;
		skew[0][2]=kxz;
		skew[1][2]=kyz;
		matrixMultiply4x4(matrix,skew);
	}
		
	public void rotate(float ax, float ay, float az, float angle){   // this = this * rotate 
		float symmetric[][]=new float[3][3];       //uses method from class to rotate around
		float identity[][]=new float[3][3];			//arbitrary axis		
		float skew[][]=new float[3][3];
		float v[]={ax, ay, az};
		normalize(v);
		symmetricMatrix(symmetric,v);
		skewMatrix(skew,v);
		loadIdentity(identity,3);
		scaleMatrix_mxn(symmetric,3,3,(1-(float)Math.cos(angle)));
		scaleMatrix_mxn(skew,3,3,(float)Math.sin(angle));
		scaleMatrix_mxn(identity,3,3,(float)Math.cos(angle));
		addMatrix_mxn(symmetric,skew,3,3);
		addMatrix_mxn(symmetric,identity,3,3);
		float tmp[][]=new float[4][4];
		for(int i=0; i<3; i++){
			for(int j=0; j<3;j++){
				tmp[i][j]=symmetric[i][j];
			}
		}
		tmp[3][3]=1;
		matrixMultiply4x4(matrix,tmp);
		
		
		
		
	}
	
	public void lookAt(float eyex, float eyey, float eyez,
                           float atx,  float aty,  float atz,
						   float upx,  float upy,  float upz){          // this = this * lookat
		float l[]={atx, aty, atz};  //create at vector
		float eye[]={eyex, eyey, eyez}; //create eye vector
		subtractVector(l,eye);          //create look vectore
		float up[]={upx, upy, upz};     //up vector
		float r[]={l[0],l[1],l[2]};     //create second eye space basis vector by taking cross of 
		cross(r,up);                    //up vector and alread computed l vector
		float u[]={r[0],r[1],r[2]};     //create third vector by taking cross of already computed 
		cross(u,l);
		normalize(l);					//normalize all basis vectors
		normalize(r);
		normalize(u);
		float tmp[][]=new float[4][4]; //create 4x4 matrix which changes basis and translates 
		tmp[0][0]=r[0];				   //world space such that eye is (0,0,0)
		tmp[0][1]=r[1];
		tmp[0][2]=r[2];
		scaleVector(r,-1);
		
		
		tmp[0][3]=dot(r,eye);
		tmp[1][0]=u[0];
		tmp[1][1]=u[1];
		tmp[1][2]=u[2];
		scaleVector(u,-1);
		tmp[1][3]=dot(u,eye);
		tmp[2][3]=dot(l,eye);
		scaleVector(l,-1);
		tmp[2][0]=l[0];
		tmp[2][1]=l[1];
		tmp[2][2]=l[2];
		tmp[3][3]=1;
		
		matrixMultiply4x4(matrix,tmp);	
		
	}
	// Assume the following projection transformations    
	// transform points into the canonical viewing space        
        
	public void perspective(float left, float right,                // this = this * persp
                                float bottom, float top,			//assumes looking down -z and near> far
								float near, float far){
		float perspective[][]=new float[4][4];
		
		perspective[0][0]=-(2*near)/(right-left);
		perspective[0][2]=(right+left)/(right-left);
		perspective[1][1]=-(2*near)/(top-bottom);
		perspective[1][2]=(top+bottom)/(top-bottom);
		
		perspective[2][2]=(far+near)/(near-far);
		perspective[2][3]=-(2*far*near)/(near-far);
		perspective[3][2]=-1;
		
		
		
		
		matrixMultiply4x4(matrix,perspective);
		
	}
        
	public void orthographic(float left, float right,               // this = this * ortho
                                 float bottom, float top,			// assumes near>far
								 float near, float far){			//simply a scale and  translation
		float orthographic[][]=new float[4][4];
		orthographic[0][0]=2/(right-left);
		orthographic[0][3]=-(right+left)/(right-left);
		orthographic[1][1]=2/(top-bottom);
		orthographic[1][3]=-(top+bottom)/(top-bottom);
		orthographic[2][2]=2/(far-near);
		orthographic[2][3]=-(far+near)/(far-near);
		orthographic[3][3]=1;
		matrixMultiply4x4(matrix,orthographic);
		
	}    


	/** Multiply 4x4 matrix and store results in a
	 *  a=a X b similar to book p424
	 */
	public void matrixMultiply4x4(float a[][], float b[][]){
		float tmp[][]=new float[4][4];
		int i,j;
		
		
		for(i=0; i<4; i++){
			for(j=0; j<4; j++){
				tmp[i][j]=a[i][0]*b[0][j]+a[i][1]*b[1][j]+
						  a[i][2]*b[2][j]+a[i][3]*b[3][j];
			}
		}
		
		// copy matrix tmp into a 
		for(i=0; i<4; i++){
			for(j=0; j<4; j++){
				a[i][j]=tmp[i][j];
			}
		}
	}
	
	/** adds mxn matricies and stores in a
	 */
	public void addMatrix_mxn(float a[][], float b[][], int m, int n){
		for(int i=0;i<m; i++){
			for(int j=0; j<n; j++){
				a[i][j]=a[i][j]+b[i][j];
			}
		}
	}
	
	/** subtracts mxn matricies and stores in a
	 */
	public void subtractMatrix_mxn(float a[][], float b[][], int m, int n){
		for(int i=0;i<m; i++){
			for(int j=0; j<n; j++){
				a[i][j]=a[i][j]-b[i][j];
			}
		}
	}
	/** creates symmetric matrix a out of vector v[x y z]
	 */
	public void symmetricMatrix(float s[][], float v[]){
		s[0][0]=v[0]*v[0];
		s[0][1]=v[0]*v[1];
		s[0][2]=v[0]*v[2];
		s[1][0]=v[1]*v[0];
		s[1][1]=v[1]*v[1];
		s[1][2]=v[1]*v[2];
		s[2][0]=v[2]*v[0];
		s[2][1]=v[2]*v[1];
		s[2][2]=v[2]*v[2];
	}
	
	/** creates skew matrix a out of vector v[x y z]
	 */
	public void skewMatrix(float s[][], float v[]){
		s[0][1]=-v[2];
		s[0][2]=v[1];
		s[1][0]=v[2];
		s[1][2]=-v[0];
		s[2][0]=-v[1];
		s[2][1]=v[0];
	}
	
	/** scale all elements of a matrix by a constant
	 */
	public void scaleMatrix_mxn(float matrix[][],int m, int n, float scale){
		for(int i=0; i<m;i++){
			for(int j=0; j<m; j++){
				matrix[i][j]=matrix[i][j]*scale;
			}
		}
					 
	}
	
	/** Normalize vector
	 */
	public void normalize(float v[]){
		float length=0;
		for(int i=0; i<v.length; i++){
			length=length+v[i]*v[i];
		}
		length=(float)Math.sqrt(length);
		if(length!=0){
			for(int i=0; i<v.length; i++){
				v[i]=v[i]/length;
			}
		}
	}
		
	/** takes dot product of two vectors and stores returns value
	 */
	
	public float dot(float a[], float b[]){
		float ans=0;
		for(int i=0; i<a.length; i++){
			ans=ans+a[i]*b[i];
		}
		return ans;
	}
	
	/** takes cross product of a and b and returns in a
	 */
	
	public void cross(float a[], float b[]){
		float tmp[]=new float[a.length];
		tmp[0]=-a[2]*b[1]+a[1]*b[2];
		tmp[1]=a[2]*b[0]-a[0]*b[2];
		tmp[2]=-a[1]*b[0]+a[0]*b[1];
		
		a[0]=tmp[0];
		a[1]=tmp[1];
		a[2]=tmp[2];
	}
	
	public void scaleVector(float a[], float scale){
		for(int i=0;i<a.length; i++){
			a[i]=a[i]*scale;
		}
	}
	
	public void subtractVector(float a[], float b[]){
		for(int i=0; i<a.length; i++){
			a[i]=a[i]-b[i];
		}
	}
}