import java.awt.Color;
import java.util.Vector;

// My primitive
public class Cylinder implements Renderable {
    Surface surface;
    Vector3D top;
    Vector3D base;
    Vector3D center;
    float radius;
    float radSqr;
    Matrix3D m3d;
    float mag;

    public Cylinder(Surface surf, Vector3D base_var, Vector3D top_var, float rad) {
        surface = surf;
        top     = top_var;
        base    = base_var;
        radius  = rad;
        radSqr  = radius*radius;

        center = new Vector3D((top.x - base.x)/2 + base.x,
                              (top.y - base.y)/2 + base.y,
                              (top.z - base.z)/2 + base.z);
        //System.out.println("center: \n"+center); //debug
        Vector3D v = new Vector3D(top.x - center.x,
                                  top.y - center.y,
                                  top.z - center.z);
        mag = v.magnitude();
        //System.out.println("mag: "+mag); //debug
        float a = v.x/mag;
        float b = v.y/mag;
        float c = v.z/mag;
        float d = (float) Math.sqrt((double) (b*b + c*c));
        
        m3d = new Matrix3D();
        m3d.rotate_to_z_axis(mag, a, b, c, d, center.x, center.y, center.z);

        //Vertex3D temp1 = new Vertex3D(top.x-.25f, top.y, top.z);
        //Vertex3D temp1 = new Vertex3D(base.x, base.y, base.z);
        //Vertex3D temp1 = new Vertex3D(center.x, center.y, center.z);
        //Vertex3D temp2 = m3d.transform(temp1);
        //System.out.println("temp2: \n"+temp2);
        //System.exit(0);
    }

    public boolean intersect(Ray ray) {
        ///////////////////////////////////////////////////////////////
        // Equation for a sphere at origin:
        //  2    2    2
        // x  + y  - r  = 0
        //
        // Want t such that:
        // x = Ox + Dx * t
        // y = Oy + Dy * t
        //
        // Comes out to:
        //    2
        // a*t  + b*t + c = 0
        // 
        // Where:
        //       2   2
        // a = Dx +Dy
        // b = 2OxDx + 2OyDy
        //       2   2   2
        // c = Ox +Oy - r
        //
        // Plug these values into the quadratic formula
        // to solve for t
        ///////////////////////////////////////////////////////////////

        // Transform ray points to sphere's object points
        // to consider the sphere at the origin
        Vertex3D new_ray_origin = new Vertex3D(ray.origin.x,
                                               ray.origin.y,
                                               ray.origin.z);
        Vertex3D temp_pt = new Vertex3D(ray.origin.x + ray.direction.x,
                                        ray.origin.y + ray.direction.y,
                                        ray.origin.z + ray.direction.z);
        new_ray_origin = m3d.transform(new_ray_origin);
        temp_pt = m3d.transform(temp_pt);
        Vertex3D new_ray_direction = new Vertex3D(temp_pt.x - new_ray_origin.x,
                                         temp_pt.y - new_ray_origin.y,
                                         temp_pt.z - new_ray_origin.z);
        new_ray_direction.normalize();
        
        Vector3D lookAt = new Vector3D();
        lookAt.x = 0f;
        lookAt.y = 0f;
        lookAt.z = -6f;
        lookAt.normalize();
        /*
        if ((ray.origin.x == 0) &&
            (ray.origin.y == .25) &&
            (ray.origin.z == 6) &&
            (ray.direction.x == lookAt.x) &&
            (ray.direction.y == lookAt.y) &&
            (ray.direction.z == lookAt.z)) {
        System.out.println("new_ray_origin: \n"+new_ray_origin); //debug
        System.out.println("new_ray_direction: \n"+new_ray_direction); //debug
        }*/
        
        float Ox = new_ray_origin.x;
        float Oy = new_ray_origin.y;
        float Dx = new_ray_direction.x;
        float Dy = new_ray_direction.y;
        
        float a = Dx*Dx + Dy*Dy;
        float b = 2*(Ox*Dx + Oy*Dy);
        float c = Ox*Ox + Oy*Oy - radSqr;
        
        //            2
        // Calculate b  - 4*a*c and make sure it is > 0
        float disc = b*b - 4*a*c;
        //System.out.println("disc = "+disc); //debug
        if (disc < 0) {
            return false;
        }
        
        // Calculate square root of disc
        float sqdisc = (float) Math.sqrt((double) disc);

        // t = (-b +/- sqdisc)/2a = intersection point
        float t = (-b - sqdisc)/(2*a);
        if ((t > ray.t) || (t < 0)) {
            t = (-b + sqdisc)/(2*a);
            if ((t > ray.t) || (t < 0)) {
                return false;
            }
            return false;
        }
        
        // Check to make sure within bounds of the top and base caps
        float pz = new_ray_origin.z + t*new_ray_direction.z;
        //System.out.println("pz = "+pz);
        if ((pz > mag) || (pz < -mag)) {
            return false;
        }
        
        //System.out.println("t = "+t);
        
        ray.t = t;
        ray.object = this;
        
        //System.out.println("Hello!");
        
        return true;
        //return false;
    }


    public Color Shade(Ray ray, Vector lights, Vector objects, Color bgnd) {
        // An object shader doesn't really do too much other than
        // supply a few critical bits of geometric information
        // for a surface shader. It must must compute:
        //
        //   1. the point of intersection (p)
        //   2. a unit-length surface normal (n)
        //   3. a unit-length vector towards the ray's origin (v)
        //

        //System.out.println("Hello");

        float px = ray.origin.x + ray.t*ray.direction.x;
        float py = ray.origin.y + ray.t*ray.direction.y;
        float pz = ray.origin.z + ray.t*ray.direction.z;
        
        /*
        System.out.println("px = "+px);
        System.out.println("py = "+py);
        System.out.println("pz = "+pz);
        System.out.println("-\n-\n");*/
 
        Vector3D p = new Vector3D(px, py, pz);
        Vector3D v = new Vector3D(-ray.direction.x, -ray.direction.y, -ray.direction.z);
        
        Vector3D A = new Vector3D(px - top.x, py - top.y, pz - top.z);
        float a = A.magnitude();
        //float d = (float) Math.sqrt((double) (a*a - radSqr));
        
        Vector3D axis = new Vector3D(base.x - top.x,
                                     base.y - top.y,
                                     base.z - top.z);
        axis.normalize();
        float D = A.dot(axis);
        
        //if (d != D) {
        //    System.out.println("what the hell");
        //    System.out.println("D: "+D);
        //    System.out.println("d: "+d);
        //}
        
        axis.x = axis.x * D;
        axis.y = axis.y * D;
        axis.z = axis.z * D;
        
        Vector3D P = new Vector3D(top.x + axis.x,
                                  top.y + axis.y,
                                  top.z + axis.z);
        
        Vector3D n = new Vector3D(px - P.x, py - P.y, pz - P.z);
        //Vector3D n = new Vector3D(0, 0, 1);
        n.normalize();
        
        //System.out.println("n: \n"+n); //debug
        
        //if (px == 0) {
        //    System.out.println("n: \n"+n); //debug
        //}

        // The illumination model is applied
        // by the surface's Shade() method
        return surface.Shade(p, n, v, lights, objects, bgnd, false);
    }

    public String toString() {
        return ("cylinder top:"+top+" base:"+base+" radius:"+radius);
    }
}