﻿//-------------------------------------------------------------------------------
// File : main.cpp
// Desc : Application Main Entry Point.
// Copyright(c) Project Asura. All right reserved.
//-------------------------------------------------------------------------------


//-------------------------------------------------------------------------------
// Includes
//-------------------------------------------------------------------------------
#include <cstdio>
#include <crtdbg.h>
#include <cassert>
#include <cmath>
#include <d3d11.h>
#include <d3dcompiler.h>


//-------------------------------------------------------------------------------
// Linker
//-------------------------------------------------------------------------------
#pragma comment( lib, "d3d11.lib" )
#pragma comment( lib, "d3dcompiler.lib" )
#pragma comment( lib, "dxguid.lib" )
#pragma comment( lib, "winmm.lib" )
#pragma comment( lib, "comctl32.lib" )


//-------------------------------------------------------------------------------
// Macros
//-------------------------------------------------------------------------------
#ifndef ASDX_RELEASE
#define ASDX_RELEASE(p) { if (p) { (p)->Release(); (p) = nullptr; } }
#endif//ASDX_RELEASE

#ifndef ILOG
#define ILOG(x, ...)    printf_s( x"\n", ##__VA_ARGS__ )
#endif//ILOG


//-------------------------------------------------------------------------------
// Constant Values
//-------------------------------------------------------------------------------
const UINT NUM_ELEMENTS = 1024;


//-------------------------------------------------------------------------------
// Forward Declarations
//-------------------------------------------------------------------------------
HRESULT         CreateComputeDevice     ( ID3D11Device**, ID3D11DeviceContext** );
HRESULT         CreateComputeShader     ( const wchar_t*, const char*, ID3D11Device*, ID3D11ComputeShader** );
HRESULT         CreateStructuredBuffer  ( ID3D11Device*, UINT, UINT, void*, ID3D11Buffer** );
HRESULT         CreateRawBuffer         ( ID3D11Device*, UINT, void*, ID3D11Buffer** );
HRESULT         CreateBufferSRV         ( ID3D11Device*, ID3D11Buffer*, ID3D11ShaderResourceView** );
HRESULT         CreateBufferUAV         ( ID3D11Device*, ID3D11Buffer*, ID3D11UnorderedAccessView** );
ID3D11Buffer*   CreateAndCopyToBuffer   ( ID3D11Device*, ID3D11DeviceContext*, ID3D11Buffer* );
void            RunComputeShader( 
                    ID3D11DeviceContext*,
                    ID3D11ComputeShader*,
                    UINT,
                    ID3D11ShaderResourceView**,
                    ID3D11Buffer*,
                    void*,
                    DWORD,
                    ID3D11UnorderedAccessView*,
                    UINT,
                    UINT,
                    UINT );

////////////////////////////////////////////////////////////////////////////////////
// BufType structure
////////////////////////////////////////////////////////////////////////////////////
struct BufType
{
    int     s32;
    float   f32;
};

//----------------------------------------------------------------------------------
// Global Variables.
//----------------------------------------------------------------------------------
ID3D11Device*           g_pDevice       = nullptr;
ID3D11DeviceContext*    g_pContext      = nullptr;
ID3D11ComputeShader*    g_pCS           = nullptr;

ID3D11Buffer*           g_pBuf0         = nullptr;
ID3D11Buffer*           g_pBuf1         = nullptr;
ID3D11Buffer*           g_pBufResult    = nullptr;

ID3D11ShaderResourceView*   g_pBufSRV0  = nullptr;
ID3D11ShaderResourceView*   g_pBufSRV1  = nullptr;
ID3D11UnorderedAccessView*  g_pBufUAV   = nullptr;

BufType g_Buf0[ NUM_ELEMENTS ];
BufType g_Buf1[ NUM_ELEMENTS ];


/////////////////////////////////////////////////////////////////////////////////////
// Functions
/////////////////////////////////////////////////////////////////////////////////////

