/*************************************************************************************************/
/*!
   	@file		MatrixMulti.h
	@author 	Fanzo
 	@date 		2008/4/10
*/
/*************************************************************************************************/
#pragma		once

///////////////////////////////////////////////////////////////////////////////////////////////////
//include files
#include	"../iCubicStd/Array.h"

#pragma pack( push , 8 )		//set align

namespace icubic
{

///////////////////////////////////////////////////////////////////////////////////////////////////
// preprocessor deifne

///////////////////////////////////////////////////////////////////////////////////////////////////
// type define

///////////////////////////////////////////////////////////////////////////////////////////////////
// classes define

/**************************************************************************************************
"fVectorMatrixMulti" class 
**************************************************************************************************/
class fVectorMatrixMulti
{
	friend class fSquareMatrixMulti;
	
// variable member
private:
	int				m_num;
	Array<float>	m_element;

// private functions
private:
// protect functions
protected:
// public functions
public:
//=================================================================================================
//!	construct
//-------------------------------------------------------------------------------------------------
fVectorMatrixMulti() : m_num( 0 )
{
}
//=================================================================================================
//!	construct
//-------------------------------------------------------------------------------------------------
fVectorMatrixMulti
		(
		int		num
		) : m_num( 0 )
{
	SetSize( num );
}
//=================================================================================================
//!	copy construct
//-------------------------------------------------------------------------------------------------
fVectorMatrixMulti
		(
		const fVectorMatrixMulti&	obj
		) : m_num( 0 )
{
	*this	= obj;
}
//=================================================================================================
//!	substitution
//!	@retval			---
//-------------------------------------------------------------------------------------------------
fVectorMatrixMulti& operator=
		(
		const fVectorMatrixMulti& obj
		)
{
	SetSize( obj.m_num );
	int		off;
	for( off = 0 ; off < m_num ; off++ )
	{
		Elem( off ) = obj.Elem( off );
	}
	return *this;
}
//=================================================================================================
//!	add
//!	@retval			---
//-------------------------------------------------------------------------------------------------
fVectorMatrixMulti operator+
		(
		const fVectorMatrixMulti& obj
		)const
{
	cb_assert( m_num == obj.m_num , L"fVectorMatrixMulti:invalid size" );
	
	fVectorMatrixMulti	mt = *this;
	mt += obj;
	return mt;
}
//=================================================================================================
//!	add
//!	@retval			---
//-------------------------------------------------------------------------------------------------
void operator+=
		(
		const fVectorMatrixMulti& obj
		)
{
	cb_assert( m_num == obj.m_num , L"fVectorMatrixMulti:invalid size" );
	int		off;
	for( off = 0 ; off < m_num ; off++ )
	{
		Elem( off ) += obj.Elem( off );
	}
}
//=================================================================================================
//!	sub
//!	@retval			---
//-------------------------------------------------------------------------------------------------
fVectorMatrixMulti operator-
		(
		const fVectorMatrixMulti& obj
		)const
{
	cb_assert( m_num == obj.m_num , L"fVectorMatrixMulti:invalid size" );

	fVectorMatrixMulti	mt = *this;
	mt -= obj;
	return mt;
}
//=================================================================================================
//!	sub
//!	@retval			---
//-------------------------------------------------------------------------------------------------
void operator-=
		(
		const fVectorMatrixMulti& obj
		)
{
	cb_assert( m_num == obj.m_num , L"fVectorMatrixMulti:invalid size" );
	int		off;
	for( off = 0 ; off < m_num ; off++ )
	{
		Elem( off ) -= obj.Elem( off );
	}
}
//=================================================================================================
//!	sub
//!	@retval			---
//-------------------------------------------------------------------------------------------------
fVectorMatrixMulti operator-()const
{
	fVectorMatrixMulti	mt;
	int		off;
	for( off = 0 ; off < m_num ; off++ )
	{
		mt.Elem( off ) = -Elem( off );
	}
	return mt;
}
//=================================================================================================
//!	mul
//!	@retval			---
//-------------------------------------------------------------------------------------------------
float operator*
		(
		const fVectorMatrixMulti&	obj
		)const
{
	cb_assert( m_num == obj.m_num , L"fVectorMatrixMulti:invalid size" );

	float	e = 0.0f;
	int		off;
	for( off = 0 ; off < m_num ; off++ )
	{
		e += Elem( off ) * obj.Elem( off );
	}
	return e;
}
//=================================================================================================
//!	set size
//!	@retval			---
//-------------------------------------------------------------------------------------------------
void SetSize
		(
		int		num
		)
{
	if( m_num == num )
		return;
	m_num	= num;
	m_element.Resize( num );
}
//=================================================================================================
//!	get size
//!	@retval			---
//-------------------------------------------------------------------------------------------------
int GetSize()const
{
	return m_num;
}
//=================================================================================================
//!	element
//!	@retval			---
//-------------------------------------------------------------------------------------------------
float& Elem
		(
		int		off
		)
{
	cb_assert( 0 <= off && off < m_num , L"out of range." );
	return m_element[ off ];
}
//=================================================================================================
//!	element
//!	@retval			---
//-------------------------------------------------------------------------------------------------
const float& Elem
		(
		int		off
		)const
{
	cb_assert( 0 <= off && off < m_num , L"out of range." );
	return m_element[ off ];
}
};

/**************************************************************************************************
"fSquareMatrixMulti" class 
**************************************************************************************************/
class fSquareMatrixMulti
{
// variable member
private:
	int				m_num;
	Array<float>	m_element;
		
// private functions
private:
// protect functions
protected:
// public functions
public:
//=================================================================================================
//!	construct
//-------------------------------------------------------------------------------------------------
fSquareMatrixMulti() : m_num( 0 )
{
}
//=================================================================================================
//!	construct
//-------------------------------------------------------------------------------------------------
fSquareMatrixMulti
		(
		int		num
		) : m_num( 0 )
{
	SetSize( num );
}
//=================================================================================================
//!	copy construct
//-------------------------------------------------------------------------------------------------
fSquareMatrixMulti
		(
		const fSquareMatrixMulti&	obj
		) : m_num( 0 )
{
	*this	= obj;
}
//=================================================================================================
//!	substitution
//!	@retval			---
//-------------------------------------------------------------------------------------------------
fSquareMatrixMulti& operator=
		(
		const fSquareMatrixMulti& obj
		)
{
	SetSize( obj.m_num );
	int		row , col;
	for( row = 0 ; row < m_num ; row++ )
	{
		for( col = 0 ; col < m_num ; col++ )
		{
			Elem( row , col ) = obj.Elem( row , col );
		}
	}
	return *this;
}
//=================================================================================================
//!	add
//!	@retval			---
//-------------------------------------------------------------------------------------------------
fSquareMatrixMulti operator+
		(
		const fSquareMatrixMulti& obj
		)const
{
	cb_assert( m_num == obj.m_num , L"invalid size" );
	
	fSquareMatrixMulti	mt = *this;
	mt += obj;
	return mt;
}
//=================================================================================================
//!	add
//!	@retval			---
//-------------------------------------------------------------------------------------------------
void operator+=
		(
		const fSquareMatrixMulti& obj
		)
{
	cb_assert( m_num == obj.m_num , L"invalid size" );
	int		row , col;
	for( row = 0 ; row < m_num ; row++ )
	{
		for( col = 0 ; col < m_num ; col++ )
		{
			Elem( row , col ) += obj.Elem( row , col );
		}
	}
}
//=================================================================================================
//!	sub
//!	@retval			---
//-------------------------------------------------------------------------------------------------
fSquareMatrixMulti operator-
		(
		const fSquareMatrixMulti& obj
		)const
{
	cb_assert( m_num == obj.m_num , L"invalid size" );

	fSquareMatrixMulti	mt = *this;
	mt -= obj;
	return mt;
}
//=================================================================================================
//!	sub
//!	@retval			---
//-------------------------------------------------------------------------------------------------
void operator-=
		(
		const fSquareMatrixMulti& obj
		)
{
	cb_assert( m_num == obj.m_num , L"invalid size" );
	int		row , col;
	for( row = 0 ; row < m_num ; row++ )
	{
		for( col = 0 ; col < m_num ; col++ )
		{
			Elem( row , col ) -= obj.Elem( row , col );
		}
	}
}
//=================================================================================================
//!	sub
//!	@retval			---
//-------------------------------------------------------------------------------------------------
fSquareMatrixMulti operator-()const
{
	fSquareMatrixMulti	mt;
	int		row , col;
	for( row = 0 ; row < m_num ; row++ )
	{
		for( col = 0 ; col < m_num ; col++ )
		{
			mt.Elem( row , col ) = -Elem( row , col );
		}
	}
}
//=================================================================================================
//!	mul
//!	@retval			---
//-------------------------------------------------------------------------------------------------
fSquareMatrixMulti operator*
		(
		const fSquareMatrixMulti&	obj
		)const
{
	cb_assert( m_num == obj.m_num , L"invalid size" );

	fSquareMatrixMulti	mt = *this;
	int		row , col;
	for( row = 0 ; row < m_num ; row++ )
	{
		for( col = 0 ; col < m_num ; col++ )
		{
			float	e = 0.0f;
			int		r;
			for( r = 0 ; r < m_num ; r++ )
			{
				e += Elem( row , r ) * obj.Elem( r , col );
			}
			mt.Elem( row , col ) = e;
		}
	}
	return mt;
}
//=================================================================================================
//!	mul
//!	@retval			---
//-------------------------------------------------------------------------------------------------
void operator*=
		(
		const fSquareMatrixMulti&	obj
		)
{
	cb_assert( m_num == obj.m_num , L"invalid size" );

	*this	= ( *this ) * obj;
}
//=================================================================================================
//!	mul
//!	@retval			---
//-------------------------------------------------------------------------------------------------
fVectorMatrixMulti operator*
		(
		const fVectorMatrixMulti&	obj
		)const
{
	cb_assert( m_num == obj.m_num , L"invalid size" );
	fVectorMatrixMulti	r( m_num );
	
	int		row;
	for( row = 0 ; row < m_num ; row++ )
	{
		float	e = 0.0f;
		int		off;
		for( off = 0 ; off < m_num ; off++ )
		{
			e += Elem( row , off ) * obj.Elem( off );
		}
		r.Elem( row ) = e;
	}
	return r;
}
//=================================================================================================
//!	set size
//!	@retval			---
//-------------------------------------------------------------------------------------------------
void SetSize
		(
		int		num
		)
{
	if( m_num == num )
		return;
	m_num	= num;
	m_element.Resize( num * num );
}
//=================================================================================================
//!	get size
//!	@retval			---
//-------------------------------------------------------------------------------------------------
int GetSize()const
{
	return m_num;
}
//=================================================================================================
//!	element
//!	@retval			---
//-------------------------------------------------------------------------------------------------
float& Elem
		(
		int		row , 
		int		col
		)
{
	cb_assert( 0 <= row && row < m_num , L"out of range." );
	cb_assert( 0 <= col && col < m_num , L"out of range." );
	return m_element[ row * m_num + col ];
}
//=================================================================================================
//!	element
//!	@retval			---
//-------------------------------------------------------------------------------------------------
const float& Elem
		(
		int		row , 
		int		col
		)const
{
	cb_assert( 0 <= row && row < m_num , L"out of range." );
	cb_assert( 0 <= col && col < m_num , L"out of range." );
	return m_element[ row * m_num + col ];
}
//=================================================================================================
//!	unit matrix
//!	@retval			---
//-------------------------------------------------------------------------------------------------
void SetUnit()
{
	int		row , col;
	for( row = 0 ; row < m_num ; row++ )
	{
		for( col = 0 ; col < m_num ; col++ )
		{
			Elem( row , col ) = ( row == col ) ? 1.0f : 0.0f;
		}
	}
}
//=================================================================================================
//!	AddRow
//!	@retval			---
//-------------------------------------------------------------------------------------------------
void AddRow
		(
		int		dest_row , 
		int		src_row
		)
{
	cb_assert( 0 <= dest_row && dest_row < m_num , L"out of range." );
	cb_assert( 0 <= src_row && src_row < m_num , L"out of range." );
	int		col;
	for( col = 0 ; col < m_num ; col++ )
	{
		Elem( dest_row , col ) += Elem( src_row , col );
	}
}
//=================================================================================================
//!	SubRow
//!	@retval			---
//-------------------------------------------------------------------------------------------------
void SubRow
		(
		int		dest_row , 
		int		src_row
		)
{
	cb_assert( 0 <= dest_row && dest_row < m_num , L"out of range." );
	cb_assert( 0 <= src_row && src_row < m_num , L"out of range." );
	int		col;
	for( col = 0 ; col < m_num ; col++ )
	{
		Elem( dest_row , col ) -= Elem( src_row , col );
	}
}
//=================================================================================================
//!	MulRow
//!	@retval			---
//-------------------------------------------------------------------------------------------------
void MulRow
		(
		int		row , 
		float	val
		)
{
	cb_assert( 0 <= row && row < m_num , L"out of range." );
	int		col;
	for( col = 0 ; col < m_num ; col++ )
	{
		Elem( row , col ) *= val;
	}
}
//=================================================================================================
//!	DivRow
//!	@retval			---
//-------------------------------------------------------------------------------------------------
void DivRow
		(
		int		row , 
		float	val
		)
{
	cb_assert( 0 <= row && row < m_num , L"out of range." );
	int		col;
	for( col = 0 ; col < m_num ; col++ )
	{
		Elem( row , col ) /= val;
	}
}
//=================================================================================================
//!	AddMulRow
//!	@retval			---
//-------------------------------------------------------------------------------------------------
void AddMulRow
		(
		int		dest_row , 
		int		src_row , 
		float	mul
		)
{
	cb_assert( 0 <= dest_row && dest_row < m_num , L"out of range." );
	cb_assert( 0 <= src_row && src_row < m_num , L"out of range." );
	int		col;
	for( col = 0 ; col < m_num ; col++ )
	{
		Elem( dest_row , col ) += ( Elem( src_row , col ) * mul );
	}
}
//=================================================================================================
//!	SubMulRow
//!	@retval			---
//-------------------------------------------------------------------------------------------------
void SubMulRow
		(
		int		dest_row , 
		int		src_row , 
		float	mul
		)
{
	cb_assert( 0 <= dest_row && dest_row < m_num , L"out of range." );
	cb_assert( 0 <= src_row && src_row < m_num , L"out of range." );
	int		col;
	for( col = 0 ; col < m_num ; col++ )
	{
		Elem( dest_row , col ) -= ( Elem( src_row , col ) * mul );
	}
}
//=================================================================================================
//!	AddDivRow
//!	@retval			---
//-------------------------------------------------------------------------------------------------
void AddDivRow
		(
		int		dest_row , 
		int		src_row , 
		float	div
		)
{
	cb_assert( 0 <= dest_row && dest_row < m_num , L"out of range." );
	cb_assert( 0 <= src_row && src_row < m_num , L"out of range." );
	int		col;
	for( col = 0 ; col < m_num ; col++ )
	{
		Elem( dest_row , col ) += ( Elem( src_row , col ) / div );
	}
}
//=================================================================================================
//!	SubDivRow
//!	@retval			---
//-------------------------------------------------------------------------------------------------
void SubDivRow
		(
		int		dest_row , 
		int		src_row , 
		float	div
		)
{
	cb_assert( 0 <= dest_row && dest_row < m_num , L"out of range." );
	cb_assert( 0 <= src_row && src_row < m_num , L"out of range." );
	int		col;
	for( col = 0 ; col < m_num ; col++ )
	{
		Elem( dest_row , col ) -= ( Elem( src_row , col ) / div );
	}
}
//=================================================================================================
//!	inverse matrix
//!	@retval			---
//-------------------------------------------------------------------------------------------------
bool GetInverse
		(
		fSquareMatrixMulti	*mt
		)
{
	fSquareMatrixMulti	tm( *this );
	fSquareMatrixMulti	sm( m_num );
	sm.SetUnit();
	
	int		row , col;
	for( col = 0 ; col < m_num ; col++ )
	{
		float	e = tm.Elem( col , col );
		if( e == 0.0f )
			return false;
		tm.DivRow( col , e );
		sm.DivRow( col , e );
					
		for( row = 0 ; row < m_num ; row++ )
		{
			if( row != col )
			{
				float	d = tm.Elem( row , col );
				if( d != 0.0f )
				{
					tm.SubMulRow( row , col , d );
					sm.SubMulRow( row , col , d );
				}
			}
		}
	}
	*mt	= sm;
	return true;
}
};

///////////////////////////////////////////////////////////////////////////////////////////////////
// global variable define

///////////////////////////////////////////////////////////////////////////////////////////////////
// global functions define

};	//namespace

//using namespace icubic;		

#pragma pack( pop )			//release align
