diff --git a/shaders/waterfall.frag b/shaders/waterfall.frag
index 662f4ee8751e9ba9552f50b3fce5def33c71e787..184f365b1ed4ea377e5046c8578d99f4399499e0 100644
--- a/shaders/waterfall.frag
+++ b/shaders/waterfall.frag
@@ -5,11 +5,16 @@ const int NUMBER_OF_STEPS = 48;
 const float MINIMUM_HIT_DISTANCE = 0.01;
 const float MAXIMUM_TRACE_DISTANCE = 100.0;
 
+// Constants related to reflection/refraction
+const float R0 = 0.04;
+const float WATER_REFRACTION = 0.75;
+const float PI = 3.1415926535897f;
+
 in vec3 world_pos;
 
+uniform sampler2D sky;
 uniform int num_balls;
 uniform vec3 camera_pos;
-uniform vec3 sun;
 
 out vec4 out_Color;
 
@@ -45,7 +50,7 @@ float SDF(vec3 position) {
   return min_dist;
 }
 
-vec3 calculate_normal(vec3 position) {
+vec3 approximate_normal(vec3 position) {
   const vec3 small_step = vec3(0.0001, 0, 0);
   return normalize(
     vec3(
@@ -56,31 +61,50 @@ vec3 calculate_normal(vec3 position) {
   );
 }
 
-vec3 ray_march(vec3 ro, vec3 rd) {
+float fresnel(vec3 normal, vec3 view) {
+    float cos_theta = dot(normal, -view);
+    float x = 1.0 - cos_theta;
+    return mix(R0, 1.0, x * x * x * x * x);
+}
+
+vec2 sphere_uv(vec3 direction)
+{
+    float u = 0.5 + atan(direction.z, direction.x) / (2 * PI);
+    float v = 0.5 + asin(direction.y) / PI;
+
+    return vec2(u, v);
+}
+
+vec3 ray_march(vec3 world_pos, vec3 dir) {
   float total_distance_traveled = 0.0;
 
   for (int i = 0; i < NUMBER_OF_STEPS; ++i) {
-    vec3 current_position = ro + total_distance_traveled * rd;
+    vec3 current_position = world_pos + total_distance_traveled * dir;
 
     float min_dist = SDF(current_position);
 
     // hit
     if (min_dist < MINIMUM_HIT_DISTANCE) {
-      const vec3 color = vec3(0.1, 0.1, 0.6);
+      //const vec3 color = vec3(0.1, 0.1, 0.6);
 
-      vec3 normal = normalize(calculate_normal(current_position));
-      vec3 direction_to_light = normalize(sun - current_position);
-      float diffuse_intensity = max(0.0, dot(normal, direction_to_light));
+      vec3 normal = normalize(approximate_normal(current_position));
+      vec3 reflected = reflect(dir, normal);
+      vec3 sky_color = texture(sky, sphere_uv(reflected)).rgb;
+      //vec3 direction_to_light = normalize(sun - current_position);
+      //float diffuse_intensity = max(0.0, dot(normal, direction_to_light));
       
-      return color + diffuse_intensity * vec3(1, 1, 1) * 0.7;
+      return sky_color;
     }
 
+    // miss
     if (total_distance_traveled > MAXIMUM_TRACE_DISTANCE) {
       discard;
     }
 
     total_distance_traveled += min_dist;
   }
+
+  // Iterated more than NUMBER_OF_STEPS, miss
   discard;
 }
 
diff --git a/src/main.cpp b/src/main.cpp
index ea289b54f8833bea2228010701348635e42b9e4d..e75b340b02539cf092048730d96cebcbf5e18423 100644
--- a/src/main.cpp
+++ b/src/main.cpp
@@ -2,13 +2,13 @@
 // Experimental C++ version 2022. Almost no code changes.
 
 #include "object.h"
-#include "waterfall.h"
 #include "surface.h"
+#include "waterfall.h"
 
 #include <GL/gl.h>
+#include <GL/glext.h>
 #include <cstdlib>
 #include <sys/types.h>
-#include <GL/glext.h>
 #define MAIN
 #include "GL_utilities.h"
 #include "LittleOBJLoader.h"
@@ -18,8 +18,6 @@
 // uses framework OpenGL
 // uses framework Cocoa
 
-
-
 struct Scene {
   Object ground;
   Surface surface;
@@ -101,33 +99,33 @@ struct Scene {
     }
   }
 
-  void draw_ground_fbo()
-  {
-      useFBO(ground_fbo, nullptr, nullptr);
-      ground.use();
-      GLuint program = ground.program;
+  void draw_ground_fbo() {
+    useFBO(ground_fbo, nullptr, nullptr);
+    ground.use();
+    GLuint program = ground.program;
 
-      glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);
+    glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);
 
-      glUniformMatrix4fv(glGetUniformLocation(program, "projectionMatrix"), 1, GL_TRUE, orth_matrix.m);
-      glUniformMatrix4fv(glGetUniformLocation(program, "modelToWorldToView"), 1, GL_TRUE, top_view_matrix.m);
+    glUniformMatrix4fv(glGetUniformLocation(program, "projectionMatrix"), 1,
+                       GL_TRUE, orth_matrix.m);
+    glUniformMatrix4fv(glGetUniformLocation(program, "modelToWorldToView"), 1,
+                       GL_TRUE, top_view_matrix.m);
 
-      glActiveTexture(GL_TEXTURE0);
-      glBindTexture(GL_TEXTURE_2D, grass_tex);
+    glActiveTexture(GL_TEXTURE0);
+    glBindTexture(GL_TEXTURE_2D, grass_tex);
 
-      glActiveTexture(GL_TEXTURE1);
-      glBindTexture(GL_TEXTURE_2D, dirt_tex);
+    glActiveTexture(GL_TEXTURE1);
+    glBindTexture(GL_TEXTURE_2D, dirt_tex);
 
-      glUniform1i(glGetUniformLocation(program, "grass"), 0);
-      glUniform1i(glGetUniformLocation(program, "dirt"), 1);
+    glUniform1i(glGetUniformLocation(program, "grass"), 0);
+    glUniform1i(glGetUniformLocation(program, "dirt"), 1);
 
-      ground.draw();
+    ground.draw();
 
-      useFBO(nullptr, nullptr, nullptr);
+    useFBO(nullptr, nullptr, nullptr);
   }
 
-  void draw_surface()
-  {
+  void draw_surface() {
     surface.use();
     GLuint program = surface.program;
     int elapsed_millis = glutGet(GLUT_ELAPSED_TIME);
@@ -142,12 +140,15 @@ struct Scene {
     glActiveTexture(GL_TEXTURE2);
     glBindTexture(GL_TEXTURE_2D, ground_fbo->depth);
 
-    glUniformMatrix4fv(glGetUniformLocation(program, "projectionMatrix"), 1, GL_TRUE, proj_matrix.m);
-    glUniformMatrix4fv(glGetUniformLocation(program, "modelToWorldToView"), 1, GL_TRUE, view_matrix.m);
+    glUniformMatrix4fv(glGetUniformLocation(program, "projectionMatrix"), 1,
+                       GL_TRUE, proj_matrix.m);
+    glUniformMatrix4fv(glGetUniformLocation(program, "modelToWorldToView"), 1,
+                       GL_TRUE, view_matrix.m);
     glUniform1i(glGetUniformLocation(program, "sky"), 0);
     glUniform1i(glGetUniformLocation(program, "ground"), 1);
     glUniform1i(glGetUniformLocation(program, "ground_depth"), 2);
-    glUniform3f(glGetUniformLocation(program, "camera_pos"), pos.x, pos.y, pos.z);
+    glUniform3f(glGetUniformLocation(program, "camera_pos"), pos.x, pos.y,
+                pos.z);
     glUniform1f(glGetUniformLocation(program, "time"), time);
 
     surface.draw();
@@ -190,17 +191,21 @@ struct Scene {
                        GL_TRUE, view_matrix.m);
     glUniform3f(glGetUniformLocation(program, "camera_pos"), pos.x, pos.y,
                 pos.z);
-    glUniform3f(glGetUniformLocation(program, "sun"), sun.x, sun.y, sun.z);
+
+    // Skybox
+    glActiveTexture(GL_TEXTURE0);
+    glBindTexture(GL_TEXTURE_2D, skybox_tex);
 
     waterfall.draw();
     waterfall.move_waterfall_balls();
   }
-  void draw_ground()
-  {
+  void draw_ground() {
     ground.use();
     GLuint program = ground.program;
-    glUniformMatrix4fv(glGetUniformLocation(program, "projectionMatrix"), 1, GL_TRUE, proj_matrix.m);
-    glUniformMatrix4fv(glGetUniformLocation(program, "modelToWorldToView"), 1, GL_TRUE, view_matrix.m);
+    glUniformMatrix4fv(glGetUniformLocation(program, "projectionMatrix"), 1,
+                       GL_TRUE, proj_matrix.m);
+    glUniformMatrix4fv(glGetUniformLocation(program, "modelToWorldToView"), 1,
+                       GL_TRUE, view_matrix.m);
     glActiveTexture(GL_TEXTURE0);
     glBindTexture(GL_TEXTURE_2D, grass_tex);
 
@@ -213,9 +218,7 @@ struct Scene {
     ground.draw();
   }
 
-
-  void draw()
-  {
+  void draw() {
     do_keyboard_input();
     update_view_matrix();
 
@@ -263,13 +266,12 @@ void display(void) {
   glutSwapBuffers();
 }
 
-void on_mouse_move(int x, int y)
-{
+void on_mouse_move(int x, int y) {
   const float sensitivity = 0.0025f;
   if (mouse_y == -1 && mouse_x == -1) {
-      mouse_x = x;
-      mouse_y = y;
-      return;
+    mouse_x = x;
+    mouse_y = y;
+    return;
   }
 
   int dx = x - mouse_x;