diff --git a/shaders/ground.frag b/shaders/ground.frag
index 590425aa1a091d4cadda29d398f7c9f742bcdad0..f093d0509e69c2fd8b2b6e83f5b6db43956b9395 100644
--- a/shaders/ground.frag
+++ b/shaders/ground.frag
@@ -17,6 +17,7 @@ const float tex_scale = 4.0;
 
 const float PI = 3.1415926535897f;
 const float ETA = 1.333;
+const float R0 = 0.0204;
 const float ATTENUATION = 0.145f;
 
 const vec2 CAUSTICS_OFFSETS[] = vec2[](
@@ -95,6 +96,12 @@ vec2 sphere_uv(vec3 direction)
     return vec2(u, v);
 }
 
+float fresnel(vec3 normal, vec3 view) {
+    float cos_theta = dot(normal, -view);
+    float x = 1.0 - cos_theta;
+    return mix(R0, 1.0, x * x * x * x * x);
+}
+
 float unlerp(float lo, float hi, float t)
 {
     return clamp((t - lo) / (hi - lo), 0.0, 1.0);
@@ -115,37 +122,35 @@ vec3 wave_normal_at(vec3 pos) {
 vec3 caustics(vec3 nnormal)
 {
     vec3 light = vec3(0);
-    int count = 1;
+    //int count = 1;
     for (int i = 0; i < CAUSTICS_OFFSETS.length(); ++i) {
         vec2 cur_offset = CAUSTICS_OFFSETS[i];
         vec3 offset_pos = position + vec3(cur_offset.x, 0, cur_offset.y);
 
         vec3 wave_pos = vec3(offset_pos.x, depth_at(offset_pos), offset_pos.z);
-        vec3 wave_normal = wave_normal_at(offset_pos);
+        vec3 wave_normal = normalize(wave_normal_at(offset_pos));
 
         vec3 ray = wave_pos - position;
         float ray_length = length(ray);
         vec3 incident = ray / ray_length;
         
-        vec3 refracted = refract(incident, wave_normal, ETA);
+        vec3 refracted = refract(incident, -wave_normal, ETA);
         float shade = exp(-ray_length * ATTENUATION);
-        
+        float F = 1.0 - fresnel(wave_normal, -refracted); // Fraction of refracted light
+
         if (incident.y > 0 && refracted != vec3(0)) {
-            light += dot(nnormal, incident) * texture(skybox, sphere_uv(refracted)).rgb * shade;
-            //light += refracted;
-            count += 1;
+            light += dot(nnormal, incident) * texture(skybox, sphere_uv(refracted)).rgb * shade * F;
         } 
     }
 
     // ambient light
-    return 0.05 + 2.0 * light / float(CAUSTICS_OFFSETS.length());
+    return vec3(0.05) + light / float(CAUSTICS_OFFSETS.length());
 }
 
 
 void main(void)
 {
     vec3 nnormal = normalize(normal);
-    //float light = max(0.0, dot(nnormal, sun));
 
     vec3 albedo = mix(texture(dirt, tex_coord * tex_scale).rgb, texture(grass, tex_coord * tex_scale).rgb, unlerp(-0.2, 0.2, position.y));
 
@@ -158,6 +163,5 @@ void main(void)
     }
 
     out_Color = vec4(albedo * light, 1.0);
-    //out_Color = texture(surf_normal, tex_coord);
 }
 
diff --git a/shaders/surface.frag b/shaders/surface.frag
index cd97f4c09f277f7297e852a194f575c0a64c35e9..65d2b8d92720e3a08b2cc0be3ebc1556ae5ed55b 100644
--- a/shaders/surface.frag
+++ b/shaders/surface.frag
@@ -18,9 +18,12 @@ const float R0 = 0.0204;
 const float ATTENUATION = 0.145f;
 //const float ATTENUATION = 0.5f;
 
-const float MIN_T = 0.05;
-const float MAX_T = 10.0;
-const float DT = 0.1;
+const float IOR = 0.75;
+const float INV_IOR = 1.333;
+
+const float MIN_T = 0.02;
+const float MAX_T = 7.5;
+const float DT = 0.05;
 
 out vec4 out_Color;
 
@@ -35,7 +38,7 @@ vec2 sphere_uv(vec3 direction)
 
 float fresnel(vec3 normal, vec3 view) {
     float cos_theta = dot(normal, -view);
-    float x = 1.0 - cos_theta;
+    float x = clamp(1.0 - cos_theta, 0, 1);
     return mix(R0, 1.0, x * x * x * x * x);
 }
 
@@ -109,6 +112,9 @@ vec3 reflection_ray_second_bounce(vec3 pos, vec3 dir) {
                 vec3 sky_color = texture(sky, sphere_uv(reflected)).rgb;
                 vec3 refracted = refract(dir, nnormal, 0.75);
                 vec3 water_color = refraction_ray(world_pos, refracted);
+
+                return mix(water_color, sky_color, R);
+                //return map(dot(dir, nnormal));
             }
         }
 
@@ -120,8 +126,9 @@ vec3 reflection_ray_second_bounce(vec3 pos, vec3 dir) {
     return texture(sky, sphere_uv(dir)).rgb;
 }
 
