diff --git a/shaders/surface.vert b/shaders/surface.vert
index ebdee6bf6ccdf7cd8ed6f97b5bf3ded918bb9e0b..1296d38ee99cc223e46773830e254513f259d7b3 100644
--- a/shaders/surface.vert
+++ b/shaders/surface.vert
@@ -18,33 +18,46 @@ struct Wave {
     vec2 D; // Direction (horizontal vector perpendicular to the wave front)
 };
 
-const int NUM_WAVES = 4;
+const int NUM_WAVES = 5;
 Wave waves[NUM_WAVES];
 
 void init_waves() {
-    waves[0] = Wave(1, 0.11, 0.157, vec2(1.0, 0.0)); 
-    waves[1] = Wave(0.5, 0.03, 0.312, vec2(1.0, 0.2)); 
-    waves[2] = Wave(0.41, 0.024, 0.451, vec2(1.0, -0.13)); 
-    waves[3] = Wave(0.9, 0.011, 0.314, vec2(0.0, 1.0)); 
+    waves[0] = Wave(0.3, 0.002, 0.157, vec2(1.0, 0.0)); 
+    waves[1] = Wave(0.5, 0.003, 0.312, vec2(1.0, 0.2)); 
+    waves[2] = Wave(0.41, 0.0024, 0.451, vec2(1.0, -0.13)); 
+    waves[3] = Wave(0.9, 0.0011, 0.314, vec2(0.0, 1.0)); 
+    waves[4] = Wave(0.33, 0.0025, 0.310, vec2(-1.0, 0.0)); 
 }
 
 /* Get the offset for a single wave */
-float get_single_offset(Wave wave, float x, float z, float t) {
+float get_gerstner1(Wave wave, float x, float z, float t) {
     float w = 2.0/wave.L; // Frequency
     float phase_const = wave.S * w; // Phase constant
 
-    return wave.A * sin(dot(normalize(wave.D), vec2(x, z)) * w + t * phase_const);
+    return wave.A * sin(dot(normalize(wave.D) * w, vec2(x, z)) + t * phase_const);
+}
+
+float get_gerstner2(Wave wave, float dir, float x, float z, float t) {
+    float w = 2.0/wave.L; // Frequency
+    float phase_const = wave.S * w; // Phase constant
+    float Q = 0.3/(w*wave.A);
+    
+    return Q * wave.A * dir * cos(dot(normalize(wave.D) * w, vec2(x, z)) + t * phase_const);
 }
 
 /* Get the offset for all the combined waves */
-float get_wave_offset(float x, float z, float t) {
-    float sum = 0.0;
+vec3 get_waves(const vec3 pos, const float t) {
+    vec3 offset = vec3(0.0, 0.0, 0.0);
 
     for (int i = 0; i < NUM_WAVES; i++) {
-        sum += get_single_offset(waves[i], x, z, t);
+        offset.y += get_gerstner1(waves[i], pos.x, pos.z, t);
+
+        offset.x += get_gerstner2(waves[i], waves[i].D.x, pos.x, pos.z, t);
+        offset.z += get_gerstner2(waves[i], waves[i].D.y, pos.x, pos.z, t);
     }
+    
+    return offset + vec3(pos.x, 0.0, pos.z);
 
-    return sum;
 }
 
 /* Get the normal for a wave */
@@ -52,11 +65,16 @@ vec3 compute_normal(vec3 pos, float t) {
     float delta = 0.001; 
 
     // Calculate the offsets in x and z directions
-    float offset_x = get_wave_offset(pos.x + delta, pos.z, t) - get_wave_offset(pos.x, pos.z, t);
-    float offset_z = get_wave_offset(pos.x, pos.z + delta, t) - get_wave_offset(pos.x, pos.z, t);
+    vec3 offset1 = vec3(pos.x + delta, pos.y, pos.z);
+    vec3 offset2 = vec3(pos.x - delta, pos.y, pos.z);
+    float offset_x = (get_waves(offset1, t) - get_waves(offset2, t)).x;
+
+    offset1 = vec3(pos.x, pos.y, pos.z + delta);
+    offset2 = vec3(pos.x, pos.y, pos.z - delta);
+    float offset_z = (get_waves(offset1, t) - get_waves(offset2, t)).z;
 
     // Construct the normal using partial derivatives
-    float scale = 200.0; // TODO: Not here
+    float scale = 20.0; // TODO: Not here
     return normalize(vec3(-offset_x * scale, 1.0, -offset_z * scale));
 }
 
@@ -64,10 +82,10 @@ void main(void)
 {
     init_waves();
     //vec3 offset_pos = vec3(in_Position.x, in_Position.y + 0.05 * sin(10.0 * in_Position.x + time), in_Position.z);
-    float offset = get_wave_offset(in_Position.x, in_Position.z, time);
-    vec3 offset_pos = vec3(in_Position.x, in_Position.y + offset, in_Position.z);
+    vec3 offset_pos = get_waves(in_Position, time);
     world_pos = offset_pos;
     normal = compute_normal(in_Position, time);
+    // normal = in_Normal;
 
 	gl_Position = projectionMatrix * modelToWorldToView * vec4(offset_pos, 1.0);
 }