Reorganized mod code handle code

This commit is contained in:
Mr-Wiseguy 2024-08-29 23:47:54 -04:00
parent e7afded4eb
commit 3dc622d895
2 changed files with 193 additions and 138 deletions

View file

@ -12,6 +12,7 @@
#include <unordered_map>
#include <array>
#include <cstddef>
#include <variant>
#define MINIZ_NO_DEFLATE_APIS
#define MINIZ_NO_ARCHIVE_WRITING_APIS
@ -21,6 +22,10 @@
#include "librecomp/recomp.h"
#include "librecomp/sections.h"
namespace N64Recomp {
class Context;
};
namespace recomp {
namespace mods {
enum class ModOpenError {
@ -156,6 +161,81 @@ namespace recomp {
std::unordered_map<recomp_func_t*, PatchData> patched_funcs;
std::unordered_map<std::string, size_t> loaded_mods_by_id;
};
using ModFunction = std::variant<recomp_func_t*>;
class ModCodeHandle {
public:
virtual ~ModCodeHandle() {}
virtual bool good() = 0;
virtual void set_imported_function_pointer(size_t import_index, recomp_func_t* ptr) = 0;
virtual void set_reference_symbol_pointer(size_t symbol_index, recomp_func_t* ptr) = 0;
virtual void set_event_index(size_t local_event_index, uint32_t global_event_index) = 0;
virtual void set_recomp_trigger_event_pointer(void (*ptr)(uint8_t* rdram, recomp_context* ctx, uint32_t index)) = 0;
virtual void set_get_function_pointer(recomp_func_t* (*ptr)(int32_t)) = 0;
virtual void set_reference_section_addresses_pointer(int32_t* ptr) = 0;
virtual void set_local_section_address(size_t section_index, int32_t address) = 0;
virtual ModFunction get_function_handle(size_t func_index) = 0;
};
struct ModHandle {
ModManifest manifest;
std::unique_ptr<ModCodeHandle> code_handle;
std::unique_ptr<N64Recomp::Context> recompiler_context;
std::vector<uint32_t> section_load_addresses;
ModHandle(ModManifest&& manifest);
ModHandle(const ModHandle& rhs) = delete;
ModHandle& operator=(const ModHandle& rhs) = delete;
ModHandle(ModHandle&& rhs);
ModHandle& operator=(ModHandle&& rhs);
~ModHandle();
};
class DynamicLibrary;
class NativeCodeHandle : public ModCodeHandle {
public:
NativeCodeHandle(const std::filesystem::path& dll_path, const N64Recomp::Context& context);
~NativeCodeHandle() = default;
bool good() final;
void set_imported_function_pointer(size_t import_index, recomp_func_t* ptr) final {
imported_funcs[import_index] = ptr;
}
void set_reference_symbol_pointer(size_t symbol_index, recomp_func_t* ptr) final {
reference_symbol_funcs[symbol_index] = ptr;
};
void set_event_index(size_t local_event_index, uint32_t global_event_index) final {
event_indices[local_event_index] = global_event_index;
};
void set_recomp_trigger_event_pointer(void (*ptr)(uint8_t* rdram, recomp_context* ctx, uint32_t index)) final {
*recomp_trigger_event = ptr;
};
void set_get_function_pointer(recomp_func_t* (*ptr)(int32_t)) final {
*get_function = ptr;
};
void set_reference_section_addresses_pointer(int32_t* ptr) final {
*reference_section_addresses = ptr;
};
void set_local_section_address(size_t section_index, int32_t address) final {
section_addresses[section_index] = address;
};
ModFunction get_function_handle(size_t func_index) final {
return ModFunction{ functions[func_index] };
}
private:
void set_bad();
bool is_good;
std::unique_ptr<DynamicLibrary> dynamic_lib;
std::vector<recomp_func_t*> functions;
recomp_func_t** imported_funcs;
recomp_func_t** reference_symbol_funcs;
uint32_t* event_indices;
void (**recomp_trigger_event)(uint8_t* rdram, recomp_context* ctx, uint32_t index);
recomp_func_t* (**get_function)(int32_t vram);
int32_t** reference_section_addresses;
int32_t* section_addresses;
};
}
};

View file

@ -1,6 +1,5 @@
#include <span>
#include <fstream>
#include <variant>
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
@ -15,140 +14,116 @@
#define PATHFMT "%s"
#endif
recomp::mods::ModHandle::ModHandle(ModManifest&& manifest) :
manifest(std::move(manifest)), code_handle(), recompiler_context{std::make_unique<N64Recomp::Context>()} { }
recomp::mods::ModHandle::ModHandle(ModHandle&& rhs) = default;
recomp::mods::ModHandle& recomp::mods::ModHandle::operator=(ModHandle&& rhs) = default;
recomp::mods::ModHandle::~ModHandle() = default;
template<class... Ts>
struct overloaded : Ts... { using Ts::operator()...; };
template<class... Ts>
overloaded(Ts...) -> overloaded<Ts...>;
namespace recomp {
namespace mods {
using ModFunction = std::variant<recomp_func_t*>;
#if defined(_WIN32)
# define WIN32_LEAN_AND_MEAN
# include "Windows.h"
class ModCodeHandle {
public:
virtual ~ModCodeHandle() {}
virtual bool good() = 0;
virtual void set_imported_function_pointer(size_t import_index, recomp_func_t* ptr) = 0;
virtual void set_reference_symbol_pointer(size_t symbol_index, recomp_func_t* ptr) = 0;
virtual void set_event_index(size_t local_event_index, uint32_t global_event_index) = 0;
virtual void set_recomp_trigger_event_pointer(void (*ptr)(uint8_t* rdram, recomp_context* ctx, uint32_t index)) = 0;
virtual void set_get_function_pointer(recomp_func_t* (*ptr)(int32_t)) = 0;
virtual void set_reference_section_addresses_pointer(int32_t* ptr) = 0;
virtual void set_local_section_address(size_t section_index, int32_t address) = 0;
virtual ModFunction get_function_handle(size_t func_index) = 0;
class recomp::mods::DynamicLibrary {
public:
DynamicLibrary() = default;
DynamicLibrary(const std::filesystem::path& path) {
mod_dll = LoadLibraryW(path.c_str());
}
~DynamicLibrary() {
unload();
}
DynamicLibrary(const DynamicLibrary&) = delete;
DynamicLibrary& operator=(const DynamicLibrary&) = delete;
DynamicLibrary(DynamicLibrary&&) = delete;
DynamicLibrary& operator=(DynamicLibrary&&) = delete;
void unload() {
if (mod_dll != nullptr) {
FreeLibrary(mod_dll);
}
mod_dll = nullptr;
}
bool good() {
return mod_dll != nullptr;
}
template <typename T>
bool get_dll_symbol(T& out, const char* name) const {
out = (T)GetProcAddress(mod_dll, name);
if (out == nullptr) {
return false;
}
return true;
};
private:
HMODULE mod_dll;
};
class NativeCodeHandle : public ModCodeHandle {
public:
NativeCodeHandle(const std::filesystem::path& dll_path, const N64Recomp::Context& context) {
// Load the DLL.
mod_dll = LoadLibraryW(dll_path.c_str());
if (mod_dll == nullptr) {
set_bad();
return;
}
// Fill out the list of function pointers.
functions.resize(context.functions.size());
for (size_t i = 0; i < functions.size(); i++) {
std::string func_name = "mod_func_" + std::to_string(i);
functions[i] = (recomp_func_t*)GetProcAddress(mod_dll, func_name.c_str());
if (functions[i] == nullptr) {
set_bad();
return;
}
}
// Get the standard exported symbols.
get_dll_func(imported_funcs, "imported_funcs");
get_dll_func(reference_symbol_funcs, "reference_symbol_funcs");
get_dll_func(event_indices, "event_indices");
get_dll_func(recomp_trigger_event, "recomp_trigger_event");
get_dll_func(get_function, "get_function");
get_dll_func(reference_section_addresses, "reference_section_addresses");
get_dll_func(section_addresses, "section_addresses");
}
~NativeCodeHandle() = default;
bool good() final {
return mod_dll != nullptr;
}
void set_imported_function_pointer(size_t import_index, recomp_func_t* ptr) final {
imported_funcs[import_index] = ptr;
}
void set_reference_symbol_pointer(size_t symbol_index, recomp_func_t* ptr) final {
reference_symbol_funcs[symbol_index] = ptr;
};
void set_event_index(size_t local_event_index, uint32_t global_event_index) final {
event_indices[local_event_index] = global_event_index;
};
void set_recomp_trigger_event_pointer(void (*ptr)(uint8_t* rdram, recomp_context* ctx, uint32_t index)) final {
*recomp_trigger_event = ptr;
};
void set_get_function_pointer(recomp_func_t* (*ptr)(int32_t)) final {
*get_function = ptr;
};
void set_reference_section_addresses_pointer(int32_t* ptr) final {
*reference_section_addresses = ptr;
};
void set_local_section_address(size_t section_index, int32_t address) final {
section_addresses[section_index] = address;
};
ModFunction get_function_handle(size_t func_index) final {
return ModFunction{ functions[func_index] };
}
private:
template <typename T>
void get_dll_func(T& out, const char* name) {
out = (T)GetProcAddress(mod_dll, name);
if (out == nullptr) {
set_bad();
}
};
void set_bad() {
if (mod_dll) {
FreeLibrary(mod_dll);
}
mod_dll = nullptr;
}
HMODULE mod_dll;
std::vector<recomp_func_t*> functions;
recomp_func_t** imported_funcs;
recomp_func_t** reference_symbol_funcs;
uint32_t* event_indices;
void (**recomp_trigger_event)(uint8_t* rdram, recomp_context* ctx, uint32_t index);
recomp_func_t* (**get_function)(int32_t vram);
int32_t** reference_section_addresses;
int32_t* section_addresses;
};
struct ModHandle {
ModManifest manifest;
std::unique_ptr<ModCodeHandle> code_handle;
N64Recomp::Context recompiler_context;
std::vector<uint32_t> section_load_addresses;
ModHandle(ModManifest&& manifest) :
manifest(std::move(manifest)), code_handle(), recompiler_context{} {
}
};
void unprotect(void* target_func, uint64_t* old_flags) {
DWORD old_flags_dword;
BOOL result = VirtualProtect(target_func,
16,
PAGE_READWRITE,
&old_flags_dword);
*old_flags = old_flags_dword;
(void)result;
}
void protect(void* target_func, uint64_t old_flags) {
DWORD dummy_old_flags;
BOOL result = VirtualProtect(target_func,
16,
static_cast<DWORD>(old_flags),
&dummy_old_flags);
(void)result;
}
#else
# error "Mods not implemented yet on this platform"
#endif
recomp::mods::NativeCodeHandle::NativeCodeHandle(const std::filesystem::path& dll_path, const N64Recomp::Context& context) {
// Load the DLL.
dynamic_lib = std::make_unique<DynamicLibrary>(dll_path);
if (!dynamic_lib->good()) {
is_good = false;
return;
}
// Fill out the list of function pointers.
functions.resize(context.functions.size());
for (size_t i = 0; i < functions.size(); i++) {
std::string func_name = "mod_func_" + std::to_string(i);
is_good &= dynamic_lib->get_dll_symbol(functions[i], func_name.c_str());
if (!is_good) {
return;
}
}
// Get the standard exported symbols.
is_good = true;
is_good &= dynamic_lib->get_dll_symbol(imported_funcs, "imported_funcs");
is_good &= dynamic_lib->get_dll_symbol(reference_symbol_funcs, "reference_symbol_funcs");
is_good &= dynamic_lib->get_dll_symbol(event_indices, "event_indices");
is_good &= dynamic_lib->get_dll_symbol(recomp_trigger_event, "recomp_trigger_event");
is_good &= dynamic_lib->get_dll_symbol(get_function, "get_function");
is_good &= dynamic_lib->get_dll_symbol(reference_section_addresses, "reference_section_addresses");
is_good &= dynamic_lib->get_dll_symbol(section_addresses, "section_addresses");
}
void unprotect(void* target_func, DWORD* old_flags) {
BOOL result = VirtualProtect(target_func,
16,
PAGE_READWRITE,
old_flags);
(void)result;
bool recomp::mods::NativeCodeHandle::good() {
return dynamic_lib->good() && is_good;
}
void protect(void* target_func, DWORD old_flags) {
DWORD dummy_old_flags;
BOOL result = VirtualProtect(target_func,
16,
old_flags,
&dummy_old_flags);
(void)result;
void recomp::mods::NativeCodeHandle::set_bad() {
dynamic_lib.reset();
is_good = false;
}
void patch_func(recomp_func_t* target_func, recomp::mods::ModFunction replacement_func) {
@ -162,7 +137,7 @@ void patch_func(recomp_func_t* target_func, recomp::mods::ModFunction replacemen
offset += count;
};
DWORD old_flags;
uint64_t old_flags;
unprotect(target_func_u8, &old_flags);
std::visit(overloaded {
@ -177,7 +152,7 @@ void patch_func(recomp_func_t* target_func, recomp::mods::ModFunction replacemen
}
void unpatch_func(void* target_func, const recomp::mods::PatchData& data) {
DWORD old_flags;
uint64_t old_flags;
unprotect(target_func, &old_flags);
memcpy(target_func, data.replaced_bytes.data(), data.replaced_bytes.size());
protect(target_func, old_flags);
@ -210,17 +185,17 @@ recomp::mods::ModLoadError recomp::mods::ModContext::load_mod(uint8_t* rdram, co
std::span<uint8_t> binary_span {reinterpret_cast<uint8_t*>(binary_data.data()), binary_data.size() };
// Parse the symbol file into the recompiler contexts.
N64Recomp::ModSymbolsError symbol_load_error = N64Recomp::parse_mod_symbols(syms_data, binary_span, section_vrom_map, handle.recompiler_context);
N64Recomp::ModSymbolsError symbol_load_error = N64Recomp::parse_mod_symbols(syms_data, binary_span, section_vrom_map, *handle.recompiler_context);
if (symbol_load_error != N64Recomp::ModSymbolsError::Good) {
return ModLoadError::FailedToLoadSyms;
}
handle.section_load_addresses.resize(handle.recompiler_context.sections.size());
handle.section_load_addresses.resize(handle.recompiler_context->sections.size());
// Copy each section's binary into rdram, leaving room for the section's bss before the next one.
int32_t cur_section_addr = load_address;
for (size_t section_index = 0; section_index < handle.recompiler_context.sections.size(); section_index++) {
const auto& section = handle.recompiler_context.sections[section_index];
for (size_t section_index = 0; section_index < handle.recompiler_context->sections.size(); section_index++) {
const auto& section = handle.recompiler_context->sections[section_index];
for (size_t i = 0; i < section.size; i++) {
MEM_B(i, (gpr)cur_section_addr) = binary_data[section.rom_addr + i];
}
@ -393,7 +368,7 @@ bool dependency_version_met(uint8_t major, uint8_t minor, uint8_t patch, uint8_t
void recomp::mods::ModContext::check_dependencies(recomp::mods::ModHandle& mod, std::vector<std::pair<recomp::mods::ModLoadError, std::string>>& errors) {
errors.clear();
for (N64Recomp::Dependency& cur_dep : mod.recompiler_context.dependencies) {
for (N64Recomp::Dependency& cur_dep : mod.recompiler_context->dependencies) {
// Handle special dependency names.
if (cur_dep.mod_id == N64Recomp::DependencyBaseRecomp || cur_dep.mod_id == N64Recomp::DependencySelf) {
continue;
@ -424,7 +399,7 @@ recomp::mods::ModLoadError recomp::mods::ModContext::load_mod_code(recomp::mods:
// TODO implement LuaJIT recompilation and allow it instead of native code loading via a mod manifest flag.
std::filesystem::path dll_path = mod.manifest.mod_root_path;
dll_path.replace_extension(".dll");
mod.code_handle = std::make_unique<NativeCodeHandle>(dll_path, mod.recompiler_context);
mod.code_handle = std::make_unique<NativeCodeHandle>(dll_path, *mod.recompiler_context);
if (!mod.code_handle->good()) {
mod.code_handle.reset();
error_param = dll_path.string();
@ -440,10 +415,10 @@ recomp::mods::ModLoadError recomp::mods::ModContext::load_mod_code(recomp::mods:
recomp::mods::ModLoadError recomp::mods::ModContext::resolve_dependencies(recomp::mods::ModHandle& mod, std::string& error_param) {
// Reference symbols from the base recomp.
for (size_t reference_sym_index = 0; reference_sym_index < mod.recompiler_context.num_regular_reference_symbols(); reference_sym_index++) {
const N64Recomp::ReferenceSymbol& reference_sym = mod.recompiler_context.get_regular_reference_symbol(reference_sym_index);
uint32_t reference_section_vrom = mod.recompiler_context.get_reference_section_rom(reference_sym.section_index);
uint32_t reference_section_vram = mod.recompiler_context.get_reference_section_vram(reference_sym.section_index);
for (size_t reference_sym_index = 0; reference_sym_index < mod.recompiler_context->num_regular_reference_symbols(); reference_sym_index++) {
const N64Recomp::ReferenceSymbol& reference_sym = mod.recompiler_context->get_regular_reference_symbol(reference_sym_index);
uint32_t reference_section_vrom = mod.recompiler_context->get_reference_section_rom(reference_sym.section_index);
uint32_t reference_section_vram = mod.recompiler_context->get_reference_section_vram(reference_sym.section_index);
uint32_t reference_symbol_vram = reference_section_vram + reference_sym.section_offset;
recomp_func_t* found_func = recomp::overlays::get_func_by_section_ram(reference_section_vrom, reference_symbol_vram);
@ -461,9 +436,9 @@ recomp::mods::ModLoadError recomp::mods::ModContext::resolve_dependencies(recomp
}
// Imported symbols.
for (size_t import_index = 0; import_index < mod.recompiler_context.import_symbols.size(); import_index++) {
const N64Recomp::ImportSymbol& imported_func = mod.recompiler_context.import_symbols[import_index];
const N64Recomp::Dependency& dependency = mod.recompiler_context.dependencies[imported_func.dependency_index];
for (size_t import_index = 0; import_index < mod.recompiler_context->import_symbols.size(); import_index++) {
const N64Recomp::ImportSymbol& imported_func = mod.recompiler_context->import_symbols[import_index];
const N64Recomp::Dependency& dependency = mod.recompiler_context->dependencies[imported_func.dependency_index];
recomp_func_t* found_func = nullptr;
@ -491,7 +466,7 @@ recomp::mods::ModLoadError recomp::mods::ModContext::resolve_dependencies(recomp
}
// Apply all the function replacements in the mod.
for (const auto& replacement : mod.recompiler_context.replacements) {
for (const auto& replacement : mod.recompiler_context->replacements) {
recomp_func_t* to_replace = recomp::overlays::get_func_by_section_ram(replacement.original_section_vrom, replacement.original_vram);
if (to_replace == nullptr) {