//-----------------------------------------------------------------------------------
//      メインエントリーポイントです.
///-----------------------------------------------------------------------------------
int main()
{
#if defined(DEBUG) || defined(_DEBUG)
    // メモリリークチェック.
    _CrtSetDbgFlag( _CRTDBG_ALLOC_MEM_DF | _CRTDBG_LEAK_CHECK_DF );
#endif//defined(DEBUG) || defined(_DEBUG)

    HRESULT hr = S_OK;

    // デバイス作成.
    hr = CreateComputeDevice( &g_pDevice, &g_pContext );
    if ( FAILED( hr ) )
    { return 1; }


    // コンピュートシェーダを作成.
    hr = CreateComputeShader( L"BasicCompute.hlsl", "CSFunc", g_pDevice, &g_pCS );
    assert( SUCCEEDED( hr ) );


    // バッファにデータを格納.
    for( int i=0; i<NUM_ELEMENTS; ++i )
    {
        g_Buf0[ i ].s32 = i;
        g_Buf0[ i ].f32 = static_cast<float>( i ) * 0.25f;

        g_Buf1[ i ].s32 = i;
        g_Buf1[ i ].f32 = static_cast<float>( i ) * 0.75f;
    }


    // 構造化バッファを生成.
    hr = CreateStructuredBuffer( g_pDevice, sizeof(BufType), NUM_ELEMENTS, &g_Buf0[ 0 ], &g_pBuf0 );
    assert( SUCCEEDED( hr ) );
    hr = CreateStructuredBuffer( g_pDevice, sizeof(BufType), NUM_ELEMENTS, &g_Buf1[ 0 ], &g_pBuf1 );
    assert( SUCCEEDED( hr ) );
    hr = CreateStructuredBuffer( g_pDevice, sizeof(BufType), NUM_ELEMENTS, nullptr, &g_pBufResult );
    assert( SUCCEEDED( hr ) );


    // 入出力用のビューを生成.
    hr = CreateBufferSRV( g_pDevice, g_pBuf0, &g_pBufSRV0 );
    assert( SUCCEEDED( hr ) );
    hr = CreateBufferSRV( g_pDevice, g_pBuf1, &g_pBufSRV1 );
    assert( SUCCEEDED( hr ) );
    hr = CreateBufferUAV( g_pDevice, g_pBufResult, &g_pBufUAV );
    assert( SUCCEEDED( hr ) );


    // コンピュートシェーダを走らせる.
    ID3D11ShaderResourceView* pSRVs[ 2 ] = { g_pBufSRV0, g_pBufSRV1 };
    RunComputeShader( g_pContext, g_pCS, 2, pSRVs, nullptr, nullptr, 0, g_pBufUAV, 32, 1, 1 );


    // GPU からの計算結果を読み戻して，CPUで計算した結果と同じであるか検証する.
    {
        // バッファを生成とコピー.
        ID3D11Buffer* pBufDbg = CreateAndCopyToBuffer( g_pDevice, g_pContext, g_pBufResult );

        D3D11_MAPPED_SUBRESOURCE subRes;
        BufType* pBufType;

        // マップ.
        hr = g_pContext->Map( pBufDbg, 0, D3D11_MAP_READ, 0, &subRes );
        assert( SUCCEEDED( hr ) );

        pBufType = (BufType*)subRes.pData;

        ILOG( "Verifying against CPU result..." );
        bool isSuccess = true;

        for( int i=0; i<NUM_ELEMENTS; ++i )
        {
            // CPUで演算.
            int   value0 = g_Buf0[i].s32 + g_Buf1[i].s32;
            float value1 = g_Buf0[i].f32 * g_Buf1[i].f32;

            // GPUの演算結果とCPUの演算結果を比較.
            if ( ( pBufType[i].s32 != value0 )
              || ( pBufType[i].f32 != value1 ) )
            {
                ILOG( "Failure." );
                ILOG( "  index = %d", i );
                ILOG( "  cpu value0 = %d, value1 = %f", value0, value1 );
                ILOG( "  gpu value0 = %d, value1 = %f", pBufType[i].s32, pBufType[i].f32 );
                isSuccess = false;
                break;
            }
        }

        // CPUとGPUの結果がすべて一致したら成功のログを出力.
        if ( isSuccess )
        { ILOG( "Succeded!!" ); }

        // アンマップ.
        g_pContext->Unmap( pBufDbg, 0 );

        // 解放処理.
        ASDX_RELEASE( pBufDbg );
    }


    // 終了処理.
    ASDX_RELEASE( g_pBufSRV0 );
    ASDX_RELEASE( g_pBufSRV1 );
    ASDX_RELEASE( g_pBufUAV );
    ASDX_RELEASE( g_pBuf0 );
    ASDX_RELEASE( g_pBuf1 );
    ASDX_RELEASE( g_pBufResult );
    ASDX_RELEASE( g_pCS );
    ASDX_RELEASE( g_pContext );
    ASDX_RELEASE( g_pDevice );

    return 0;
}


