#include "DirectionalLight.fxh"
#include "HeightMapHelper.fxh"
HEIGHTMAP_REGISTER(0)
	
float4x4 World;
float4x4 View;
float4x4 Projection;

texture TerrainTexture : register(t1);
sampler2D TerrainSampler : register(s1) = sampler_state
{
	Texture = <TerrainTexture>;
};

texture UnderWaterTerrainTexture : register(t2);
sampler2D UnderWaterTerrainSampler : register(s2) = sampler_state
{
	Texture = <UnderWaterTerrainTexture>;
};

struct VertexShaderInput
{
    float4 Position : POSITION0;
	float2 TexCoord : TEXCOORD0;
};

struct VertexShaderOutput
{
    float4 Position : POSITION0;
	float3 WorldPosition : TEXCOORD0;
	float2 TexCoord : TEXCOORD1;
	float3 Normal : NORMAL0;
};

//WorldPositionToHeightMapTexCoord(worldPosition)

float3 CalculateNormalAt(float3 worldPosition)
{
	float2 baseTex = WorldPositionToHeightMapTexCoord(worldPosition);

	// Now figure out the normal.
	float xNormal = tex2Dlod(HeightMapSampler, float4(baseTex - float2(HeightMapTexelSize.x, 0), 0, 0)).r -
		tex2Dlod(HeightMapSampler, float4(baseTex - float2(-HeightMapTexelSize.x, 0), 0, 0)).r;
	float zNormal = tex2Dlod(HeightMapSampler, float4(baseTex - float2(0, HeightMapTexelSize.y), 0, 0)).r -
		tex2Dlod(HeightMapSampler, float4(baseTex - float2(0, -HeightMapTexelSize.y), 0, 0)).r;

	float ratioOfHorizontalToHeight = HeightMapTexelSize.x / TERRAIN_HEIGHT_SCALE;
	float3 normal = float3(xNormal, 2 * ratioOfHorizontalToHeight, zNormal);
	normal = normalize(normal);

	return normal;
}

VertexShaderOutput VertexShaderFunction(VertexShaderInput input)
{
    VertexShaderOutput output;

    float4 worldPosition = mul(input.Position, World);
	worldPosition.y = GetWorldHeight(input.Position);
    float4 viewPosition = mul(worldPosition, View);
    output.Position = mul(viewPosition, Projection);
	output.TexCoord = WorldPositionToTerrainTextureCoord(worldPosition);
	output.Normal = CalculateNormalAt(worldPosition);
	output.WorldPosition = worldPosition;
    return output;
}

#define TRANSITION_RANGE 0.1

float4 PixelShaderFunction(VertexShaderOutput input) : COLOR0
{
	// Some extermely basic lighting.
	float intensity = GetDirectionalLightIntensity(input.Normal);
	intensity = (intensity + 1) / 2;	// Half ambient

	float4 upper = tex2D(TerrainSampler, input.TexCoord);
	float4 lower = tex2D(UnderWaterTerrainSampler, input.TexCoord);

	// from 0 to TRANSITION_RANGE, we want the terrain to go from invisible to full.
	float4 final = lerp(lower, upper, saturate(input.WorldPosition.y / TRANSITION_RANGE));

    return final * intensity;
}

technique Technique1
{
    pass Pass1
    {
        VertexShader = compile vs_3_0 VertexShaderFunction();
        PixelShader = compile ps_3_0 PixelShaderFunction();
    }
}
