#include "stdafx.hpp"

#include "XMLWriter.hpp"

#include <msxml2.h>
#include <shlwapi.h>

#include <comutil.h>

#include <assert.h>

#include <vector>

#include "tstringUty.hpp"

namespace
{
	class SAXAttributesAdaptor : public ISAXAttributes
	{
	private:
		DWORD ref_;

		const SAXAttributes& attr_;

		std::wstring tmp_;
		std::wstring tmp2_;
		std::wstring tmp3_;

	public:
		SAXAttributesAdaptor( const SAXAttributes& v_attr )
			: ref_( 0 )
			, attr_( v_attr )
		{
		}

		virtual ~SAXAttributesAdaptor()
		{
			assert( ref_ == 0 && "QƃJEg[ł͂܂B" );
		}

		HRESULT __stdcall QueryInterface( const IID& iid, void** ppvObj )
		{
			if( iid == IID_IUnknown || iid == __uuidof( ISAXAttributes ) ) {
				*ppvObj = this;
				AddRef();
				return NOERROR;
			}
			return E_NOINTERFACE;
		}

		ULONG __stdcall AddRef(void)
		{
			return ++ref_;
		}

		ULONG __stdcall Release(void)
		{
			if( --ref_ == 0 ) {
				delete this;
				return 0;
			}
			return ref_;
		}

		HRESULT __stdcall getLength(int* v_pLen )
		{
			assert( v_pLen != NULL && "NULL͎wł܂B" );

			*v_pLen = attr_.getLength();

			return S_OK;
		}

		HRESULT __stdcall getURI(int v_index, const wchar_t** v_pNS, int* v_pNSLen )
		{
			assert( v_pNS != NULL && v_pNSLen != NULL && "NULL͎wł܂B" );

			tmp_ = tstringuty::getWStdString( attr_.getURI( v_index ) );
			*v_pNS = tmp_.c_str();
			*v_pNSLen = (int) tmp_.length();

			return S_OK;
		}

		HRESULT __stdcall getLocalName(int v_index, const wchar_t** v_pLocalName, int* v_pLocalNameLen )
		{
			assert( v_pLocalName != NULL && v_pLocalNameLen != NULL && "NULL͎wł܂B" );

			tmp_ = tstringuty::getWStdString( attr_.getLocalName( v_index ) );
			*v_pLocalName = tmp_.c_str();
			*v_pLocalNameLen = (int) tmp_.length();

			return S_OK;
		}

		HRESULT __stdcall getQName(int v_index, const wchar_t** v_pQName, int* v_pQNameLen )
		{
			assert( v_pQName != NULL && v_pQNameLen != NULL && "NULL͎wł܂B" );

			tmp_ = tstringuty::getWStdString( attr_.getQName( v_index ) );
			*v_pQName = tmp_.c_str();
			*v_pQNameLen = (int) tmp_.length();

			return S_OK;
		}

		HRESULT __stdcall getName(
			int v_index,
			const wchar_t** v_pNS,
			int* v_pNSLen,
			const wchar_t** v_pLocalName,
			int* v_pLocalNameLen,
			const wchar_t** v_pQName,
			int * v_pQNameLen
			)
		{
			tstring ns;
			tstring localName;
			tstring qName;
			attr_.getName( v_index, &ns, &localName, &qName );
			
			if( v_pNS != NULL && v_pNSLen != NULL ) {
				tmp_ = tstringuty::getWStdString( ns );
				*v_pNS = tmp_.c_str();
				*v_pNSLen = (int) tmp_.length();
			}
			if( v_pLocalName != NULL && v_pLocalNameLen != NULL ) {
				tmp2_ = tstringuty::getWStdString( localName );
				*v_pLocalName = tmp2_.c_str();
				*v_pLocalNameLen = (int) tmp2_.length();
			}
			if( v_pQName != NULL && v_pQNameLen != NULL ) {
				tmp3_  = tstringuty::getWStdString( qName );
				*v_pQName = tmp3_.c_str();
				*v_pQNameLen = (int) tmp3_.length();
			}

			return S_OK;
		}

