import java.awt.Color;
import java.util.*;
import Ray;
import Renderable;
import Surface;
import Vector3D;


/**
 * Cylinder.java
 *   A renderable cylinder designed for ray tracers.
 *   Coded for MIT 6.837 Project 5, Fall 98
 *   Implementation by Jonathan Lie  12/1998
 */
class Cylinder implements Renderable
{
  Surface surface;
  Vector3D center; // midpoint of the axis
  Vector3D axis; // direction parallel to length of object
  float height;
  float radius;
  float radSqr;

  public Cylinder(Surface s, Vector3D c, Vector3D a, float h, float r) {
    surface = s;
    center = c;
    axis = a;
    axis.normalize();
    height = h;
    radius = r;
    radSqr = r*r;
  }

  public Cylinder() {}

  public boolean intersect(Ray ray) {
    /* Okay, let's see if I can manage to describe geometrical calculations
     * using only words .... I hope you can follow this.  =-S */

    /* Variables: ray.origin, ray.direction, center (of cylinder),
     *   axis (of cylinder), c1={ray.origin projected onto axis,
     *   h1={distance from c1 to center (in positive "axis" direction}
     */
    float h1 = axis.dot(center.x-ray.origin.x, center.y-ray.origin.y, center.z-ray.origin.z);

    /*   ray2axis={Vector3D from ray.origin to c1},
	 "P" (not used)={plane perpendicular to axis and containing ray.origin}
	 dir2={ray.direction, projected onto P}
	 s_t={sin(angle between dir2 and ray.direction)} [Cos_Theta] */
    // c1 = center - h1*axis
    // ray2axis = c1 - ray.origin
    Vector3D ray2axis = new Vector3D((center.x - h1*axis.x) - ray.origin.x,
				     (center.y - h1*axis.y) - ray.origin.y,
				     (center.z - h1*axis.z) - ray.origin.z);
    float s_t = ray.direction.dot(axis);
    float c_t = (float) Math.sqrt(1 - s_t*s_t);
    // There must be a way to optimize that c_t calculation!! =-S
    Vector3D dir2 = new Vector3D(ray.direction.x - s_t*axis.x,
				 ray.direction.y - s_t*axis.y,
				 ray.direction.z - s_t*axis.z);
    try {
      dir2.normalize();
    } catch (java.lang.ArithmeticException e) {
      // dir2 = 0-vector.  Thus, ray.direction is parallel to axis
      if (ray2axis.dot(ray2axis) >= radSqr)
	return false;
      else {
	if (h1 < -height/2) {
	  float t = -height/2 - h1;
	  if (ray.t > t) {
	    ray.t = t;
	    ray.object = this;
	    return true;
	  } else
	    return false;
	} else if (h1 < height/2) {
	  float t = height/2 - h1;
	  if (ray.t > t) {
	    ray.t = t;
	    ray.object = this;
	    return true;
	  } else
	    return false;
	} else
	  return false;
      }
    }

    // Begin "2-D" analysis of a line intersecting a circle
    /*   v={length of ray2axis projected onto dir2}
	 fSqr={square of distance from c1 to line_containing_dir2} */

    // Check if ray.origin is near axis (i.e. inside radius)
    float r2aSqr = ray2axis.dot(ray2axis);
    if (r2aSqr < radSqr)
      return close_to_axis_intersect(ray, ray2axis, r2aSqr, dir2, c_t,s_t, h1);

    /////  The rest of the code is for rays whose origin is more than radius
    /////  away from the axis
    float v = ray2axis.dot(dir2);
    if (v <= 0) {
      // dir2 and ray.direction are pointing away from axis
      return false;
    }
    float fSqr = r2aSqr - v*v;
    if (fSqr >= radSqr) { // trivial rejection
      return false;
    } else {
      /*   t2={distance from ray to point where dir2 intersects cylinder}
	   t={distance from ray to pt where ray.direction intersects this} */
      float t2 = v - (float) Math.sqrt(radSqr - fSqr);
      float t = t2 / c_t;
      if (ray.t > t) {
	// Check if this intersection point is w/in the height of the cylinder
	/*   h={distance along axis between center and intersection point} */
	float h = -h1 + ((float) Math.sqrt(t*t - t2*t2)
			 * /*sign of direction.dot(axis)*/ ((s_t>0 ?1:-1)));
	if (Math.abs(h) < height/2) {
//System.err.println("h1="+h1+", h="+h+", t="+t+", t2="+t2);
	  ray.t = t;
	  ray.object = this;
	  return true;
	} else {
	  // Okay.  Ray doesn't intersect front side.
	  // What about inner back side?
	  Vector3D newEye = new Vector3D(ray.origin.x+(t+0.001f)*ray.direction.x,
					 ray.origin.y+(t+0.001f)*ray.direction.y,
					 ray.origin.z+(t+0.001f)*ray.direction.z);
	  Ray newRay = new Ray(newEye, ray.direction);
	  newRay.t = ray.t - t;
	  newRay.object = ray.object;
	  if (intersect(newRay)) {
	    ray.t = t+newRay.t;
	    ray.object = newRay.object;
	    return true;
	  } else
	    return false;
	}
      } else {
	return false;
      }
    }
  }

