// $Id: pe.cpp,v 1.1 2003/01/10 04:42:50 yuya Exp $

////////////////////////////////////////////////////////////////////////////////

#include "exerb.h"

////////////////////////////////////////////////////////////////////////////////

static PIMAGE_IMPORT_DESCRIPTOR ExGetFirstImportDescriptor(DWORD dwBaseAddress, DWORD *pdwImportTableDelta);
static bool ExGetImportTable(PIMAGE_NT_HEADERS pNtHeader, DWORD *pdwAddress, DWORD *pdwDelta);
static void ExReplaceImportDllName(DWORD dwBaseAddress, DWORD dwImportTableDelta, PDWORD pdwNamePoolAddress, PDWORD pdwNamePoolSize, PIMAGE_IMPORT_DESCRIPTOR pFirstDescriptor, char* pszSrc, char* pszDest);
static void ExReplaceImportFunctionName(DWORD dwOffsetOfName, PIMAGE_IMPORT_DESCRIPTOR pFirstDescriptor, const char* pszDllName, const char* pszSource, const char* pszDestination);
static PIMAGE_NT_HEADERS ExGetNtHeader(PIMAGE_DOS_HEADER pDosHeader);
static bool  ExGetSectionUnusedArea(PIMAGE_NT_HEADERS pNtHeader, char *pszName, PDWORD pdwAddress, PDWORD pdwSize);
static PIMAGE_SECTION_HEADER ExGetEnclosingSectionHeader(PIMAGE_NT_HEADERS pNtHeader, DWORD dwRVA);
static PIMAGE_SECTION_HEADER ExFindSection(PIMAGE_NT_HEADERS pNtHeader, char *pszSectionName);

////////////////////////////////////////////////////////////////////////////////

extern char g_szPhiSoFileName[MAX_PATH];

////////////////////////////////////////////////////////////////////////////////

bool
ExReplaceImportTable(void *pvBuffer)
{
	const DWORD dwBaseAddress = (DWORD)pvBuffer;

	DWORD dwImportTableDelta = 0;
	const PIMAGE_IMPORT_DESCRIPTOR pDescriptor = ::ExGetFirstImportDescriptor(dwBaseAddress, &dwImportTableDelta);
	if ( !pDescriptor ) return false;

	DWORD dwNamePoolAddress = 0;
	DWORD dwNamePoolSize    = 0;
	PIMAGE_NT_HEADERS pNtHeader = ::ExGetNtHeader((PIMAGE_DOS_HEADER)dwBaseAddress);

	::ExGetSectionUnusedArea(pNtHeader, ".idata", &dwNamePoolAddress, &dwNamePoolSize);

	char szSelfFileName[MAX_PATH] = "";
	::ExGetSelfFileName(szSelfFileName, sizeof(szSelfFileName));

#ifdef RUBY18
	::ExReplaceImportDllName(dwBaseAddress, dwImportTableDelta, &dwNamePoolAddress, &dwNamePoolSize, pDescriptor, "msvcrt-ruby17.dll",  szSelfFileName);
	::ExReplaceImportDllName(dwBaseAddress, dwImportTableDelta, &dwNamePoolAddress, &dwNamePoolSize, pDescriptor, "msvcrt-ruby18.dll",  szSelfFileName);
	::ExReplaceImportDllName(dwBaseAddress, dwImportTableDelta, &dwNamePoolAddress, &dwNamePoolSize, pDescriptor, "cygwin-ruby17.dll",  szSelfFileName);
	::ExReplaceImportDllName(dwBaseAddress, dwImportTableDelta, &dwNamePoolAddress, &dwNamePoolSize, pDescriptor, "cygwin-ruby18.dll",  szSelfFileName);
#else
	::ExReplaceImportDllName(dwBaseAddress, dwImportTableDelta, &dwNamePoolAddress, &dwNamePoolSize, pDescriptor, "mswin32-ruby16.dll", szSelfFileName);
	::ExReplaceImportDllName(dwBaseAddress, dwImportTableDelta, &dwNamePoolAddress, &dwNamePoolSize, pDescriptor, "mingw32-ruby16.dll", szSelfFileName);
	::ExReplaceImportDllName(dwBaseAddress, dwImportTableDelta, &dwNamePoolAddress, &dwNamePoolSize, pDescriptor, "cygwin-ruby16.dll",  szSelfFileName);
#endif
	::ExReplaceImportDllName(dwBaseAddress, dwImportTableDelta, &dwNamePoolAddress, &dwNamePoolSize, pDescriptor, "ruby.exe",           szSelfFileName);

	if ( ::strlen(g_szPhiSoFileName) > 0 ) {
		::ExReplaceImportDllName(dwBaseAddress, dwImportTableDelta, &dwNamePoolAddress, &dwNamePoolSize, pDescriptor, "phi.so", g_szPhiSoFileName);
	}

	const DWORD dwOffsetOfName = dwBaseAddress - dwImportTableDelta;
	::ExReplaceImportFunctionName(dwOffsetOfName, pDescriptor, szSelfFileName, "rb_require",   "ex_require");
	::ExReplaceImportFunctionName(dwOffsetOfName, pDescriptor, szSelfFileName, "rb_f_require", "ex_f_require");

	return true;
}