		HRESULT __stdcall getIndexFromName(
			const wchar_t* v_ns,
			int v_nsLen,
			const wchar_t* v_localName,
			int v_localNameLen,
			int* v_pIndex
			)
		{
			assert( v_ns != NULL && v_localName != NULL && "NULL͎wł܂B" );
			assert( v_pIndex != NULL && "NULL͎wł܂B" );

			const std::wstring ns( v_ns, v_nsLen );
			const std::wstring localName( v_localName, v_localNameLen );
			
			const int idx = attr_.getIndexFromName( tstringuty::getTString( ns ), tstringuty::getTString( localName ) );
			if( idx >= 0 ) {
				*v_pIndex = idx;
				return S_OK;
			}
			return E_INVALIDARG;
		}

		HRESULT __stdcall getIndexFromQName(
			const wchar_t* v_qName,
			int v_qNameLen,
			int* v_pLen
			)
		{
			assert( v_qName != NULL && "NULL͎wł܂B" );
			assert( v_pLen != NULL && "NULL͎wł܂B" );

			const std::wstring qName( v_qName, v_qNameLen );
			
			const int idx = attr_.getIndexFromQName( tstringuty::getTString( qName ) );
			if( idx >= 0 ) {
				*v_pLen = idx;
				return S_OK;
			}
			return E_INVALIDARG;
		}

		HRESULT __stdcall getType(int v_index, const wchar_t** v_pType, int* v_pTypeLen )
		{
			assert( v_index >= 0 && "CfbNXsłB" );
			assert( v_pType != NULL && v_pTypeLen != NULL && "NULL͎wł܂B" );

			tmp_ = tstringuty::getWStdString( attr_.getType( v_index ) );

			*v_pType = tmp_.c_str();
			*v_pTypeLen = (int) tmp_.length();

			return S_OK;
		}

		HRESULT __stdcall getTypeFromName(
			const wchar_t* v_ns,
			int v_nsLen,
			const wchar_t* v_localName,
			int v_localNameLen,
			const wchar_t** v_pType,
			int* v_pTypeLen
			)
		{
			assert( v_ns != NULL && v_localName != NULL && "NULL͎wł܂B" );
			assert( v_pType != NULL && v_pTypeLen != NULL && "NULL͎wł܂B" );

			const std::wstring ns( v_ns, v_nsLen );
			const std::wstring localName( v_localName, v_localNameLen );

			tmp_ = tstringuty::getWStdString(
					attr_.getTypeFromName(
						tstringuty::getTString( ns ),
						tstringuty::getTString( localName )
					) );

			*v_pType = tmp_.c_str();
			*v_pTypeLen = (int) tmp_.length();

			return S_OK;
		}

		HRESULT __stdcall getTypeFromQName(
			const wchar_t* v_qName,
			int v_qNameLen,
			const wchar_t** v_pType,
			int* v_pTypeLen
			)
		{
			assert( v_qName != NULL && "NULL͎wł܂B" );
			assert( v_pType != NULL && v_pTypeLen != NULL && "NULL͎wł܂B" );

			const std::wstring qName( v_qName, v_qNameLen );
			
			tmp_ = tstringuty::getWStdString(
					attr_.getTypeFromQName(
						tstringuty::getTString( qName )
					) );

			*v_pType = tmp_.c_str();
			*v_pTypeLen = (int) tmp_.length();

			return S_OK;
		}

		HRESULT __stdcall getValue(int v_index, const wchar_t** v_value, int* v_valueLen )
		{
			assert( v_index >= 0 && "CfbNXsłB" );
			assert( v_value != NULL && v_valueLen != NULL && "NULL͎wł܂B" );

			tmp_ = tstringuty::getWStdString( attr_.getValue( v_index ) );

			*v_value = tmp_.c_str();
			*v_valueLen = (int) tmp_.length();

			return S_OK;
		}

