CoCreateInstance主要的工作是读注册表,然后CoLoadLibrary,调用DllGetClassObject,最后CreateInstance创建对象。
#include<iostream>
using namespace std;
#include "ComFree_i.h"
#include "ComFree_i.c"
#include <Dbghelp.h>
#pragma comment(lib,"ole32.lib")
#pragma comment(lib,"Dbghelp.lib")
typedef HRESULT (__stdcall *pDllGetClassObject)(REFCLSID rclsid, REFIID riid, LPVOID * ppv);
static PVOID sm_pvMaxAppAddr;
const BYTE cPushOpCode = 0x68;
//普通COM
void TestCom1()
{
cout<<"TestCom1"<<endl;
//声明HRESULT和IRCom接口指针
IRCom* iCom = NULL;
HRESULT hr = CoInitialize(NULL); //初始化COM
//使用SUCCEEDED宏并检查我们是否能得到一个接口指针
if(SUCCEEDED(hr))
{
hr = CoCreateInstance(CLSID_RCom,
NULL,
CLSCTX_INPROC_SERVER,
IID_IRCom,
(void **)&iCom);
//如果成功,则调用Minus方法,否则显示相应的出错信息
if(SUCCEEDED(hr))
{
long ret;
iCom->Minus(8,9,&ret);
cout << "The answer for 8-9 is:" << ret << endl;
iCom->Release();
}
else
{
cout << "CoCreateInstance Failed." << endl;
}
}
CoUninitialize();//释放COM
}
//无注册表COM
void TestCom2()
{
cout<<"TestCom2"<<endl;
//声明HRESULT和IRCom接口指针
IRCom* iCom = NULL;
char *pDLLName = "ComFree.dll";
wchar_t szDLLPath[MAX_PATH];
MultiByteToWideChar(CP_ACP, NULL, pDLLName, strlen(pDLLName)+1, szDLLPath, MAX_PATH);
HMODULE hModule = CoLoadLibrary(szDLLPath, TRUE);
if(hModule)
{
pDllGetClassObject pfGetClassObj = (pDllGetClassObject)GetProcAddress(hModule, "DllGetClassObject");
if (pfGetClassObj)
{
IClassFactory *pFac;
HRESULT hr = pfGetClassObj(CLSID_RCom,IID_IClassFactory,(void **)&pFac);
if (SUCCEEDED(hr))
{
hr = pFac->CreateInstance(NULL, IID_IRCom, (void **)&iCom);
pFac->Release();
long ret;
iCom->Minus(8,9,&ret);
cout << "The answer for 8-9 is:" << ret << endl;
iCom->Release();
}
else
{
cout<<hr<<endl;
cout << "DllGetClassObject Failed." << endl;
}
}
}
else
{
cout << "CoLoadLibrary Failed." << endl;
}
}
//HOOK,使普通COM变成无注册表COM
int CompareStringNoCase(const char* dst, const char* src)
{
int f, l;
do
{
f = (unsigned char)(*(dst++));
if ((f >= 'A') && (f <= 'Z'))
f -= ('A' - 'a');
l = (unsigned char)(*(src++));
if ((l >= 'A') && (l <= 'Z'))
l -= ('A' - 'a');
}
while ( f && (f == l) );
return (f - l);
}
HRESULT WINAPI HookCoCreateInstance(REFCLSID rclsid,LPUNKNOWN pUnkOuter,DWORD dwClsContext,REFIID riid,LPVOID * ppv)
{
cout<<"HookCoCreateInstance"<<endl;
HRESULT hr = NULL;
char *pDLLName = "ComFree.dll";
wchar_t szDLLPath[MAX_PATH];
MultiByteToWideChar(CP_ACP, NULL, pDLLName, strlen(pDLLName)+1, szDLLPath, MAX_PATH);
HMODULE hModule = CoLoadLibrary(szDLLPath, TRUE);
if(hModule)
{
pDllGetClassObject pfGetClassObj = (pDllGetClassObject)GetProcAddress(hModule, "DllGetClassObject");
if (pfGetClassObj)
{
IClassFactory *pFac;
hr = pfGetClassObj(rclsid,IID_IClassFactory,(void **)&pFac);
if (SUCCEEDED(hr))
{
hr = pFac->CreateInstance(NULL, riid, ppv);
pFac->Release();
}
else
{
cout<<hr<<endl;
cout << "DllGetClassObject Failed." << endl;
}
}
}
else
{
cout << "CoLoadLibrary Failed." << endl;
}
return hr;
}
BOOL HookOle()
{
char* pszCalleeModName = "Ole32.dll";
char* pszFuncName = "CoCreateInstance";
PROC pfnCurrent = ::GetProcAddress(::GetModuleHandleA(pszCalleeModName), pszFuncName);
if (NULL == pfnCurrent)
{
HMODULE hmod = ::LoadLibraryA(pszCalleeModName);
if (hmod)
{
pfnCurrent = ::GetProcAddress(::GetModuleHandleA(pszCalleeModName), pszFuncName);
}
}
HMODULE hmodCaller = GetModuleHandle(NULL);
PROC pfnNew = (PROC)HookCoCreateInstance;
try
{
ULONG ulSize;
// Get the address of the module's import section
PIMAGE_IMPORT_DESCRIPTOR pImportDesc =
(PIMAGE_IMPORT_DESCRIPTOR)ImageDirectoryEntryToData(
hmodCaller,
TRUE,
IMAGE_DIRECTORY_ENTRY_IMPORT,
&ulSize
);
// Does this module has import section ?
if (pImportDesc == NULL)
return FALSE;
while (pImportDesc != NULL)
{
// Loop through all descriptors and
// find the import descriptor containing references to callee's functions
while (pImportDesc->Name)
{
LPCSTR lpszName = ((LPCSTR)((PBYTE) hmodCaller + pImportDesc->Name));
if (CompareStringNoCase(lpszName, pszCalleeModName) == 0)
break; // Found
pImportDesc++;
} // while
// Does this module import any functions from this callee ?
if (pImportDesc->Name == 0)
return FALSE;
PIMAGE_THUNK_DATA pThunk =
(PIMAGE_THUNK_DATA)( (PBYTE) hmodCaller + (UINT_PTR)pImportDesc->FirstThunk );
while (pThunk->u1.Function)
{
PROC* ppfn = (PROC*) &pThunk->u1.Function;
BOOL bFound = (*ppfn == pfnCurrent);
if (!bFound && (*ppfn > sm_pvMaxAppAddr))
{
PBYTE pbInFunc = (PBYTE) *ppfn;
// Is this a wrapper (debug thunk) represented by PUSH instruction?
if (pbInFunc[0] == cPushOpCode)
{
ppfn = (PROC*) &pbInFunc[1];
bFound = (*ppfn == pfnCurrent);
}
}
if (bFound)
{
MEMORY_BASIC_INFORMATION mbi;
::VirtualQuery(ppfn, &mbi, sizeof(MEMORY_BASIC_INFORMATION));
// In order to provide writable access to this part of the
// memory we need to change the memory protection
if (!::VirtualProtect(mbi.BaseAddress, mbi.RegionSize,
PAGE_READWRITE, &mbi.Protect))
{
return FALSE;
}
// Hook the function.
*ppfn = *pfnNew;
// Restore the protection back
DWORD dwOldProtect;
::VirtualProtect(mbi.BaseAddress, mbi.RegionSize,
mbi.Protect, &dwOldProtect);
return TRUE;
}
pThunk++;
}
pImportDesc++;
}
}
catch(...)
{
// do nothing
}
return FALSE;
}
int main(int argc, char* argv[])
{
if(HookOle())//如果把这行注释掉,TestCom1将失败
{
TestCom1();
}
TestCom2();
return 0;
}