/*************************************************************************************************/
/*!
   	@file		ComBase.h
	@author 	Fanzo
*/
/*************************************************************************************************/
#pragma		once

///////////////////////////////////////////////////////////////////////////////////////////////////
//include files
#include	<windows.h>

#pragma pack( push , 8 )		//set align

namespace icubic
{

///////////////////////////////////////////////////////////////////////////////////////////////////
// preprocessor deifne
#define		cb_comcall( rv )	virtual rv __stdcall

#define		cb_com_begin()									\
public:														\
cb_comcall( HRESULT ) QueryInterface						\
		(													\
		const GUID&	riid,									\
		void**		object									\
		)													\
{

#define		cb_com_query( t_guid , t_interface )			\
	if( IsEqualIID( riid , t_guid ) )						\
	{														\
		AddRef();											\
		*object	= static_cast<t_interface*>( this );		\
		return S_OK;										\
	}

#define		cb_com_end( t_super_class )						\
	return t_super_class::QueryInterface( riid , object );	\
}															\
cb_comcall( ULONG ) AddRef()								\
{															\
	return ComObject::AddRef();								\
}															\
cb_comcall( ULONG ) Release()								\
{															\
	return ComObject::Release();							\
}

///////////////////////////////////////////////////////////////////////////////////////////////////
// type define

///////////////////////////////////////////////////////////////////////////////////////////////////
// shared functions
//=================================================================================================
cb_inline
bool CompareRegistedComServer
		(
		HINSTANCE		t_hinst , 
		const GUID&		t_clsid ,
		const wchar_t*	t_threadmodel
		)
{
	wstring		dllpath;
	{
		if( false == GetModulePath( t_hinst , &dllpath ) )
			return false;
		std::transform( dllpath.begin() , dllpath.end() , dllpath.begin() , std::tolower );
	}
	wstring		clsid;
	{
		if( false == ToString( t_clsid , &clsid ) )
			return false;
		std::transform( clsid.begin() , clsid.end() , clsid.begin() , std::tolower );
	}
	wstring		threadmodel	= t_threadmodel;
	std::transform( threadmodel.begin() , threadmodel.end() , threadmodel.begin() , std::tolower );
	
	{
		RegKey		reg;
		if( false == reg.Open( HKEY_CLASSES_ROOT , ( L"CLSID\\" + clsid + L"\\InprocServer32" ).c_str() , RegKey::Read ) )
			return false;
		{
			wstring		value;
			if( false == reg.GetStringValue( 0 , &value ) )
				return false;
			std::transform( value.begin() , value.end() , value.begin() , std::tolower );
			if( dllpath != value )
				return false;
		}
		{
			wstring		value;
			if( false == reg.GetStringValue( L"ThreadingModel" , &value ) )
				return false;
			std::transform( value.begin() , value.end() , value.begin() , std::tolower );
			if( threadmodel != value )
				return false;
		}
	}
	return true;
}
//=================================================================================================
cb_inline
bool UnregistComServer
		(
		const GUID&		t_clsid
		)
{
	wstring		clsid;
	{
		if( false == ToString( t_clsid , &clsid ) )
			return false;
		std::transform( clsid.begin() , clsid.end() , clsid.begin() , std::tolower );
	}
	{
		RegKey		reg;
		if( true == reg.Open( HKEY_CLASSES_ROOT , L"CLSID" , RegKey::Write ) )
			reg.DeleteSubkey( clsid.c_str() );
	}
	return true;
}
//=================================================================================================
cb_inline
bool RegistComServer
		(
		HINSTANCE		t_hinst , 
		const GUID&		t_clsid ,
		const wchar_t*	t_threadmodel
		)
{
	if( true == CompareRegistedComServer( t_hinst , t_clsid , t_threadmodel ) )
		return true;
	UnregistComServer( t_clsid );

	wstring		dllpath;
	{
		if( false == GetModulePath( t_hinst , &dllpath ) )
			return false;
		std::transform( dllpath.begin() , dllpath.end() , dllpath.begin() , std::tolower );
	}
	
	wstring		clsid;
	{
		if( false == ToString( t_clsid , &clsid ) )
			return false;
		std::transform( clsid.begin() , clsid.end() , clsid.begin() , std::tolower );
	}	
	{
		RegKey		reg;
		if( false == reg.Open( HKEY_CLASSES_ROOT , ( L"CLSID\\" + clsid ).c_str() , RegKey::Write ) )
			return false;
	}
	{
		RegKey		reg;
		if( false == reg.Open( HKEY_CLASSES_ROOT , ( L"CLSID\\" + clsid + L"\\InprocServer32" ).c_str() , RegKey::Write ) )
			return false;
		if( false == reg.SetStringValue( 0 , dllpath ) )
			return false;
		if( false == reg.SetStringValue( L"ThreadingModel" , t_threadmodel ) )
			return false;
	}
	return true;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
// classes define

/**************************************************************************************************
"ComInstance" class 
**************************************************************************************************/
template<class t_class>
class ComInstance
{
	cb_copy_impossible( ComInstance );
	
// variable member
private:
	t_class*	m_ptr;
	
// public functions
public:
//=================================================================================================
ComInstance() : m_ptr( new t_class() )
{
}
//=================================================================================================
~ComInstance()
{
	m_ptr->Release();
}
//=================================================================================================
t_class* operator->()const
{
	return m_ptr;
}
//=================================================================================================
t_class& operator*()const
{
	return *m_ptr;
}
};
/**************************************************************************************************
"ComObject" class 
**************************************************************************************************/
class ComObject : public IUnknown
{
private:
	enum Cmd
	{
		Get , 
		Inc , 
		Dec , 
	};	

// variable member
private:
	int32		m_refcount;
	
// private functions
private:
//=================================================================================================
static
int32 ActiveObjectCount
		(
		Cmd		cmd
		)
{
	static 
	int32	m_active_count	= 0;
	if( cmd == Inc )
		return atomic_increment( &m_active_count );
	else if( cmd == Dec )
		return atomic_decrement( &m_active_count );
	else
		return atomic_get( &m_active_count );
}
// "IUnknown" interface functions
protected:
//=================================================================================================
cb_comcall( HRESULT )
QueryInterface
		(
		const GUID&	riid, 
		void**		object
		)
{
	if( IsEqualIID( riid, IID_IUnknown ) )
	{
		AddRef();
		*object	= static_cast<IUnknown*>(this);
		return S_OK;
	}
	return E_NOINTERFACE;
}
//=================================================================================================
cb_comcall( ULONG )
AddRef()
{
	return cb_atomic_increment( &m_refcount );
}
//=================================================================================================
cb_comcall( ULONG )
Release()
{
	LONG	r = cb_atomic_decrement( &m_refcount );
	if( r == 0 )
		delete this;
	return r;
}

// public functions
public:
//=================================================================================================
ComObject() : m_refcount( 1 )
{
	ActiveObjectCount( Inc );
}
//=================================================================================================
virtual
~ComObject()
{
	ActiveObjectCount( Dec );
}
//=================================================================================================
static
bool ObjectsActive()
{
	return ActiveObjectCount( Get ) > 0 ? true : false;
}
};

/**************************************************************************************************
"ComFactory" class 
**************************************************************************************************/
class ComFactory : 
		public IClassFactory , 
		public ComObject
{
	cb_com_begin()
	cb_com_query( IID_IClassFactory , IClassFactory )
	cb_com_end( ComObject )

private:
	enum Cmd
	{
		Get , 
		Inc , 
		Dec , 
	};	
// variable member
private:

// private functions
private:
//=================================================================================================
static
int32 LockServerCount
		(
		Cmd			cmd
		)
{
	static
	int32		m_lock	= 0;
	if( cmd == Inc )
		return atomic_increment( &m_lock );
	else if( cmd == Dec )
		return atomic_decrement( &m_lock );
	else
		return atomic_get( &m_lock );
}
// "IClassFactory" interface functions
private:
//=================================================================================================
cb_comcall( HRESULT ) 
CreateInstance
		(
		LPUNKNOWN	pUnkOuter , 
		const GUID&	riid , 
		void**		pv
		)
{
	return CreateComObject( pUnkOuter , riid , pv );
}
//=================================================================================================
cb_comcall( HRESULT ) 
LockServer
		(
		BOOL	lock_f
		)
{
	LockServerCount( lock_f == FALSE ? Dec : Inc );
    return NOERROR;
}
// protected functions
protected:
//=================================================================================================
virtual
HRESULT CreateComObject
		(
		IUnknown*	outer , 
		const GUID&	riid , 
		void**		pv
		)
{
	return E_FAIL;
}

// public functions
public:
//=================================================================================================
ComFactory()
{
}
//=================================================================================================
virtual
~ComFactory()
{
}
//=================================================================================================
static
bool IsLocked()
{
	return LockServerCount( Get ) > 0 ? true : false;
}
};
/**************************************************************************************************
"ComModule" class 
**************************************************************************************************/
class ComModule
{
	enum Cmd
	{
		Get , 
		Set , 
	};
// variable member
protected:
	
// private functions
private:
//=================================================================================================
static
HINSTANCE ModuleHandle
		(
		Cmd			cmd		= Get , 
		HINSTANCE	hinst	= NULL
		)
{
	static
	HINSTANCE		m_hinst = NULL;
	if( cmd == Set )
		m_hinst = hinst;
	return m_hinst;
}
// override functions
protected:
//=================================================================================================
virtual
HRESULT GetClassObject
		(
		const GUID&		rclsid, 
		const GUID&		riid, 
		void**			ppv
		)
{
    return E_FAIL;
}
//=================================================================================================
virtual
HRESULT RegisterServer()
{
	return E_FAIL;
}
//=================================================================================================
virtual
HRESULT UnregisterServer()
{
	return E_FAIL;
}
// protected functions
protected:
//=================================================================================================
bool RegistServer
		(
		const GUID&		t_clsid ,
		const wchar_t*	t_threadmodel
		)
{
	return RegistComServer( ModuleHandle() , t_clsid , t_threadmodel );
}
//=================================================================================================
bool UnregistServer
		(
		const GUID&		t_clsid
		)
{
	return UnregistComServer( t_clsid );
}
// public functions
public:
//=================================================================================================
ComModule()
{
}
//=================================================================================================
BOOL DllMain
		(
		HINSTANCE	hinst ,
		DWORD		reason,
		void*		pv
		)
{
	if( reason == DLL_PROCESS_ATTACH )
	{
		DisableThreadLibraryCalls( ModuleHandle( Set , hinst ) );
	}
	else if( reason == DLL_THREAD_DETACH )
	{
	}
    return  TRUE;
}
//=================================================================================================
HRESULT DllCanUnloadNow()
{
   if( true == ComFactory::IsLocked() || true == ComObject::ObjectsActive() )
		return S_FALSE;
	return S_OK;
}
//=================================================================================================
HRESULT DllGetClassObject
		(
		const GUID&		rclsid, 
		const GUID&		riid, 
		void**			ppv
		)
{
    return GetClassObject( rclsid , riid , ppv );
}
//=================================================================================================
HRESULT DllRegisterServer()
{
    return RegisterServer();
}
//=================================================================================================
HRESULT DllUnregisterServer()
{
	return UnregisterServer();
}
//=================================================================================================
static
HINSTANCE GetModuleHandle()
{
	return ModuleHandle();
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
// global variable define

///////////////////////////////////////////////////////////////////////////////////////////////////
// global functions define

};	//namespace

//using namespace icubic;		

#pragma pack( pop )			//release align