//---------------------------------------------------------------------------------------------
//      デバイス生成処理.
//---------------------------------------------------------------------------------------------
HRESULT CreateComputeDevice( ID3D11Device** ppDevice, ID3D11DeviceContext** ppContext )
{
    (*ppDevice)  = nullptr;
    (*ppContext) = nullptr;

    HRESULT hr = S_OK;

    UINT creationFlags = D3D11_CREATE_DEVICE_SINGLETHREADED;

#if defined(DEBUG) || defined(_DEBUG)
    creationFlags |= D3D11_CREATE_DEVICE_DEBUG;
#endif//defined(DEBUG) || defined(_DEBUG)

    D3D_FEATURE_LEVEL featureLv;
    static const D3D_FEATURE_LEVEL fLvs[] = {
        D3D_FEATURE_LEVEL_11_0,
        D3D_FEATURE_LEVEL_10_1,
        D3D_FEATURE_LEVEL_10_0,
    };

    hr = D3D11CreateDevice(
        nullptr,
        D3D_DRIVER_TYPE_HARDWARE,
        nullptr,
        creationFlags,
        fLvs,
        sizeof(fLvs) / sizeof(fLvs[0]),
        D3D11_SDK_VERSION,
        ppDevice,
        &featureLv,
        ppContext );

    /* このプログラムでは，PCがコンピュートシェーダを使えるものとしているので，
       コンピュートシェーダが使えるかなどのデバイスのサポートチェックは省略 */

    return hr;
}


//---------------------------------------------------------------------------------------------
//      ファイルからシェーダをコンパイルします.
//---------------------------------------------------------------------------------------------
HRESULT CompileShaderFromFile
(
    const wchar_t*  szFileName,
    const char*     szEntryPoint,
    const char*     szShaderModel,
    ID3DBlob**      ppBlobOut
)
{
    // リターンコードを初期化.
    HRESULT hr = S_OK;

    // コンパイルフラグ.
    DWORD dwShaderFlags = D3DCOMPILE_ENABLE_STRICTNESS;

#if defined(DEBUG) || defined(_DEBUG)
    dwShaderFlags |= D3DCOMPILE_DEBUG;
#endif//defiend(DEBUG) || defined(_DEBUG)

#if defined(NDEBUG) || defined(_NDEBUG)
    dwShaderFlags |= D3DCOMPILE_OPTIMIZATION_LEVEL3;
#endif//defined(NDEBUG) || defined(_NDEBUG)

    ID3DBlob* pErrorBlob = nullptr;

    // ファイルからシェーダをコンパイル.
    hr = D3DCompileFromFile(
        szFileName,
        NULL,
        D3D_COMPILE_STANDARD_FILE_INCLUDE,
        szEntryPoint,
        szShaderModel,
        dwShaderFlags,
        0,
        ppBlobOut,
        &pErrorBlob 
    );

    // エラーチェック.
    if ( FAILED( hr ) )
    {
        // エラーメッセージを出力.
        if ( pErrorBlob != nullptr )
        { OutputDebugStringA( (char*)pErrorBlob->GetBufferPointer() ); }
    }

    // 解放処理.
    ASDX_RELEASE( pErrorBlob );
 
    // リターンコードを返却.
    return hr;
}


//---------------------------------------------------------------------------------------------
//      コンピュートシェーダの生成.
//---------------------------------------------------------------------------------------------
HRESULT CreateComputeShader
(
    const wchar_t*          srcFile,
    const char*             functionName,
    ID3D11Device*           pDevice, 
    ID3D11ComputeShader**   ppShaderOut 
)
{
    HRESULT hr = S_OK;

    DWORD dwShaderFlags = D3DCOMPILE_ENABLE_STRICTNESS;
#if defined(DEBUG) || defined(_DEBUG)
    dwShaderFlags |= D3DCOMPILE_DEBUG;
#endif//defined(DEBUG) || defined(_DEBUG)

    ID3DBlob* pBlob = nullptr;

    // コンパイル.
    hr = CompileShaderFromFile( srcFile, functionName, "cs_4_0", &pBlob );
    if ( FAILED( hr ) )
    {
        ASDX_RELEASE( pBlob );
        return hr;
    }

    // シェーダ生成.
    hr = pDevice->CreateComputeShader( pBlob->GetBufferPointer(), pBlob->GetBufferSize(), nullptr, ppShaderOut );

    // バイトレングスオブジェクトを破棄.
    ASDX_RELEASE( pBlob );

    return hr;
}


