import java.awt.Color;
import java.util.Vector;
import java.util.Enumeration;
import Ray;

public class Surface {
    public float ir, ig, ib;        // surface's intrinsic color
    public float ka, kd, ks, ns;    // constants for phong model
    public float kt, kr, nt;
    private static final float TINY = 0.001f;
    private static final float I255 = 0.00392156f;  // 1/255

    public Surface(float rval, float gval, float bval, float a, float d, float s, float n, float r, float t, float index) {
        ir = rval; ig = gval; ib = bval;
        ka = a; kd = d; ks = s; ns = n;
        kr = r*I255; kt = t*I255; nt = index; // CHANGED kt = t to kt = t*I255
    }

    public Color Shade(Vector3D p, Vector3D n, Vector3D v, Vector lights, Vector objects, Color bgnd,
                       boolean inside) { // boolean is NEW
        //   p: the point of intersection
        //   n: a unit-length surface normal
        //   v: a unit-length vector towards the ray's origin
        
        if (inside) { // NEW
            n.x = -n.x;
            n.y = -n.y;
            n.z = -n.z;
        }
        
        Enumeration lightSources = lights.elements();

        float r = 0;
        float g = 0;
        float b = 0;
        while (lightSources.hasMoreElements()) {
            Light light = (Light) lightSources.nextElement();
            if (light.lightType == Light.AMBIENT) {
                r += ka*ir*light.ir;
                g += ka*ig*light.ig;
                b += ka*ib*light.ib;
            } else {
                Vector3D l;
                if (light.lightType == Light.POINT) {
                    // l is the direction from the light to the ray's
                    // point of intersection
                    l = new Vector3D(light.lvec.x - p.x, light.lvec.y - p.y, light.lvec.z - p.z);
                    l.normalize();
                } else { // DIRECTIONAL
                    l = new Vector3D(-light.lvec.x, -light.lvec.y, -light.lvec.z);
                }

                // Check if the surface point is in shadow
                Vector3D poffset = new Vector3D(p.x + TINY*l.x, p.y + TINY*l.y, p.z + TINY*l.z);
                Ray shadowRay = new Ray(poffset, l);
                if (shadowRay.trace(objects))
                    break;

                float lambert = Vector3D.dot(n,l);
                if (lambert > 0) {
                    if (kd > 0) {
                        float diffuse = kd*lambert;
                        r += diffuse*ir*light.ir;
                        g += diffuse*ig*light.ig;
                        b += diffuse*ib*light.ib;
                    }
                    if (ks > 0) {
                        lambert *= 2;
                        float spec = v.dot(lambert*n.x - l.x, lambert*n.y - l.y, lambert*n.z - l.z);
                        if (spec > 0) {
                            spec = ks*((float) Math.pow((double) spec, (double) ns));
                            r += spec*light.ir;
                            g += spec*light.ig;
                            b += spec*light.ib;
                        }
                    }
                }
            }
        }

        // Compute illumination due to reflection
        //if ((kr > 0) || false) {
        if (kr > 0) {
            float t = v.dot(n); // Magnitude of v in direction of normal
            if (t > 0) {
                ///////////////////////////////////////////////////////////
                // v   t*n    reflect
                // ._   +   _.
                // |\  /|\  /|
                //   \  |  /
                //    \ | /
                //     \|/
                // ------------
                ///////////////////////////////////////////////////////////
                t *= 2;
                Vector3D reflect = new Vector3D(t*n.x - v.x, t*n.y - v.y, t*n.z - v.z);
                Vector3D poffset = new Vector3D(p.x + TINY*reflect.x, p.y + TINY*reflect.y, p.z + TINY*reflect.z);
                Ray reflectedRay = new Ray(poffset, reflect);
                if (reflectedRay.trace(objects)) {
                    // NEW ////////////////////////////////////////////////
                    if (inside) { // NEW
                        Ray.depth++;
                        //System.out.println("depth = "+Ray.depth);
                        if (Ray.depth >= RayTrace.MAXDEPTH) {
                            //System.out.println("Reset");
                            Ray.depth = 0;
                            return new Color(0, 0, 0);
                        }
                    }
                    Color rcolor = reflectedRay.Shade(lights, objects, bgnd);
                    r += kr*rcolor.getRed();
                    g += kr*rcolor.getGreen();
                    b += kr*rcolor.getBlue();
                    //if (inside) { // NEW
                    //    System.out.println("r: "+r);
                    //    System.out.println("g: "+g);
                    //    System.out.println("b: "+b);
                    //}
                } else {
                    r += kr*bgnd.getRed();
                    g += kr*bgnd.getGreen();
                    b += kr*bgnd.getBlue();
                }
            }
        }

        ///////////////////////////////////////////////////////////////////
        // R E F R A C T I O N
        ///////////////////////////////////////////////////////////////////
        // v       n     reflect
        // ._      +      _.
        // |\  A  /|\     /|
        //   \  i  |     /
        //    \    |    /
        //     \   |   /
        //      \  |  /
        //       \ | /
        //        \|/
        // -------------------
        //         |\
        //         | \
        //         |  \
        //         |   \
        //         | A  \
        //         |  r  \
        //        \|/    _\|
        //         +       *
        //        -n     refract
        ///////////////////////////////////////////////////////////////////
        // Snell's Law: n *sin(A )=n *sin(A )
        //               i      i   t      t
        ///////////////////////////////////////////////////////////////////

        //if ((kt > 0) || false) {
        if (kt > 0) {
            float nt_real = nt;
            if (inside)
                nt_real = 1/nt;
            ///////////////////////////////////////////////////////////
            // v.dot(n) = |v| * |n| * cos (A )
            //                              i
            //             v.dot(n)
            // cos (A ) = ---------- = v.dot(n)
            //       i      1 * 1
            ///////////////////////////////////////////////////////////
            //                  _____________________________
            //                 /       2             2       
            // cos (A ) = -+  / 1 - ( n ) * ( 1 - cos A )
            //       r      \/         t               i
            ///////////////////////////////////////////////////////////
            // refract =  n * (-v) - (cos A  - n * cos A )*n
            //             t               r    t       i
            ///////////////////////////////////////////////////////////
            float cosAi = v.dot(n);
            if (cosAi > 0) {
                float cos2Ar = 1 - ((nt_real*nt_real) * (1 - (cosAi*cosAi)));
                float cosAr  = (float) Math.sqrt((double) cos2Ar);
                Vector3D refract = new Vector3D();
                refract.x = -nt_real*v.x - (cosAr - nt_real*cosAi)*n.x;
                refract.y = -nt_real*v.y - (cosAr - nt_real*cosAi)*n.y;
                refract.z = -nt_real*v.z - (cosAr - nt_real*cosAi)*n.z;
                Vector3D poffset = new Vector3D(p.x + TINY*refract.x, p.y + TINY*refract.y, p.z + TINY*refract.z);
                Ray refractedRay = new Ray(poffset, refract);
                if (refractedRay.trace(objects)) {
                    //System.out.println("Refracted ray hit");
                    Color rcolor = refractedRay.Shade(lights, objects, bgnd);
                    r += kt*rcolor.getRed();
                    g += kt*rcolor.getGreen();
                    b += kt*rcolor.getBlue();
                } else {
                    //System.out.println("Refracted ray didn't hit"); //debug
                    r += kt*bgnd.getRed();
                    g += kt*bgnd.getGreen();
                    b += kt*bgnd.getBlue();
                }
            }
        }


//                    if (inside) { // NEW
//                        System.out.println("r: "+r);
//                        System.out.println("g: "+g);
//                        System.out.println("b: "+b);
//                    }

        r = (r > 1f) ? 1f : r;
        g = (g > 1f) ? 1f : g;
        b = (b > 1f) ? 1f : b;
        return new Color(r, g, b);
    }
}