		HRESULT __stdcall getValueFromName(
			const wchar_t* v_ns,
			int v_nsLen,
			const wchar_t* v_localName,
			int v_localNameLen,
			const wchar_t ** v_pValue,
			int* v_pValueLen
			)
		{
			assert( v_ns != NULL && v_localName != NULL && "NULL͎wł܂B" );
			assert( v_pValue != NULL && v_pValueLen != NULL && "NULL͎wł܂B" );

			const std::wstring ns( v_ns, v_nsLen );
			const std::wstring localName( v_localName, v_localNameLen );

			tmp_ = tstringuty::getWStdString(
					attr_.getValueFromName(
						tstringuty::getTString( ns ),
						tstringuty::getTString( localName )
					) );

			*v_pValue = tmp_.c_str();
			*v_pValueLen = (int) tmp_.length();

			return S_OK;
		}

		HRESULT __stdcall getValueFromQName(
			const wchar_t* v_qName,
			int v_qNameLen,
			const wchar_t** v_pValue,
			int* v_pValueLen
			)
		{
			assert( v_pValue != NULL && v_pValueLen != NULL && "NULL͎wł܂B" );
			assert( v_qName != NULL && "NULL͎wł܂B" );

			const std::wstring qName( v_qName, v_qNameLen );

			tmp_  = tstringuty::getWStdString(
					attr_.getValueFromQName(
						tstringuty::getTString( qName )
					) );

			*v_pValue = tmp_.c_str();
			*v_pValueLen = (int) tmp_.length();

			return S_OK;
		}

	};

	class DefaultSAXAttributesHolder : public SAXAttributesHolder
	{
	private:
		DefaultSAXAttributesHolder( const DefaultSAXAttributesHolder& ); //<! Ȃ
		operator=( const DefaultSAXAttributesHolder& ); //<! Ȃ

	private:
		IMXAttributes* pMXAttr_;
		ISAXAttributes* pAttr_;

	public:
		
		DefaultSAXAttributesHolder()
			: pMXAttr_( NULL )
			, pAttr_( NULL )
		{
			if( FAILED( CoCreateInstance(
				__uuidof( SAXAttributes ),
				NULL,
				CLSCTX_ALL,
				__uuidof( IMXAttributes ),
				(void**) &pMXAttr_
				) ) )
			{
				throw std::runtime_error( "MSXML MXAttribute Creation failed." );
			}
			if( FAILED( pMXAttr_->QueryInterface(
				__uuidof( ISAXAttributes ),
				(void**) &pAttr_ 
				) ) )
			{
				pMXAttr_->Release();
				throw std::runtime_error( "MSXML MXAttribute Creation failed." );
			}
		}
		
		virtual ~DefaultSAXAttributesHolder()
		{
			pAttr_->Release();
			pMXAttr_->Release();
		}

		virtual void clear()
		{
			const HRESULT hr = pMXAttr_->clear();
			assert( SUCCEEDED( hr ) && "clearɎs܂B" );
		}

		virtual void addAttribute(
			const tstring& v_ns,
			const tstring& v_localName,
			const tstring& v_qName,
			const tstring& v_type,
			const tstring& v_value )
		{
			const _bstr_t ns( v_ns.c_str() );
			const _bstr_t localName( v_localName.c_str() );
			const _bstr_t qName( v_qName.c_str() );
			const _bstr_t type( v_type.c_str() );
			const _bstr_t value( v_value.c_str() );
			
			const HRESULT hr = pMXAttr_->addAttribute( ns, localName, qName, type, value );
			assert( SUCCEEDED( hr ) && "addAttributeɎs܂B" );
		}

		virtual void removeAttribute( int v_index )
		{
			const HRESULT hr = pMXAttr_->removeAttribute( v_index );
			assert( SUCCEEDED( hr ) && "removeAttributeɎs܂B" );
		}

		virtual void setAttribute(
			int v_index,
			const tstring& v_ns,
			const tstring& v_localName,
			const tstring& v_qName,
			const tstring& v_type,
			const tstring& v_value )
		{
			const _bstr_t ns( v_ns.c_str() );
			const _bstr_t localName( v_localName.c_str() );
			const _bstr_t qName( v_qName.c_str() );
			const _bstr_t type( v_type.c_str() );
			const _bstr_t value( v_value.c_str() );

			const HRESULT hr = pMXAttr_->setAttribute( v_index, ns, localName, qName, type, value );
			assert( SUCCEEDED( hr ) && "setAttributeɎs܂B" );
		}

