From 3dc622d895f02f8283e0f6cf03a9cd9f76e32a81 Mon Sep 17 00:00:00 2001 From: Mr-Wiseguy Date: Thu, 29 Aug 2024 23:47:54 -0400 Subject: [PATCH] Reorganized mod code handle code --- librecomp/include/librecomp/mods.hpp | 80 +++++++++ librecomp/src/mods.cpp | 251 ++++++++++++--------------- 2 files changed, 193 insertions(+), 138 deletions(-) diff --git a/librecomp/include/librecomp/mods.hpp b/librecomp/include/librecomp/mods.hpp index 4a74b61..853cda7 100644 --- a/librecomp/include/librecomp/mods.hpp +++ b/librecomp/include/librecomp/mods.hpp @@ -12,6 +12,7 @@ #include #include #include +#include #define MINIZ_NO_DEFLATE_APIS #define MINIZ_NO_ARCHIVE_WRITING_APIS @@ -21,6 +22,10 @@ #include "librecomp/recomp.h" #include "librecomp/sections.h" +namespace N64Recomp { + class Context; +}; + namespace recomp { namespace mods { enum class ModOpenError { @@ -156,6 +161,81 @@ namespace recomp { std::unordered_map patched_funcs; std::unordered_map loaded_mods_by_id; }; + + using ModFunction = std::variant; + + class ModCodeHandle { + public: + virtual ~ModCodeHandle() {} + virtual bool good() = 0; + virtual void set_imported_function_pointer(size_t import_index, recomp_func_t* ptr) = 0; + virtual void set_reference_symbol_pointer(size_t symbol_index, recomp_func_t* ptr) = 0; + virtual void set_event_index(size_t local_event_index, uint32_t global_event_index) = 0; + virtual void set_recomp_trigger_event_pointer(void (*ptr)(uint8_t* rdram, recomp_context* ctx, uint32_t index)) = 0; + virtual void set_get_function_pointer(recomp_func_t* (*ptr)(int32_t)) = 0; + virtual void set_reference_section_addresses_pointer(int32_t* ptr) = 0; + virtual void set_local_section_address(size_t section_index, int32_t address) = 0; + virtual ModFunction get_function_handle(size_t func_index) = 0; + }; + + struct ModHandle { + ModManifest manifest; + std::unique_ptr code_handle; + std::unique_ptr recompiler_context; + std::vector section_load_addresses; + + ModHandle(ModManifest&& manifest); + ModHandle(const ModHandle& rhs) = delete; + ModHandle& operator=(const ModHandle& rhs) = delete; + ModHandle(ModHandle&& rhs); + ModHandle& operator=(ModHandle&& rhs); + ~ModHandle(); + }; + + class DynamicLibrary; + class NativeCodeHandle : public ModCodeHandle { + public: + NativeCodeHandle(const std::filesystem::path& dll_path, const N64Recomp::Context& context); + ~NativeCodeHandle() = default; + bool good() final; + void set_imported_function_pointer(size_t import_index, recomp_func_t* ptr) final { + imported_funcs[import_index] = ptr; + } + void set_reference_symbol_pointer(size_t symbol_index, recomp_func_t* ptr) final { + reference_symbol_funcs[symbol_index] = ptr; + }; + void set_event_index(size_t local_event_index, uint32_t global_event_index) final { + event_indices[local_event_index] = global_event_index; + }; + void set_recomp_trigger_event_pointer(void (*ptr)(uint8_t* rdram, recomp_context* ctx, uint32_t index)) final { + *recomp_trigger_event = ptr; + }; + void set_get_function_pointer(recomp_func_t* (*ptr)(int32_t)) final { + *get_function = ptr; + }; + void set_reference_section_addresses_pointer(int32_t* ptr) final { + *reference_section_addresses = ptr; + }; + void set_local_section_address(size_t section_index, int32_t address) final { + section_addresses[section_index] = address; + }; + ModFunction get_function_handle(size_t func_index) final { + return ModFunction{ functions[func_index] }; + } + private: + void set_bad(); + bool is_good; + std::unique_ptr dynamic_lib; + std::vector functions; + recomp_func_t** imported_funcs; + recomp_func_t** reference_symbol_funcs; + uint32_t* event_indices; + void (**recomp_trigger_event)(uint8_t* rdram, recomp_context* ctx, uint32_t index); + recomp_func_t* (**get_function)(int32_t vram); + int32_t** reference_section_addresses; + int32_t* section_addresses; + }; + } }; diff --git a/librecomp/src/mods.cpp b/librecomp/src/mods.cpp index ce45354..5728ac9 100644 --- a/librecomp/src/mods.cpp +++ b/librecomp/src/mods.cpp @@ -1,6 +1,5 @@ #include #include -#include #define WIN32_LEAN_AND_MEAN #include @@ -15,140 +14,116 @@ #define PATHFMT "%s" #endif +recomp::mods::ModHandle::ModHandle(ModManifest&& manifest) : + manifest(std::move(manifest)), code_handle(), recompiler_context{std::make_unique()} { } +recomp::mods::ModHandle::ModHandle(ModHandle&& rhs) = default; +recomp::mods::ModHandle& recomp::mods::ModHandle::operator=(ModHandle&& rhs) = default; +recomp::mods::ModHandle::~ModHandle() = default; + template struct overloaded : Ts... { using Ts::operator()...; }; template overloaded(Ts...) -> overloaded; -namespace recomp { - namespace mods { - using ModFunction = std::variant; +#if defined(_WIN32) +# define WIN32_LEAN_AND_MEAN +# include "Windows.h" - class ModCodeHandle { - public: - virtual ~ModCodeHandle() {} - virtual bool good() = 0; - virtual void set_imported_function_pointer(size_t import_index, recomp_func_t* ptr) = 0; - virtual void set_reference_symbol_pointer(size_t symbol_index, recomp_func_t* ptr) = 0; - virtual void set_event_index(size_t local_event_index, uint32_t global_event_index) = 0; - virtual void set_recomp_trigger_event_pointer(void (*ptr)(uint8_t* rdram, recomp_context* ctx, uint32_t index)) = 0; - virtual void set_get_function_pointer(recomp_func_t* (*ptr)(int32_t)) = 0; - virtual void set_reference_section_addresses_pointer(int32_t* ptr) = 0; - virtual void set_local_section_address(size_t section_index, int32_t address) = 0; - virtual ModFunction get_function_handle(size_t func_index) = 0; + class recomp::mods::DynamicLibrary { + public: + DynamicLibrary() = default; + DynamicLibrary(const std::filesystem::path& path) { + mod_dll = LoadLibraryW(path.c_str()); + } + ~DynamicLibrary() { + unload(); + } + DynamicLibrary(const DynamicLibrary&) = delete; + DynamicLibrary& operator=(const DynamicLibrary&) = delete; + DynamicLibrary(DynamicLibrary&&) = delete; + DynamicLibrary& operator=(DynamicLibrary&&) = delete; + + void unload() { + if (mod_dll != nullptr) { + FreeLibrary(mod_dll); + } + mod_dll = nullptr; + } + + bool good() { + return mod_dll != nullptr; + } + + template + bool get_dll_symbol(T& out, const char* name) const { + out = (T)GetProcAddress(mod_dll, name); + if (out == nullptr) { + return false; + } + return true; }; + private: + HMODULE mod_dll; + }; - class NativeCodeHandle : public ModCodeHandle { - public: - NativeCodeHandle(const std::filesystem::path& dll_path, const N64Recomp::Context& context) { - // Load the DLL. - mod_dll = LoadLibraryW(dll_path.c_str()); - if (mod_dll == nullptr) { - set_bad(); - return; - } - - // Fill out the list of function pointers. - functions.resize(context.functions.size()); - for (size_t i = 0; i < functions.size(); i++) { - std::string func_name = "mod_func_" + std::to_string(i); - functions[i] = (recomp_func_t*)GetProcAddress(mod_dll, func_name.c_str()); - if (functions[i] == nullptr) { - set_bad(); - return; - } - } - - // Get the standard exported symbols. - get_dll_func(imported_funcs, "imported_funcs"); - get_dll_func(reference_symbol_funcs, "reference_symbol_funcs"); - get_dll_func(event_indices, "event_indices"); - get_dll_func(recomp_trigger_event, "recomp_trigger_event"); - get_dll_func(get_function, "get_function"); - get_dll_func(reference_section_addresses, "reference_section_addresses"); - get_dll_func(section_addresses, "section_addresses"); - } - ~NativeCodeHandle() = default; - bool good() final { - return mod_dll != nullptr; - } - void set_imported_function_pointer(size_t import_index, recomp_func_t* ptr) final { - imported_funcs[import_index] = ptr; - } - void set_reference_symbol_pointer(size_t symbol_index, recomp_func_t* ptr) final { - reference_symbol_funcs[symbol_index] = ptr; - }; - void set_event_index(size_t local_event_index, uint32_t global_event_index) final { - event_indices[local_event_index] = global_event_index; - }; - void set_recomp_trigger_event_pointer(void (*ptr)(uint8_t* rdram, recomp_context* ctx, uint32_t index)) final { - *recomp_trigger_event = ptr; - }; - void set_get_function_pointer(recomp_func_t* (*ptr)(int32_t)) final { - *get_function = ptr; - }; - void set_reference_section_addresses_pointer(int32_t* ptr) final { - *reference_section_addresses = ptr; - }; - void set_local_section_address(size_t section_index, int32_t address) final { - section_addresses[section_index] = address; - }; - ModFunction get_function_handle(size_t func_index) final { - return ModFunction{ functions[func_index] }; - } - private: - template - void get_dll_func(T& out, const char* name) { - out = (T)GetProcAddress(mod_dll, name); - if (out == nullptr) { - set_bad(); - } - }; - void set_bad() { - if (mod_dll) { - FreeLibrary(mod_dll); - } - mod_dll = nullptr; - } - HMODULE mod_dll; - std::vector functions; - recomp_func_t** imported_funcs; - recomp_func_t** reference_symbol_funcs; - uint32_t* event_indices; - void (**recomp_trigger_event)(uint8_t* rdram, recomp_context* ctx, uint32_t index); - recomp_func_t* (**get_function)(int32_t vram); - int32_t** reference_section_addresses; - int32_t* section_addresses; - }; - - struct ModHandle { - ModManifest manifest; - std::unique_ptr code_handle; - N64Recomp::Context recompiler_context; - std::vector section_load_addresses; - - ModHandle(ModManifest&& manifest) : - manifest(std::move(manifest)), code_handle(), recompiler_context{} { - } - }; + void unprotect(void* target_func, uint64_t* old_flags) { + DWORD old_flags_dword; + BOOL result = VirtualProtect(target_func, + 16, + PAGE_READWRITE, + &old_flags_dword); + *old_flags = old_flags_dword; + (void)result; } + + void protect(void* target_func, uint64_t old_flags) { + DWORD dummy_old_flags; + BOOL result = VirtualProtect(target_func, + 16, + static_cast(old_flags), + &dummy_old_flags); + (void)result; + } +#else +# error "Mods not implemented yet on this platform" +#endif + +recomp::mods::NativeCodeHandle::NativeCodeHandle(const std::filesystem::path& dll_path, const N64Recomp::Context& context) { + // Load the DLL. + dynamic_lib = std::make_unique(dll_path); + if (!dynamic_lib->good()) { + is_good = false; + return; + } + + // Fill out the list of function pointers. + functions.resize(context.functions.size()); + for (size_t i = 0; i < functions.size(); i++) { + std::string func_name = "mod_func_" + std::to_string(i); + is_good &= dynamic_lib->get_dll_symbol(functions[i], func_name.c_str()); + if (!is_good) { + return; + } + } + + // Get the standard exported symbols. + is_good = true; + is_good &= dynamic_lib->get_dll_symbol(imported_funcs, "imported_funcs"); + is_good &= dynamic_lib->get_dll_symbol(reference_symbol_funcs, "reference_symbol_funcs"); + is_good &= dynamic_lib->get_dll_symbol(event_indices, "event_indices"); + is_good &= dynamic_lib->get_dll_symbol(recomp_trigger_event, "recomp_trigger_event"); + is_good &= dynamic_lib->get_dll_symbol(get_function, "get_function"); + is_good &= dynamic_lib->get_dll_symbol(reference_section_addresses, "reference_section_addresses"); + is_good &= dynamic_lib->get_dll_symbol(section_addresses, "section_addresses"); } -void unprotect(void* target_func, DWORD* old_flags) { - BOOL result = VirtualProtect(target_func, - 16, - PAGE_READWRITE, - old_flags); - (void)result; +bool recomp::mods::NativeCodeHandle::good() { + return dynamic_lib->good() && is_good; } -void protect(void* target_func, DWORD old_flags) { - DWORD dummy_old_flags; - BOOL result = VirtualProtect(target_func, - 16, - old_flags, - &dummy_old_flags); - (void)result; +void recomp::mods::NativeCodeHandle::set_bad() { + dynamic_lib.reset(); + is_good = false; } void patch_func(recomp_func_t* target_func, recomp::mods::ModFunction replacement_func) { @@ -162,7 +137,7 @@ void patch_func(recomp_func_t* target_func, recomp::mods::ModFunction replacemen offset += count; }; - DWORD old_flags; + uint64_t old_flags; unprotect(target_func_u8, &old_flags); std::visit(overloaded { @@ -177,7 +152,7 @@ void patch_func(recomp_func_t* target_func, recomp::mods::ModFunction replacemen } void unpatch_func(void* target_func, const recomp::mods::PatchData& data) { - DWORD old_flags; + uint64_t old_flags; unprotect(target_func, &old_flags); memcpy(target_func, data.replaced_bytes.data(), data.replaced_bytes.size()); protect(target_func, old_flags); @@ -210,17 +185,17 @@ recomp::mods::ModLoadError recomp::mods::ModContext::load_mod(uint8_t* rdram, co std::span binary_span {reinterpret_cast(binary_data.data()), binary_data.size() }; // Parse the symbol file into the recompiler contexts. - N64Recomp::ModSymbolsError symbol_load_error = N64Recomp::parse_mod_symbols(syms_data, binary_span, section_vrom_map, handle.recompiler_context); + N64Recomp::ModSymbolsError symbol_load_error = N64Recomp::parse_mod_symbols(syms_data, binary_span, section_vrom_map, *handle.recompiler_context); if (symbol_load_error != N64Recomp::ModSymbolsError::Good) { return ModLoadError::FailedToLoadSyms; } - handle.section_load_addresses.resize(handle.recompiler_context.sections.size()); + handle.section_load_addresses.resize(handle.recompiler_context->sections.size()); // Copy each section's binary into rdram, leaving room for the section's bss before the next one. int32_t cur_section_addr = load_address; - for (size_t section_index = 0; section_index < handle.recompiler_context.sections.size(); section_index++) { - const auto& section = handle.recompiler_context.sections[section_index]; + for (size_t section_index = 0; section_index < handle.recompiler_context->sections.size(); section_index++) { + const auto& section = handle.recompiler_context->sections[section_index]; for (size_t i = 0; i < section.size; i++) { MEM_B(i, (gpr)cur_section_addr) = binary_data[section.rom_addr + i]; } @@ -393,7 +368,7 @@ bool dependency_version_met(uint8_t major, uint8_t minor, uint8_t patch, uint8_t void recomp::mods::ModContext::check_dependencies(recomp::mods::ModHandle& mod, std::vector>& errors) { errors.clear(); - for (N64Recomp::Dependency& cur_dep : mod.recompiler_context.dependencies) { + for (N64Recomp::Dependency& cur_dep : mod.recompiler_context->dependencies) { // Handle special dependency names. if (cur_dep.mod_id == N64Recomp::DependencyBaseRecomp || cur_dep.mod_id == N64Recomp::DependencySelf) { continue; @@ -424,7 +399,7 @@ recomp::mods::ModLoadError recomp::mods::ModContext::load_mod_code(recomp::mods: // TODO implement LuaJIT recompilation and allow it instead of native code loading via a mod manifest flag. std::filesystem::path dll_path = mod.manifest.mod_root_path; dll_path.replace_extension(".dll"); - mod.code_handle = std::make_unique(dll_path, mod.recompiler_context); + mod.code_handle = std::make_unique(dll_path, *mod.recompiler_context); if (!mod.code_handle->good()) { mod.code_handle.reset(); error_param = dll_path.string(); @@ -440,10 +415,10 @@ recomp::mods::ModLoadError recomp::mods::ModContext::load_mod_code(recomp::mods: recomp::mods::ModLoadError recomp::mods::ModContext::resolve_dependencies(recomp::mods::ModHandle& mod, std::string& error_param) { // Reference symbols from the base recomp. - for (size_t reference_sym_index = 0; reference_sym_index < mod.recompiler_context.num_regular_reference_symbols(); reference_sym_index++) { - const N64Recomp::ReferenceSymbol& reference_sym = mod.recompiler_context.get_regular_reference_symbol(reference_sym_index); - uint32_t reference_section_vrom = mod.recompiler_context.get_reference_section_rom(reference_sym.section_index); - uint32_t reference_section_vram = mod.recompiler_context.get_reference_section_vram(reference_sym.section_index); + for (size_t reference_sym_index = 0; reference_sym_index < mod.recompiler_context->num_regular_reference_symbols(); reference_sym_index++) { + const N64Recomp::ReferenceSymbol& reference_sym = mod.recompiler_context->get_regular_reference_symbol(reference_sym_index); + uint32_t reference_section_vrom = mod.recompiler_context->get_reference_section_rom(reference_sym.section_index); + uint32_t reference_section_vram = mod.recompiler_context->get_reference_section_vram(reference_sym.section_index); uint32_t reference_symbol_vram = reference_section_vram + reference_sym.section_offset; recomp_func_t* found_func = recomp::overlays::get_func_by_section_ram(reference_section_vrom, reference_symbol_vram); @@ -461,9 +436,9 @@ recomp::mods::ModLoadError recomp::mods::ModContext::resolve_dependencies(recomp } // Imported symbols. - for (size_t import_index = 0; import_index < mod.recompiler_context.import_symbols.size(); import_index++) { - const N64Recomp::ImportSymbol& imported_func = mod.recompiler_context.import_symbols[import_index]; - const N64Recomp::Dependency& dependency = mod.recompiler_context.dependencies[imported_func.dependency_index]; + for (size_t import_index = 0; import_index < mod.recompiler_context->import_symbols.size(); import_index++) { + const N64Recomp::ImportSymbol& imported_func = mod.recompiler_context->import_symbols[import_index]; + const N64Recomp::Dependency& dependency = mod.recompiler_context->dependencies[imported_func.dependency_index]; recomp_func_t* found_func = nullptr; @@ -491,7 +466,7 @@ recomp::mods::ModLoadError recomp::mods::ModContext::resolve_dependencies(recomp } // Apply all the function replacements in the mod. - for (const auto& replacement : mod.recompiler_context.replacements) { + for (const auto& replacement : mod.recompiler_context->replacements) { recomp_func_t* to_replace = recomp::overlays::get_func_by_section_ram(replacement.original_section_vrom, replacement.original_vram); if (to_replace == nullptr) {