//---------------------------------------------------------------------------------------------
//      構造化バッファを生成します.
//      ※ 頂点バッファやインデックスバッファとしては使用不可。
//---------------------------------------------------------------------------------------------
HRESULT CreateStructuredBuffer
(
    ID3D11Device*   pDevice,
    UINT            elementSize,
    UINT            count,
    void*           pInitData,
    ID3D11Buffer**  ppBufferOut
)
{
    (*ppBufferOut) = nullptr;

    D3D11_BUFFER_DESC desc;
    memset( &desc, 0, sizeof(desc) );

    desc.BindFlags           = D3D11_BIND_UNORDERED_ACCESS | D3D11_BIND_SHADER_RESOURCE;
    desc.ByteWidth           = elementSize * count;
    desc.MiscFlags           = D3D11_RESOURCE_MISC_BUFFER_STRUCTURED;
    desc.StructureByteStride = elementSize;

    if ( pInitData )
    {
        D3D11_SUBRESOURCE_DATA initData;
        initData.pSysMem = pInitData;

        return pDevice->CreateBuffer( &desc, &initData, ppBufferOut );
    }

    return pDevice->CreateBuffer( &desc, nullptr, ppBufferOut );
}


//---------------------------------------------------------------------------------------------
//      バイトアドレスバッファを生成します.
//      ※頂点バッファやインデックスバッファとして使用可能.
//---------------------------------------------------------------------------------------------
HRESULT CreateRawBuffer( ID3D11Device* pDevice, UINT size, void* pInitData, ID3D11Buffer** ppBufferOut )
{
    (*ppBufferOut) = nullptr;

    D3D11_BUFFER_DESC desc;
    memset( &desc, 0, sizeof(desc) );

    desc.BindFlags = D3D11_BIND_UNORDERED_ACCESS | D3D11_BIND_SHADER_RESOURCE | D3D11_BIND_INDEX_BUFFER | D3D11_BIND_VERTEX_BUFFER;
    desc.ByteWidth = size;
    desc.MiscFlags = D3D11_RESOURCE_MISC_BUFFER_ALLOW_RAW_VIEWS;

    if ( pInitData )
    {
        D3D11_SUBRESOURCE_DATA initData;
        initData.pSysMem = pInitData;

        return pDevice->CreateBuffer( &desc, &initData, ppBufferOut );
    }

    return pDevice->CreateBuffer( &desc, nullptr, ppBufferOut );
}

//---------------------------------------------------------------------------------------------
//      シェーダリソースビューを生成します.
//---------------------------------------------------------------------------------------------
HRESULT CreateBufferSRV( ID3D11Device* pDevice, ID3D11Buffer* pBuffer, ID3D11ShaderResourceView** ppSRVOut )
{
    D3D11_BUFFER_DESC desc;
    memset( &desc, 0, sizeof(desc) );
    pBuffer->GetDesc( &desc );

    D3D11_SHADER_RESOURCE_VIEW_DESC srvDesc;
    memset( &srvDesc, 0, sizeof(srvDesc) );

    srvDesc.ViewDimension = D3D11_SRV_DIMENSION_BUFFEREX;
    srvDesc.BufferEx.FirstElement = 0;

    if ( desc.MiscFlags & D3D11_RESOURCE_MISC_BUFFER_ALLOW_RAW_VIEWS )
    {
        srvDesc.Format               = DXGI_FORMAT_R32_TYPELESS;
        srvDesc.BufferEx.Flags       = D3D11_BUFFEREX_SRV_FLAG_RAW;
        srvDesc.BufferEx.NumElements = desc.ByteWidth / 4;
    }
    else if ( desc.MiscFlags & D3D11_RESOURCE_MISC_BUFFER_STRUCTURED )
    {
        srvDesc.Format               = DXGI_FORMAT_UNKNOWN;
        srvDesc.BufferEx.NumElements = desc.ByteWidth / desc.StructureByteStride;
    }
    else
    {
        return E_INVALIDARG;
    }

    return pDevice->CreateShaderResourceView( pBuffer, &srvDesc, ppSRVOut );
}