////////////////////////////////////////////////////////////////////////////////

static PIMAGE_IMPORT_DESCRIPTOR
ExGetFirstImportDescriptor(DWORD dwBaseAddress, DWORD *pdwImportTableDelta)
{
	DWORD dwImportTableAddress = 0;
	const PIMAGE_NT_HEADERS pNtHeader = ::ExGetNtHeader((PIMAGE_DOS_HEADER)dwBaseAddress);

	if ( ::ExGetImportTable(pNtHeader, &dwImportTableAddress, pdwImportTableDelta) ) {
		const DWORD dwImportTableBase = dwBaseAddress + dwImportTableAddress;
		return (PIMAGE_IMPORT_DESCRIPTOR)dwImportTableBase;
	} else {
		return NULL;
	}
}

static bool
ExGetImportTable(PIMAGE_NT_HEADERS pNtHeader, DWORD *pdwAddress, DWORD *pdwDelta)
{
	const DWORD dwImportTableRVA = pNtHeader->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT].VirtualAddress;

	const PIMAGE_SECTION_HEADER pSection = ::ExGetEnclosingSectionHeader(pNtHeader, dwImportTableRVA);
	if ( pSection ) {
		*pdwDelta   = pSection->VirtualAddress - pSection->PointerToRawData;
		*pdwAddress = dwImportTableRVA - *pdwDelta;
		return true;
	} else {
		*pdwDelta   = 0;
		*pdwAddress = 0;
		return false;
	}
}

static void
ExReplaceImportDllName(DWORD dwBaseAddress, DWORD dwImportTableDelta, PDWORD pdwNamePoolAddress, PDWORD pdwNamePoolSize, PIMAGE_IMPORT_DESCRIPTOR pFirstDescriptor, char* pszSrc, char* pszDest)
{
	DEBUGMSG2("ExReplaceImportDllName(..., '%s', '%s')\n", pszSrc, pszDest);

	const DWORD dwOffsetOfName = dwBaseAddress - dwImportTableDelta;
	const DWORD dwDestLength   = ::strlen(pszDest);

	for ( PIMAGE_IMPORT_DESCRIPTOR pDescriptor = pFirstDescriptor; pDescriptor->Name; pDescriptor++ ) {
		char *pszName = (char*)(dwOffsetOfName + pDescriptor->Name);

		if ( ::stricmp(pszName, pszSrc) == 0 ) {
			if ( dwDestLength <= ::strlen(pszName) ) {
				::strcpy(pszName, pszDest);
			} else if ( dwDestLength + 1 <= *pdwNamePoolSize ) {
				const DWORD dwAddress = *pdwNamePoolAddress - dwDestLength - 1;
				pDescriptor->Name = dwAddress + dwImportTableDelta;

				::memcpy((void*)(dwBaseAddress + dwAddress), pszDest, dwDestLength);

				*pdwNamePoolAddress -= dwDestLength + 1;
				*pdwNamePoolSize    -= dwDestLength + 1;
			} else {
				::rb_raise(rb_eLoadError, "Fail to modify the import table. exe/dll file name is too long.");
			}
		}
	}
}

