#include "ShaderCommon.h"
#include <metal_stdlib>
#include <metal_graphics>
#include <metal_geometric>

using namespace metal;

struct Light {
    float3 direction;
    float3 ambientColor;
    float3 diffuseColor;
    float3 specularColor;
};

struct Material {
    float3 ambientColor;
    float3 diffuseColor;
    float3 specularColor;
    float  specularPower;
};

struct VertexOut {
    float4 position [[position]];
    float3 eye;
    float3 normal;
    float2 texCoord;
};

constant Light g_light = {
    .direction     = { 0.13, 0.72, 0.68 },
    .ambientColor  = { 0.75, 0.75, 0.75 },
    .diffuseColor  = { 0.9,  0.9,  0.9  },
    .specularColor = { 1.0,  1.0,  1.0  }
};

vertex VertexOut basic_vertex(const device VertexIn* vertex_array [[ buffer(0) ]],
                              const device Uniforms& uniforms     [[ buffer(1) ]],
                              const device float4x4* matrices     [[ buffer(2) ]],
                              unsigned int vid                    [[ vertex_id ]]) {
    VertexIn in = vertex_array[vid];
    
    float4 positions[4];
    const float4 v = float4(in.position, 1);
    for (int i = 0; i < 4; i++) {
        positions[i] = uniforms.modelViewMatrix * matrices[in.boneIndices[i]] * v;
    }
    float4 position = positions[0] * in.boneWeights[0];
    for (int i = 1; i < 4; i++) {
        position += positions[i] * in.boneWeights[i];
    }
    
    float3 normals[4];
    const float4 n = float4(in.normal, 0); // w must be zero.
    for (int i = 0; i < 4; i++) {
        normals[i] = (matrices[in.boneIndices[i]] * n).xyz;
    }
    float3 normal = normals[0].xyz * in.boneWeights[0];
    for (int i = 1; i < 4; i++) {
        normal += normals[i].xyz * in.boneWeights[i];
    }
    normal = normalize(normal);
    
    VertexOut out;
    out.position = uniforms.projectionMatrix * position;
    out.eye      = -position.xyz;
    out.normal   = uniforms.normalMatrix * normal;
    out.texCoord = in.texCoord;
    
    return out;
}

fragment float4 basic_fragment(VertexOut         in        [[ stage_in   ]],
                               constant Uniforms &uniforms [[ buffer(0)  ]],
                               constant Material &material [[ buffer(1)  ]],
                               texture2d<float>  tex2D     [[ texture(0) ]],
                               sampler           sampler2D [[ sampler(0) ]]) {
    float3 ambientTerm        = g_light.ambientColor * material.ambientColor;
    
    float3 normal             = normalize(in.normal);
    float  diffuseIntensity   = saturate(dot(normal, g_light.direction));
    float3 diffuseTerm        = g_light.diffuseColor * material.diffuseColor * diffuseIntensity;
    
    float3 specularTerm(0);
    if (diffuseIntensity > 0) {
        float3 eyeDirection   = normalize(in.eye);
        float3 halfway        = normalize(g_light.direction + eyeDirection);
        float  specularFactor = pow(saturate(dot(normal, halfway)), material.specularPower);
        specularTerm          = g_light.specularColor * material.specularColor * specularFactor;
    }
    
    float4 texColor           = tex2D.sample(sampler2D, in.texCoord);
    
    // If the depth value is not output by the fragment function,
    // the depth value generated by the rasterizer is output to the depth attachment
    return float4(ambientTerm + diffuseTerm + specularTerm, 1) * texColor;
}
