diff --git a/shaders/waterfall.frag b/shaders/waterfall.frag
index 3244e1710286b70dbf76c27ecbc9f4b66dd5f6d7..704f418fe9de29abc6c1d3efae7a1e787e73c66d 100644
--- a/shaders/waterfall.frag
+++ b/shaders/waterfall.frag
@@ -6,8 +6,6 @@ const float MINIMUM_HIT_DISTANCE = 0.0001;
 const float MAXIMUM_TRACE_DISTANCE = 1000.0;
 
 uniform int num_balls;
-uniform float screenWidth;
-uniform float screenHeight;
 
 in vec3 world_pos;
 uniform vec3 camera_pos;
@@ -19,7 +17,7 @@ struct Ball {
 };
 
 layout(std140) uniform BallBuffer {
-    Ball balls[30];  // TODO: Need to manually update size
+  Ball balls[30];  // TODO: Need to manually update size
 };
 
 float distance_from_sphere(vec3 pos, Ball ball) {
@@ -28,25 +26,25 @@ float distance_from_sphere(vec3 pos, Ball ball) {
 
 // Function taken from: https://iquilezles.org/articles/distfunctions/
 float opSmoothUnion(float d1, float d2, float k) {
-    float h = clamp(0.5 + 0.5*(d2-d1)/k, 0.0, 1.0);
-    return mix(d2, d1, h) - k*h*(1.0-h);
+  float h = clamp(0.5 + 0.5*(d2-d1)/k, 0.0, 1.0);
+  return mix(d2, d1, h) - k*h*(1.0-h);
 }
 
 float SDF(vec3 position) {
-    float min_dist = distance_from_sphere(position, balls[0]);
+  float min_dist = distance_from_sphere(position, balls[0]);
 
-    // TODO: iterate over uniform num_balls
-    for (int i = 1; i < 30; ++i) {
-      float dist = distance_from_sphere(position, balls[i]);
+  // TODO: iterate over uniform num_balls
+  for (int i = 1; i < 30; ++i) {
+    float dist = distance_from_sphere(position, balls[i]);
 
-      min_dist = opSmoothUnion(min_dist, dist, 0.2);
-    }
+    min_dist = opSmoothUnion(min_dist, dist, 0.1);
+  }
 
-    return min_dist;
+  return min_dist;
 }
 
 vec3 calculate_normal(vec3 position) {
-  const vec3 small_step = vec3(0.001, 0, 0);
+  const vec3 small_step = vec3(0.0001, 0, 0);
   return normalize(
     vec3(
       SDF(position + small_step.xyy) - SDF(position - small_step.xyy),
@@ -57,20 +55,28 @@ vec3 calculate_normal(vec3 position) {
 }
 
 vec3 ray_march(vec3 ro, vec3 rd) {
-
   float total_distance_traveled = 0.0;
 
   for (int i = 0; i < NUMBER_OF_STEPS; ++i) {
-    
     vec3 current_position = ro + total_distance_traveled * rd;
 
     float min_dist = SDF(current_position);
 
     // hit
     if (min_dist < MINIMUM_HIT_DISTANCE) {
-      const vec3 color = vec3(0.1, 0.2, 0.7);
+      const vec3 color = vec3(0.2, 0.3, 0.8);
+
+      const vec3 light_position = vec3(0, 10, 0);
       vec3 normal = normalize(calculate_normal(current_position));
-      return color;
+      vec3 direction_to_light = normalize(light_position - current_position);
+      float diffuse_intensity = max(0.0, dot(normal, direction_to_light));
+      
+      if (diffuse_intensity < 0.99) {
+        discard;
+      }
+
+      //return color + diffuse_intensity * vec3(1, 1, 1) * 0.5;
+      return vec3(diffuse_intensity, diffuse_intensity, diffuse_intensity);
     }
 
     if (total_distance_traveled > MAXIMUM_TRACE_DISTANCE) {
@@ -86,6 +92,4 @@ void main(void)
   vec3 dir = normalize(world_pos - camera_pos);
   vec3 shaded_color = ray_march(world_pos, dir);
 	out_Color = vec4(shaded_color, 1.0);
-	//out_Color = vec4(gl_FragCoord.x, gl_FragCoord.y, gl_FragCoord.z, 1);
-	//out_Color = vec4(shade, shade, shade, 0.0);
 }