		///

		virtual int getLength() const
		{
			int len = 0;
			if( SUCCEEDED( pAttr_->getLength( &len ) ) ) {
				return len;
			}
			assert( false && "sG[" );
			return 0;
		}

		virtual tstring getURI( int v_index ) const
		{
			const wchar_t* ns;
			int len;
			if( SUCCEEDED( pAttr_->getURI( v_index, &ns, &len ) ) ) {
				const std::wstring tmp( ns, len );
				return tstringuty::getTString( tmp );
			}
			assert( false && "sG[" );
			return _TEXT("");
		}

		virtual tstring getLocalName( int v_index ) const
		{
			const wchar_t* localName;
			int len;
			if( SUCCEEDED( pAttr_->getLocalName( v_index, &localName, &len ) ) ) {
				const std::wstring tmp( localName, len );
				return tstringuty::getTString( tmp );
			}
			assert( false && "sG[" );
			return _TEXT("");
		}

		virtual tstring getQName( int v_index ) const
		{
			const wchar_t* qName;
			int len;
			if( SUCCEEDED( pAttr_->getQName( v_index, &qName, &len ) ) ) {
				const std::wstring tmp( qName, len );
				return tstringuty::getTString( tmp );
			}
			assert( false && "sG[" );
			return _TEXT("");
		}

		virtual void getName( int v_index, tstring* v_pNS, tstring* v_pLocalName, tstring* v_pQName ) const
		{
			const wchar_t* ns;
			const wchar_t* localName;
			const wchar_t* qName;
			int nsLen;
			int localNameLen;
			int qNameLen;

			if( SUCCEEDED( pAttr_->getName(
				v_index,
				&ns,
				&nsLen,
				&localName,
				&localNameLen,
				&qName,
				&qNameLen
				) ) )
			{
				if( v_pNS != NULL ) {
					const std::wstring tmp( ns, nsLen );
					*v_pNS = tstringuty::getTString( tmp );
				}
				if( v_pLocalName != NULL ) {
					const std::wstring tmp( localName, localNameLen );
					*v_pLocalName = tstringuty::getTString( tmp );
				}
				if( v_pQName != NULL ) {
					const std::wstring tmp( qName, qNameLen );
					*v_pQName = tstringuty::getTString( tmp );
				}
			}
			assert( false && "sG[" );
		}

		virtual int getIndexFromName( const tstring& v_ns, const tstring& v_localName ) const
		{
			const std::wstring ns( tstringuty::getWStdString( v_ns ) );
			const std::wstring localName( tstringuty::getWStdString( v_localName ) );

			int idx;
			if( SUCCEEDED( pAttr_->getIndexFromName(
				ns.c_str(),
				(int) ns.length(),
				localName.c_str(),
				(int) localName.length(),
				&idx
				) ) )
			{
				return idx;
			}
			return -1;
		}

		virtual int getIndexFromQName( const tstring& v_qName ) const
		{
			const std::wstring qName( tstringuty::getWStdString( v_qName ) );
			int idx;
			if( SUCCEEDED( pAttr_->getIndexFromQName(
				qName.c_str(),
				(int) qName.length(),
				&idx
				) ) )
			{
				return idx;
			}
			return -1;
		}

		virtual tstring getType(int v_index ) const
		{
			const wchar_t* typeName;
			int len;

			if( SUCCEEDED( pAttr_->getType( v_index, &typeName, &len ) ) ) {
				const std::wstring tmp( typeName, len );
				return tstringuty::getTString( tmp );
			}

			assert( false && "sG[łB" );
			return _TEXT("");
		}

		virtual tstring getTypeFromName( const tstring& v_ns, const tstring& v_localName ) const
		{
			const std::wstring ns( tstringuty::getWStdString( v_ns ) );
			const std::wstring localName( tstringuty::getWStdString( v_localName ) );

			const wchar_t* typeName;
			int len;

			if( SUCCEEDED( pAttr_->getTypeFromName(
				ns.c_str(),
				(int) ns.length(),
				localName.c_str(),
				(int) localName.length(),
				&typeName,
				&len
				) ) )
			{
				const std::wstring tmp( typeName, len );
				return tstringuty::getTString( tmp );
			}
			return _TEXT("");
		}