static void
ExReplaceImportFunctionName(DWORD dwOffsetOfName, PIMAGE_IMPORT_DESCRIPTOR pFirstDescriptor, const char* pszDllName, const char* pszSource, const char* pszDestination)
{
	for ( PIMAGE_IMPORT_DESCRIPTOR pDescriptor = pFirstDescriptor; pDescriptor->Name; pDescriptor++ ) {
		const char *pszName = (char*)(dwOffsetOfName + pDescriptor->Name);

		if ( ::strcmp(pszName, pszDllName) == 0 ) {
			PIMAGE_THUNK_DATA thunk    = (PIMAGE_THUNK_DATA)pDescriptor->Characteristics;
			PIMAGE_THUNK_DATA thunkIAT = (PIMAGE_THUNK_DATA)pDescriptor->FirstThunk;

			if ( !thunk ) {
				if ( !thunkIAT ) {
					continue;
				}
				thunk = thunkIAT;
			}

			thunk    = (PIMAGE_THUNK_DATA)((DWORD)thunk    + dwOffsetOfName);
			thunkIAT = (PIMAGE_THUNK_DATA)((DWORD)thunkIAT + dwOffsetOfName);

			while ( thunk->u1.AddressOfData ) {
				if ( !(thunk->u1.Ordinal & IMAGE_ORDINAL_FLAG) ) {
					const PIMAGE_IMPORT_BY_NAME pImportByName = (PIMAGE_IMPORT_BY_NAME)((DWORD)(thunk->u1.AddressOfData) + dwOffsetOfName);
					const LPSTR pszFunctionName = (LPSTR)pImportByName->Name;
					if ( ::strcmp(pszFunctionName, pszSource) == 0 ) {
						::strcpy(pszFunctionName, pszDestination);
					}
				}

				thunk++;
				thunkIAT++;
			}
		}
	}
}

static PIMAGE_NT_HEADERS
ExGetNtHeader(PIMAGE_DOS_HEADER pDosHeader)
{
	return (PIMAGE_NT_HEADERS)((DWORD)pDosHeader + pDosHeader->e_lfanew);
}

static bool
ExGetSectionUnusedArea(PIMAGE_NT_HEADERS pNtHeader, char *pszName, PDWORD pdwAddress, PDWORD pdwSize)
{
	const PIMAGE_SECTION_HEADER pSection = ::ExFindSection(pNtHeader, pszName);

	if ( pSection ) {
		*pdwAddress = pSection->PointerToRawData + pSection->SizeOfRawData;
		*pdwSize    = pSection->SizeOfRawData    - pSection->Misc.VirtualSize;
		return true;
	} else {
		*pdwAddress = 0;
		*pdwSize    = 0;
		return false;
	}
}

static PIMAGE_SECTION_HEADER
ExGetEnclosingSectionHeader(PIMAGE_NT_HEADERS pNtHeader, DWORD dwRVA)
{
	PIMAGE_SECTION_HEADER pSection = IMAGE_FIRST_SECTION(pNtHeader);

	for ( int i = 0; i < pNtHeader->FileHeader.NumberOfSections; i++, pSection++ ) {
		if ( (dwRVA >= pSection->VirtualAddress) && (dwRVA < (pSection->VirtualAddress + pSection->Misc.VirtualSize)) ) {
			return pSection;
		}
	}

	return NULL;
}

static PIMAGE_SECTION_HEADER
ExFindSection(PIMAGE_NT_HEADERS pNtHeader, char *pszSectionName)
{
	PIMAGE_SECTION_HEADER pSection = IMAGE_FIRST_SECTION(pNtHeader);

	for ( int i = 0; i < pNtHeader->FileHeader.NumberOfSections; i++, pSection++ ) {
		if ( ::strnicmp(pszSectionName, (char*)pSection->Name, 8) == 0 ) {
			return pSection;
		}
	}

	return NULL;
}

////////////////////////////////////////////////////////////////////////////////
