diff --git a/shaders/waterfall.frag b/shaders/waterfall.frag
index 184f365b1ed4ea377e5046c8578d99f4399499e0..03b855d20ab53e3e354eaa13b759196c5a46cbc8 100644
--- a/shaders/waterfall.frag
+++ b/shaders/waterfall.frag
@@ -2,13 +2,16 @@
 
 const float step_size = 0.1;
 const int NUMBER_OF_STEPS = 48;
-const float MINIMUM_HIT_DISTANCE = 0.01;
+const int NUMBER_OF_BALL_STEPS = 16;
+const float START_STEP = 0.1/30.0; // Half the size of the smallest ball radius
+const float MINIMUM_HIT_DISTANCE = 0.002;
 const float MAXIMUM_TRACE_DISTANCE = 100.0;
 
 // Constants related to reflection/refraction
 const float R0 = 0.04;
 const float WATER_REFRACTION = 0.75;
 const float PI = 3.1415926535897f;
+const int REFRACTION_ITER = 3;
 
 in vec3 world_pos;
 
@@ -41,7 +44,7 @@ float SDF(vec3 position) {
   float min_dist = distance_from_sphere(position, balls[0]);
 
   // TODO: iterate over uniform num_balls
-  for (int i = 1; i < 100; ++i) {
+  for (int i = 1; i < 30; ++i) {
     float dist = distance_from_sphere(position, balls[i]);
 
     min_dist = opSmoothUnion(min_dist, dist, 0.2);
@@ -62,17 +65,50 @@ vec3 approximate_normal(vec3 position) {
 }
 
 float fresnel(vec3 normal, vec3 view) {
-    float cos_theta = dot(normal, -view);
-    float x = 1.0 - cos_theta;
-    return mix(R0, 1.0, x * x * x * x * x);
+  float cos_theta = dot(normal, -view);
+  float x = 1.0 - cos_theta;
+  return mix(R0, 1.0, x * x * x * x * x);
 }
 
 vec2 sphere_uv(vec3 direction)
 {
-    float u = 0.5 + atan(direction.z, direction.x) / (2 * PI);
-    float v = 0.5 + asin(direction.y) / PI;
+  float u = 0.5 + atan(direction.z, direction.x) / (2 * PI);
+  float v = 0.5 + asin(direction.y) / PI;
 
-    return vec2(u, v);
+  return vec2(u, v);
+}
+
+vec3 reflection(vec3 pos, vec3 dir) {
+  // Calculate reflection from skybox
+  vec3 normal = normalize(approximate_normal(pos));
+  vec3 reflected = reflect(dir, normal);
+  vec3 sky_color = texture(sky, sphere_uv(reflected)).rgb;
+  return sky_color;
+}
+
+// TODO: implement this
+vec3 refraction(vec3 world_pos, vec3 dir) {
+
+  float total_distance_traveled = START_STEP;
+
+  for (int i = 0; i < NUMBER_OF_BALL_STEPS; ++i) {
+    vec3 current_position = world_pos + total_distance_traveled * dir;
+
+    float min_dist = -SDF(current_position);
+
+    // hit
+    if (min_dist < MINIMUM_HIT_DISTANCE) {
+      return reflection(current_position, dir);
+    }
+
+    // miss
+    if (total_distance_traveled > MAXIMUM_TRACE_DISTANCE) {
+      discard;
+    }
+
+    total_distance_traveled += min_dist;
+  }
+  discard;
 }
 
 vec3 ray_march(vec3 world_pos, vec3 dir) {
@@ -87,13 +123,21 @@ vec3 ray_march(vec3 world_pos, vec3 dir) {
     if (min_dist < MINIMUM_HIT_DISTANCE) {
       //const vec3 color = vec3(0.1, 0.1, 0.6);
 
-      vec3 normal = normalize(approximate_normal(current_position));
-      vec3 reflected = reflect(dir, normal);
-      vec3 sky_color = texture(sky, sphere_uv(reflected)).rgb;
-      //vec3 direction_to_light = normalize(sun - current_position);
-      //float diffuse_intensity = max(0.0, dot(normal, direction_to_light));
+      vec3 reflec = reflection(current_position, dir);
+      vec3 refrac = refraction(current_position, dir);
+
+      vec3 normal = approximate_normal(current_position);
+
+      // What percentage of out_Color should be from reflection/refraction
+      float change_me = fresnel(normal, dir);
+
+      // TODO: Add refraction from ball/skybox
+      vec3 refracted = refraction(current_position, dir);
+
+      vec3 out_Color = mix(reflec, refrac, change_me);
       
-      return sky_color;
+      return out_Color;
+      //return reflec;
     }
 
     // miss
diff --git a/src/waterfall.cpp b/src/waterfall.cpp
index 4342d8d5fedb36525e39814f7bf1953ab96d2764..3f2e5c3fa114ad1de67e037c2fa47f133eaf57cf 100644
--- a/src/waterfall.cpp
+++ b/src/waterfall.cpp
@@ -85,8 +85,8 @@ float lerp(float a, float b, float t) { return a + t * (b - a); }
 void Waterfall::gen_balls() {
   for (int i{0}; i < NUM_BALLS; ++i) {
 
-    // pseudo random value between 0.01 and 0.1
-    float radius = ((rand() / (float)RAND_MAX) * 0.9 + 0.1) / 50.0;
+    // pseudo random value between 0.1/30.0 and 1.0/30.0
+    float radius = ((rand() / (float)RAND_MAX) * 0.9 + 0.1) / 10.0;
 
     // pseudo random value between min-radius and max-radius
     float pos_x = lerp(x.first, x.second, (rand() / (float)RAND_MAX));
@@ -95,7 +95,7 @@ void Waterfall::gen_balls() {
     vec4 pos = vec4(pos_x, pos_y, pos_z, 0);
 
     // pseudo random value between -0.001 and -0.01
-    float velocity_y = -((rand() / (float)RAND_MAX) * 0.9 + 0.1) / 50.0;
+    float velocity_y = -((rand() / (float)RAND_MAX) * 0.9 + 0.1) / 100.0;
     vec4 velocity = vec4(0, velocity_y, 0, 0);
 
     balls[i] = Waterfall::Ball(pos, radius);