		virtual tstring getTypeFromQName( const tstring& v_qName ) const 
		{
			const std::wstring qName( tstringuty::getWStdString( v_qName ) );

			const wchar_t* typeName;
			int len;

			if( SUCCEEDED( pAttr_->getTypeFromQName(
				qName.c_str(),
				(int) qName.length(),
				&typeName,
				&len
				) ) )
			{
				const std::wstring tmp( typeName, len );
				return tstringuty::getTString( tmp );
			}
			return _TEXT("");
		}

		virtual tstring getValue( int v_index ) const
		{
			const wchar_t* value;
			int len;

			if( SUCCEEDED( pAttr_->getValue( v_index, &value, &len ) ) ) {
				const std::wstring tmp( value, len );
				return tstringuty::getTString( tmp );
			}

			assert( false && "sG[łB" );
			return _TEXT("");
		}

		virtual tstring getValueFromName( const tstring& v_ns, const tstring& v_localName ) const
		{
			const std::wstring ns( tstringuty::getWStdString( v_ns ) );
			const std::wstring localName( tstringuty::getWStdString( v_localName ) );

			const wchar_t* value;
			int len;

			if( SUCCEEDED( pAttr_->getValueFromName(
				ns.c_str(),
				(int) ns.length(),
				localName.c_str(),
				(int) localName.length(),
				&value,
				&len
				) ) )
			{
				const std::wstring tmp( value, len );
				return tstringuty::getTString( tmp );
			}

			return _TEXT("");
		}

		virtual tstring getValueFromQName( const tstring& v_qName ) const
		{
			const std::wstring qName( tstringuty::getWStdString( v_qName ) );

			const wchar_t* value;
			int len;

			if( SUCCEEDED( pAttr_->getValueFromQName(
				qName.c_str(),
				(int) qName.length(),
				&value,
				&len
				) ) )
			{
				const std::wstring tmp( value, len );
				return tstringuty::getTString( tmp );
			}

			return _TEXT("");
		}
	};

	class DefaultXMLWriter : public XMLWriter
	{
	private:
		DefaultXMLWriter( const DefaultXMLWriter& );
		operator=( const DefaultXMLWriter& );

	private:
		IMXWriter* pXMLWriter_;
		ISAXContentHandler* pContentHandler_;

	public:
		DefaultXMLWriter( const tstring& v_fileName )
			: pXMLWriter_( NULL )
			, pContentHandler_( NULL )
		{
			if( FAILED( ::CoCreateInstance(
				__uuidof(MXXMLWriter),
				NULL,
				CLSCTX_ALL,
				__uuidof(IMXWriter),
				(void **) &pXMLWriter_
				) ) )
			{
				throw std::runtime_error( "MSXML XMLWriter Creation failed." );
			}

			const _bstr_t encodingName( L"csWindows31J" );
			pXMLWriter_->put_encoding( encodingName );
			pXMLWriter_->put_indent( VARIANT_TRUE );

			if( FAILED( pXMLWriter_->QueryInterface(
				__uuidof( ISAXContentHandler ),
				(void**) &pContentHandler_
				) ) )
			{
				pXMLWriter_->Release();
				throw std::runtime_error( "MSXML XMLWriter Creation failed." );
			}

			IStream* pStream = NULL;
			if( FAILED( ::SHCreateStreamOnFile( v_fileName.c_str(), STGM_WRITE | STGM_CREATE, &pStream ) ) ) {
				pXMLWriter_->Release();
				throw std::runtime_error( "file creation failed." );
			}

			VARIANT output;
			::VariantInit( &output );
			output.vt = VT_UNKNOWN;
			output.punkVal = pStream;
			if( FAILED( pXMLWriter_->put_output( output ) ) ) {
				pStream->Release();
				pXMLWriter_->Release();
				throw std::runtime_error( "MSXML OutputStream binding failure." );
			}
			::VariantClear( &output );
		}

