diff --git a/shaders/surface.frag b/shaders/surface.frag
index 18f3dfa2131de80ac0ef82a7bef8ae6bfd9a0e21..bc02f077e2a43b1d504c7c269146399a2d00f986 100644
--- a/shaders/surface.frag
+++ b/shaders/surface.frag
@@ -31,7 +31,7 @@ void main(void)
     vec3 nnormal = normalize(normal);
     vec3 reflected = reflect(view, nnormal);
     float R = fresnel(nnormal, view);
-    vec3 water_color = vec3(0, 0.05, 0.1);
+    vec3 water_color = vec3(0.6, 0.6, 0.9);
 
     vec3 sky_color = texture(sky, sphere_uv(reflected)).rgb;
 
diff --git a/shaders/surface.vert b/shaders/surface.vert
index 4e09f949d6a6875c23f955805f2ba96ec2c916cb..62dc97f891555a2929d0b0b55e43e93d23ef0060 100644
--- a/shaders/surface.vert
+++ b/shaders/surface.vert
@@ -11,9 +11,47 @@ uniform float time;
 out vec3 world_pos;
 out vec3 normal;
 
+struct Wave {
+    float L; // Wavelength (distance between waves in world space)
+    float A; // Amplitude (height from surface to wave crest)
+    float S; // Speed (distance the crest moves per second)
+    vec2 D; // Direction (horizontal vector perpendicular to the wave front)
+};
+
+const int NUM_WAVES = 3;
+Wave waves[NUM_WAVES];
+
+void init_waves() {
+    waves[0] = Wave(0.2, 0.01, 0.1, vec2(0.5, 0.5)); 
+    waves[1] = Wave(0.3, 0.02, 0.15, vec2(1.0, 0.0)); 
+    waves[2] = Wave(0.1, 0.015, 0.08, vec2(-0.7, 0.3)); 
+}
+
+/* Get the offset for a single wave */
+float get_single_offset(Wave wave, float x, float z, float t) {
+    float w = 2.0/wave.L; // Frequency
+    float phase_const = wave.S * w; // Phase constant
+
+    return wave.A * sin(dot(wave.D, vec2(x, z)) * w + t * phase_const);
+}
+
+/* Get the offset for all the combined waves */
+float get_wave_offset(float x, float z, float t) {
+    float sum = 0.0;
+
+    for (int i = 0; i < NUM_WAVES; i++) {
+        sum += get_single_offset(waves[i], x, z, t);
+    }
+
+    return sum;
+}
+
 void main(void)
 {
-    vec3 offset_pos = vec3(in_Position.x, in_Position.y + 0.05 * sin(10.0 * in_Position.x + time), in_Position.z);
+    init_waves();
+    //vec3 offset_pos = vec3(in_Position.x, in_Position.y + 0.05 * sin(10.0 * in_Position.x + time), in_Position.z);
+    float offset = get_wave_offset(in_Position.x, in_Position.z, time);
+    vec3 offset_pos = vec3(in_Position.x, in_Position.y + offset, in_Position.z);
     world_pos = offset_pos;
     normal = normalize(vec3(-0.05 * cos(10.0 * in_Position.x + time), 1.0, 0.0));
 
diff --git a/src/main.cpp b/src/main.cpp
index 83b853fc4ea4a439caee896a1a0112a32bb84b6d..5ce9d9737f0bd8945db41d84f9c055448c7ef5b8 100644
--- a/src/main.cpp
+++ b/src/main.cpp
@@ -94,6 +94,12 @@ struct Scene {
         if (glutKeyIsDown('f')) {
             pos -= speed * up;
         }
+        if (glutKeyIsDown('e')) {
+            pos += speed * up;
+        }
+        if (glutKeyIsDown('q')) {
+            pos -= speed * up;
+        }
     }
 
     void update_view_matrix()