-vec3 reflection_ray_first_bounce(vec3 pos, vec3 dir) {
+vec3 inside_reflection_ray_second_bounce(vec3 pos, vec3 dir) {
     float lh = 0.0f;
+    float lh2 = 0.0f;
     float ly = 0.0f;
 
     for (float t = MIN_T; t < MAX_T; t += DT) {
@@ -135,37 +142,44 @@ vec3 reflection_ray_first_bounce(vec3 pos, vec3 dir) {
         float wave_y_at = wave_depth_at(cur_pos);
         float y_at = max(ground_y_at, wave_y_at);
 
-        if (cur_pos.y < y_at) {
+        if (cur_pos.y < ground_y_at) {
             float res_t = t - DT + DT * (lh - ly) / (cur_pos.y - ly - y_at + lh);
             vec3 hit_pos = pos + res_t * dir;
+            float shade = exp(-res_t * ATTENUATION);
 
-            if (ground_y_at > wave_y_at) {
-                return texture(ground, tex_at(hit_pos)).rgb;
-            } else {
-                vec3 nnormal = normalize(wave_normal_at(hit_pos));
-                vec3 reflected = reflect(dir, nnormal);
-                float R = fresnel(nnormal, dir);
-                // vec3 water_color = vec3(0, 0.05, 0.1);
+            return shade * texture(ground, tex_at(hit_pos)).rgb;
+        } else if (cur_pos.y > wave_y_at) {
+            float res_t = t - DT + DT * (lh2 - ly) / (cur_pos.y - ly - y_at + lh2);
+            vec3 hit_pos = pos + res_t * dir;
 
-                vec3 sky_color = reflection_ray_second_bounce(cur_pos, reflected);
-                //vec3 sky_color = texture(sky, sphere_uv(reflected)).rgb;
-                vec3 refracted = refract(dir, nnormal, 0.75);
-                vec3 water_color = refraction_ray(cur_pos, refracted);
-                return mix(water_color, sky_color, R);
+            vec3 nnormal = normalize(wave_normal_at(hit_pos));
+            vec3 reflected = reflect(dir, nnormal);
+            vec3 refracted = refract(dir, -nnormal, 1.333);
+            float R = fresnel(nnormal, -refracted);
+            if (dot(refracted, refracted) < 0.1) {
+                R = 1.0;
             }
+            // vec3 water_color = vec3(0, 0.05, 0.1);
+
+            //vec3 sky_color = reflection_ray_second_bounce(cur_pos, refracted);
+            vec3 sky_color = texture(sky, sphere_uv(reflected)).rgb;
+            vec3 water_color = refraction_ray(cur_pos, reflected);
+            float shade = exp(-res_t * ATTENUATION);
+
+            return shade * mix(sky_color, water_color, R);
         }
 
 
-        lh = y_at;
+        lh = ground_y_at;
+        lh2 = wave_y_at;
         ly = cur_pos.y;
     }
 
-    return texture(sky, sphere_uv(dir)).rgb;
+    return vec3(0);
 }
 
-vec3 inside_reflection_ray_second_bounce(vec3 pos, vec3 dir) {
+vec3 reflection_ray_first_bounce(vec3 pos, vec3 dir) {
     float lh = 0.0f;
-    float lh2 = 0.0f;
     float ly = 0.0f;
 
     for (float t = MIN_T; t < MAX_T; t += DT) {
@@ -179,39 +193,35 @@ vec3 inside_reflection_ray_second_bounce(vec3 pos, vec3 dir) {
         float wave_y_at = wave_depth_at(cur_pos);
         float y_at = max(ground_y_at, wave_y_at);
 
-        if (cur_pos.y < ground_y_at) {
+        if (cur_pos.y < y_at) {
             float res_t = t - DT + DT * (lh - ly) / (cur_pos.y - ly - y_at + lh);
             vec3 hit_pos = pos + res_t * dir;
 
-            return texture(ground, tex_at(hit_pos)).rgb;
-        } else if (cur_pos.y > wave_y_at) {
-            float res_t = t - DT + DT * (lh2 - ly) / (cur_pos.y - ly - y_at + lh2);
-            vec3 hit_pos = pos + res_t * dir;
+            if (ground_y_at > wave_y_at) {
+                return texture(ground, tex_at(hit_pos)).rgb;
+            } else {
+                vec3 nnormal = normalize(wave_normal_at(hit_pos));
+                vec3 reflected = reflect(dir, nnormal);
+                float R = fresnel(nnormal, dir);
+                // vec3 water_color = vec3(0, 0.05, 0.1);
 
-            vec3 nnormal = normalize(wave_normal_at(hit_pos));
-            vec3 reflected = reflect(dir, nnormal);
-            vec3 refracted = refract(dir, -nnormal, 1.333);
-            float R = fresnel(nnormal, -refracted);
-            if (dot(refracted, refracted) < 0.1) {
-                R = 1.0;
+                vec3 sky_color = reflection_ray_second_bounce(cur_pos, reflected);
+                //vec3 sky_color = texture(sky, sphere_uv(reflected)).rgb;
+                vec3 refracted = refract(dir, nnormal, 0.75);
+                vec3 water_color = inside_reflection_ray_second_bounce(cur_pos, refracted);
+                return mix(water_color, sky_color, R);
             }
-            // vec3 water_color = vec3(0, 0.05, 0.1);
-
-            //vec3 sky_color = reflection_ray_second_bounce(cur_pos, refracted);
-            vec3 sky_color = texture(sky, sphere_uv(reflected)).rgb;
-            vec3 water_color = refraction_ray(cur_pos, reflected);
-            return mix(sky_color, water_color, R);
         }
 
 
-        lh = ground_y_at;
-        lh2 = wave_y_at;
+        lh = y_at;
         ly = cur_pos.y;
     }
 
     return texture(sky, sphere_uv(dir)).rgb;
 }
 
+
 vec3 inside_reflection_ray_first_bounce(vec3 pos, vec3 dir) {
     float lh = 0.0f;
     float lh2 = 0.0f;
@@ -231,8 +241,9 @@ vec3 inside_reflection_ray_first_bounce(vec3 pos, vec3 dir) {
         if (cur_pos.y < ground_y_at) {
             float res_t = t - DT + DT * (lh - ly) / (cur_pos.y - ly - y_at + lh);
             vec3 hit_pos = pos + res_t * dir;
+            float shade = exp(-res_t * ATTENUATION);
 
-            return texture(ground, tex_at(hit_pos)).rgb;
+            return shade * texture(ground, tex_at(hit_pos)).rgb;
         } else if (cur_pos.y > wave_y_at) {
             float res_t = t - DT + DT * (lh2 - ly) / (cur_pos.y - ly - y_at + lh2);
             vec3 hit_pos = pos + res_t * dir;
@@ -249,7 +260,9 @@ vec3 inside_reflection_ray_first_bounce(vec3 pos, vec3 dir) {
             vec3 sky_color = reflection_ray_second_bounce(cur_pos, refracted);
             //vec3 sky_color = texture(sky, sphere_uv(reflected)).rgb;
             vec3 water_color = inside_reflection_ray_second_bounce(cur_pos, reflected);
-            return mix(sky_color, water_color, R);
+
+            float shade = exp(-res_t * ATTENUATION);
+            return shade * mix(sky_color, water_color, R);
         }
 
 
@@ -258,17 +271,9 @@ vec3 inside_reflection_ray_first_bounce(vec3 pos, vec3 dir) {
         ly = cur_pos.y;
     }
 
-    return texture(sky, sphere_uv(dir)).rgb;
+    return vec3(0);
 }
 
-vec3 map(float v)
-{
-    if (v > 1.0 || v < 0.0) {
-        return vec3(1, 0, 1);
-    } else {
-        return vec3(v);
-    }
-}
 
 void main(void)
 {
@@ -281,7 +286,7 @@ void main(void)
         float R = fresnel(nnormal, view);
 
         vec3 sky_color = reflection_ray_first_bounce(world_pos + vec3(0, 0.001, 0), reflected);
-        vec3 water_color = refraction_ray(world_pos, refracted);
+        vec3 water_color = inside_reflection_ray_first_bounce(world_pos, refracted);
 
         color = mix(water_color, sky_color, R);
     } else {
@@ -292,10 +297,11 @@ void main(void)
             R = 1.0;
         }
 
-        vec3 sky_color = reflection_ray_second_bounce(world_pos + vec3(0, 0.001, 0), refracted);
+        float shade = exp(-length(world_pos - camera_pos) * ATTENUATION);
+        vec3 sky_color = reflection_ray_first_bounce(world_pos + vec3(0, 0.001, 0), refracted);
         vec3 water_color = inside_reflection_ray_first_bounce(world_pos - vec3(0, 0.001, 0), reflected);
 
-        color = mix(sky_color, water_color, R);
+        color = shade * mix(sky_color, water_color, R);
     }
     out_Color = vec4(color, 1.0);
 }
diff --git a/src/main.cpp b/src/main.cpp
index e3487458e52c0d5310ef73962f9c12dc82198e31..60e672a5167fe167d825d2033238d6cbf4ee34bf 100644
--- a/src/main.cpp
+++ b/src/main.cpp
@@ -139,8 +139,8 @@ struct Scene {
   }
 
   void draw_surface_fbo() {
-    glUseProgram(env_map_program);
     useFBO(surf_fbo, nullptr, nullptr);
+    glUseProgram(env_map_program);
 
     glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);
 
@@ -159,10 +159,11 @@ struct Scene {
   }
 
   void draw_ground_fbo() {
-    useFBO(ground_fbo, nullptr, nullptr);
     ground.use();
     GLuint program = ground.program;
 
+    useFBO(ground_fbo, nullptr, nullptr);
+
     glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);
 
     glUniformMatrix4fv(glGetUniformLocation(program, "projectionMatrix"), 1,
@@ -311,6 +312,7 @@ struct Scene {
     waterfall.draw();
     waterfall.move_waterfall_balls();
   }
+
   void draw_ground() {
     ground.use();
     GLuint program = ground.program;