diff --git a/shaders/surface.frag b/shaders/surface.frag
index 7d92382799e0b4e4370ba9fc4c86ebcb5e8c3c01..dfb45f62aa7fb0e932b3220a4ba098064dcdc607 100644
--- a/shaders/surface.frag
+++ b/shaders/surface.frag
@@ -1,11 +1,40 @@
 #version 150
 
-in float shade;
+in vec3 world_pos;
+in vec3 normal;
 
 out vec4 out_Color;
 
+uniform sampler2D sky;
+uniform vec3 camera_pos;
+
+const float PI = 3.1415926535897f;
+const float R0 = 0.04;
+
+vec2 sphere_uv(vec3 direction)
+{
+    float u = 0.5 + atan(direction.z, direction.x) / (2 * PI);
+    float v = 0.5 + asin(direction.y) / PI;
+
+    return vec2(u, v);
+}
+
+float fresnel(vec3 normal, vec3 view) {
+    float cos_theta = dot(normal, -view);
+    float x = 1.0 - cos_theta;
+    return mix(R0, 1.0, x * x * x * x * x);
+}
+
 void main(void)
 {
-	out_Color=vec4(shade,shade,shade,1.0);
+    vec3 view = normalize(world_pos - camera_pos);
+    vec3 nnormal = normalize(normal);
+    vec3 reflected = reflect(view, nnormal);
+    float R = fresnel(nnormal, view);
+    vec3 water_color = vec3(0, 0.3, 0.5);
+
+    vec3 sky_color = texture(sky, sphere_uv(reflected)).rgb;
+
+    out_Color = vec4(mix(water_color, sky_color, R), 1.0);
 }
 
diff --git a/shaders/surface.vert b/shaders/surface.vert
index a3adbe90a7a42e871e52704cbed2f275d47f6f1e..571f4bdbe7a1b7f3b62eb8d4d6217ea2fe59e967 100644
--- a/shaders/surface.vert
+++ b/shaders/surface.vert
@@ -7,11 +7,13 @@ in  vec2  in_TexCoord;
 uniform mat4 projectionMatrix;
 uniform mat4 modelToWorldToView;
 
-out float shade;
+out vec3 world_pos;
+out vec3 normal;
 
 void main(void)
 {
-	shade = (mat3(modelToWorldToView)*in_Normal).z; // Fake shading
-	gl_Position=projectionMatrix*modelToWorldToView*vec4(in_Position, 1.0);
+    world_pos = in_Position;
+    normal = in_Normal;
+	gl_Position = projectionMatrix * modelToWorldToView * vec4(in_Position, 1.0);
 }
 
diff --git a/src/main.cpp b/src/main.cpp
index 813928e1878e31e472b8ae987531bc5ecd650684..4da65e0ce0bb59ac05849274ece29c82d11a8620 100644
--- a/src/main.cpp
+++ b/src/main.cpp
@@ -56,9 +56,6 @@ struct Scene {
     LoadTGATextureSimple("textures/sky4k.tga", &skybox_tex);
     glBindTexture(GL_TEXTURE_2D, skybox_tex);
     printError("bind texture");
-    // glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S,
-    // GL_MIRRORED_REPEAT); glTexParameteri(GL_TEXTURE_2D,
-    // GL_TEXTURE_WRAP_T, GL_MIRRORED_REPEAT);
     glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR);
     printError("tex parameteri");
   }
@@ -89,6 +86,29 @@ struct Scene {
     }
   }
 
+  void draw_surface()
+  {
+      surface.use();
+      GLuint program = surface.program;
+      glUniformMatrix4fv(glGetUniformLocation(program, "projectionMatrix"), 1, GL_TRUE, proj_matrix.m);
+      glUniformMatrix4fv(glGetUniformLocation(program, "modelToWorldToView"), 1, GL_TRUE, view_matrix.m);
+      glUniform1i(glGetUniformLocation(program, "sky"), 0);
+      glUniform3f(glGetUniformLocation(program, "camera_pos"), pos.x, pos.y, pos.z);
+
+      surface.draw();
+  }
+
+
+  void draw_ground()
+  {
+    ground.use();
+    GLuint program = ground.program;
+    glUniformMatrix4fv(glGetUniformLocation(program, "projectionMatrix"), 1, GL_TRUE, proj_matrix.m);
+    glUniformMatrix4fv(glGetUniformLocation(program, "modelToWorldToView"), 1, GL_TRUE, view_matrix.m);
+
+    ground.draw();
+  }
+
   void update_view_matrix() {
     vec3 up{0.0, 1.0, 0.0};
     mat4 rot = Rx(pitch) * Ry(yaw);
@@ -115,27 +135,6 @@ struct Scene {
     glEnable(GL_CULL_FACE);
   }
 
-  void draw_surface() {
-    surface.use();
-    GLuint program = surface.program;
-    glUniformMatrix4fv(glGetUniformLocation(program, "projectionMatrix"), 1,
-                       GL_TRUE, proj_matrix.m);
-    glUniformMatrix4fv(glGetUniformLocation(program, "modelToWorldToView"), 1,
-                       GL_TRUE, view_matrix.m);
-
-    surface.draw();
-  }
-
-  void draw_ground() {
-    ground.use();
-    GLuint program = ground.program;
-    glUniformMatrix4fv(glGetUniformLocation(program, "projectionMatrix"), 1,
-                       GL_TRUE, proj_matrix.m);
-    glUniformMatrix4fv(glGetUniformLocation(program, "modelToWorldToView"), 1,
-                       GL_TRUE, view_matrix.m);
-
-    ground.draw();
-  }
 
   void draw_waterfall() {
     waterfall.use();