//---------------------------------------------------------------------------------------------
//      アンオーダードアクセスビューを生成します.
//---------------------------------------------------------------------------------------------
HRESULT CreateBufferUAV( ID3D11Device* pDevice, ID3D11Buffer* pBuffer, ID3D11UnorderedAccessView** ppUAVOut )
{
    D3D11_BUFFER_DESC desc;
    memset( &desc, 0, sizeof(desc) );
    pBuffer->GetDesc( &desc );

    D3D11_UNORDERED_ACCESS_VIEW_DESC uavDesc;
    memset( &uavDesc, 0, sizeof(uavDesc ) );

    uavDesc.ViewDimension = D3D11_UAV_DIMENSION_BUFFER;
    uavDesc.Buffer.FirstElement = 0;

    if ( desc.MiscFlags & D3D11_RESOURCE_MISC_BUFFER_ALLOW_RAW_VIEWS )
    {
        uavDesc.Format              = DXGI_FORMAT_R32_TYPELESS;
        uavDesc.Buffer.Flags        = D3D11_BUFFER_UAV_FLAG_RAW;
        uavDesc.Buffer.NumElements  = desc.ByteWidth / 4;
    }
    else if ( desc.MiscFlags & D3D11_RESOURCE_MISC_BUFFER_STRUCTURED )
    {
        uavDesc.Format              = DXGI_FORMAT_UNKNOWN;
        uavDesc.Buffer.NumElements  = desc.ByteWidth / desc.StructureByteStride;
    }
    else
    {
        return E_INVALIDARG;
    }

    return pDevice->CreateUnorderedAccessView( pBuffer, &uavDesc, ppUAVOut );
}


//---------------------------------------------------------------------------------------------
//      バッファを生成し，内容をコピーします.
//---------------------------------------------------------------------------------------------
ID3D11Buffer* CreateAndCopyToBuffer( ID3D11Device* pDevice, ID3D11DeviceContext* pContext, ID3D11Buffer* pBuffer )
{
    ID3D11Buffer* pCloneBuf = nullptr;

    D3D11_BUFFER_DESC desc;
    memset( &desc, 0, sizeof(desc) );

    pBuffer->GetDesc( &desc );
    desc.CPUAccessFlags = D3D11_CPU_ACCESS_READ;
    desc.Usage = D3D11_USAGE_STAGING;
    desc.BindFlags = 0;
    desc.MiscFlags = 0;

    if ( SUCCEEDED( pDevice->CreateBuffer( &desc, nullptr, &pCloneBuf ) ) )
    {
        pContext->CopyResource( pCloneBuf, pBuffer );
    }

    return pCloneBuf;
}


//----------------------------------------------------------------------------------------------
//      コンピュートシェーダを実行します.
//----------------------------------------------------------------------------------------------
void RunComputeShader
(
    ID3D11DeviceContext*        pContext,
    ID3D11ComputeShader*        pComputeShader,
    UINT                        numViews,
    ID3D11ShaderResourceView**  pSRVs,
    ID3D11Buffer*               pCBCS,
    void*                       pCSData,
    DWORD                       numDataBytes,
    ID3D11UnorderedAccessView*  pUAV,
    UINT                        x,
    UINT                        y,
    UINT                        z
)
{
    pContext->CSSetShader( pComputeShader, nullptr, 0 );
    pContext->CSSetShaderResources( 0, numViews, pSRVs );
    pContext->CSSetUnorderedAccessViews( 0, 1, &pUAV, nullptr );

    if ( pCBCS )
    {
        D3D11_MAPPED_SUBRESOURCE res;

        pContext->Map( pCBCS, 0, D3D11_MAP_WRITE_DISCARD, 0, &res );
        memcpy( res.pData, pCSData, numDataBytes );
        pContext->Unmap( pCBCS, 0 );

        ID3D11Buffer* ppCB[ 1 ] = { pCBCS };
        pContext->CSSetConstantBuffers( 0, 1, ppCB );
    }

    pContext->Dispatch( x, y, z );

    ID3D11UnorderedAccessView*  pNullUAVs[ 1 ] = { nullptr };
    ID3D11ShaderResourceView*   pNullSRVs[ 2 ] = { nullptr, nullptr };
    ID3D11Buffer*               pNullCBs [ 1 ] = { nullptr };

    pContext->CSSetShader( nullptr, nullptr, 0 );
    pContext->CSSetUnorderedAccessViews( 0, 1, pNullUAVs, nullptr );
    pContext->CSSetShaderResources( 0, 2, pNullSRVs );
    pContext->CSSetConstantBuffers( 0, 1, pNullCBs );
}