		virtual ~DefaultXMLWriter()
		{
			pContentHandler_->Release();
			pXMLWriter_->Release();
		}

		virtual SAXAttributesHolderPtr createAttributesHolder()
		{
			return SAXAttributesHolderPtr( new DefaultSAXAttributesHolder() );
		}

		virtual void startDocument()
		{
			pContentHandler_->startDocument();
		}

		virtual void endDocument()
		{
			pContentHandler_->endDocument();
		}

		virtual void startPrefixMapping(const tstring& v_prefix, const tstring& v_url )
		{
			const std::wstring prefix( tstringuty::getWStdString( v_prefix ) );
			const std::wstring url( tstringuty::getWStdString( v_url ) );

			HRESULT hr = pContentHandler_->startPrefixMapping(
				prefix.c_str(),
				(int) prefix.length(),
				url.c_str(),
				(int) url.length()
				);
			assert( SUCCEEDED( hr ) && "startPrefixMappingɎs" );
		}

		virtual void endPrefixMapping(const tstring& v_prefix )
		{
			const std::wstring prefix( tstringuty::getWStdString( v_prefix ) );

			HRESULT hr = pContentHandler_->endPrefixMapping(
				prefix.c_str(),
				(int) prefix.length()
				);
			assert( SUCCEEDED( hr ) && "endPrefixMappingɎs" );
		}

		virtual void startElement(
			const tstring& v_namespace,
			const tstring& v_localName,
			const tstring& v_QName,
			const SAXAttributes& pAttr
			)
		{
			const std::wstring ns( tstringuty::getWStdString( v_namespace ) );
			const std::wstring localName( tstringuty::getWStdString( v_localName ) );
			const std::wstring qName( tstringuty::getWStdString( v_QName ) );

			SAXAttributesAdaptor* pAttrAdaptor = new SAXAttributesAdaptor( pAttr );

			const HRESULT hr = pContentHandler_->startElement(
				ns.c_str(),
				(int) ns.length(),
				localName.c_str(),
				(int) localName.length(),
				qName.c_str(),
				(int) qName.length(),
				pAttrAdaptor
				);

			delete pAttrAdaptor;
			assert( SUCCEEDED( hr ) && "startElementɎs" );
		}

		virtual void endElement(
			const tstring& v_namespace,
			const tstring& v_localName,
			const tstring& v_QName
			)
		{
			const std::wstring ns( tstringuty::getWStdString( v_namespace ) );
			const std::wstring localName( tstringuty::getWStdString( v_localName ) );
			const std::wstring qName( tstringuty::getWStdString( v_QName ) );

			HRESULT hr = pContentHandler_->endElement(
				ns.c_str(),
				(int)ns.length(),
				localName.c_str(),
				(int)localName.length(),
				qName.c_str(),
				(int) qName.length()
				);
			assert( SUCCEEDED( hr ) && "endElementɎs" );
		}

		virtual void characters(const tstring& v_chars )
		{
			std::wstring chars( tstringuty::getWStdString( v_chars ) );
			HRESULT hr = pContentHandler_->characters(
				chars.c_str(),
				(int) chars.length()
				);
			assert( SUCCEEDED( hr ) && "charactersɎs" );
		}

		virtual void ignorableWhitespace(const tstring& v_chars )
		{
			std::wstring chars( tstringuty::getWStdString( v_chars ) );
			HRESULT hr = pContentHandler_->ignorableWhitespace(
				chars.c_str(),
				(int) chars.length()
				);
			assert( SUCCEEDED( hr ) && "ignorableWhitespaceɎs" );
		}

		virtual void processingInstruction( const tstring& , const tstring& )
		{
			assert( false && "T|[gĂ܂B" );
		}

		virtual void skippedEntity( const tstring& )
		{
			assert( false && "T|[gĂ܂B" );
		}

	};
}


XMLWriterFactory::XMLWriterFactory()
{
}

XMLWriterFactory::~XMLWriterFactory()
{
}

XMLWriterPtr XMLWriterFactory::create( const tstring& v_fileName )
{
	return XMLWriterPtr( new DefaultXMLWriter( v_fileName ) );
}