  private boolean close_to_axis_intersect(Ray ray, Vector3D ray2axis,
					  float r2aSqr, Vector3D dir2,
					  float c_t, float s_t, float h1) {
    float v = ray2axis.dot(dir2);
    float t2 = v+((float) Math.sqrt(radSqr - (r2aSqr - v*v)));
    float t = t2 / c_t;
    float h = -h1 + ((float) Math.sqrt(t*t - t2*t2)
		     * /*sign of direction.dot(axis)*/ ((s_t>0 ?1:-1)));
    if (Math.abs(h) < height/2) { // Hits side of cylinder
/*
      if (ray.t > t) {
	ray.t = t;
	ray.object = this;
	return true;
      } else
	return false;
*/

      // Check if it hits the end of the cylinder
      if (-h1 < -height/2) {
	// Yes, the ray hits the (bottom) end before the side
	t *= (h1-height/2)/(h+h1); // by similar triangles
	if (ray.t > t) {
	  ray.t = t;
	  ray.object = this;
	  return true;
	} else
	  return false;

      } else if (-h1 > height/2) {
	// Yes, the ray hits the (top) end before the side
	t *= (h1+height/2)/(h+h1); // by similar triangles
	if (ray.t > t) {
	  ray.t = t;
	  ray.object = this;
	  return true;
	} else
	  return false;
      } else { // ray is inside, so if it hits the side, it doesn't hit the end
	if (ray.t > t) {
	  ray.t = t;
	  ray.object = this;
	  return true;
	} else
	  return false;
      }

    } else { // Misses side of cylinder
/*
      return false;
*/

      if (h > 0) { // it's aimed to hit the side beyond the "above" end
	if (h1 > 0) { // ray.origin is on the "below" half, so hit bottom end
	  t *= (h1-height/2)/(h1+h);
	  if (ray.t > t) {
	    ray.t = t;
	    ray.object = this;
	    return true;
	  } else
	    return false;

	} else { // ray.origin is on the "above" half, so no hit
	  return false;
	}

      } else { // it's aimed to hit the side beyond the "below" end
	if (h1 < 0) { // ray.origin is on the "above" half, so hit "above" end
	  t *= (h1+height/2)/(h1+h);
	  if (ray.t > t) {
	    ray.t = t;
	    ray.object = this;
	    return true;
	  } else
	    return false;

	} else { // ray.origin is on the "below" half, so no hit
	  return false;
	}
      }

    } 
  }

  public Color Shade(Ray ray, Vector lights, Vector objects, Color bgnd, int depth) {
    // An object shader doesn't really do too much other than
    // supply a few critical bits of geometric information
    // for a surface shader. It must must compute:
    //
    //   1. the point of intersection (p)
    //   2. a unit-length surface normal (n)
    //   3. a unit-length vector towards the ray's origin (v)
    //
    float px = ray.origin.x + ray.t*ray.direction.x;
    float py = ray.origin.y + ray.t*ray.direction.y;
    float pz = ray.origin.z + ray.t*ray.direction.z;

    Vector3D p = new Vector3D(px, py, pz);
    Vector3D v = new Vector3D(-ray.direction.x, -ray.direction.y, -ray.direction.z);
    Vector3D n;

    // project p onto the axis of the cylinder
    float p_proj = axis.dot(px-center.x, py-center.y, pz-center.z);
    boolean atEnds = (Math.abs(p_proj) > height/2 - 0.00001f);
    if (atEnds)
      if (p_proj > 0)
	n = new Vector3D(axis);
      else
	n = new Vector3D(-axis.x, -axis.y, -axis.z);
    else
      n = new Vector3D((px - (center.x + p_proj*axis.x)) / radius,
		       (py - (center.y + p_proj*axis.y)) / radius,
		       (pz - (center.z + p_proj*axis.z)) / radius );

    return surface.Shade(p, n, v, lights, objects, bgnd, depth);
  }

  public String toString() {
    return ("cylinder "+center+", "+axis+", "+height+", "+radius);
  }

}
