From 904f17a87248ce89ada7f2d52f2187592c07e560 Mon Sep 17 00:00:00 2001 From: Mr-Wiseguy Date: Fri, 30 Aug 2024 02:59:34 -0400 Subject: [PATCH] Implement importing functions from other mods --- N64Recomp | 2 +- librecomp/include/librecomp/mods.hpp | 37 +++++++---- librecomp/src/mods.cpp | 94 ++++++++++++++++++++++++---- 3 files changed, 108 insertions(+), 25 deletions(-) diff --git a/N64Recomp b/N64Recomp index 131157d..747cd9f 160000 --- a/N64Recomp +++ b/N64Recomp @@ -1 +1 @@ -Subproject commit 131157dad85953c3f484cea5a67c61c49f38d0b3 +Subproject commit 747cd9f6acc09d20ea9a8148def8cea88728a5cb diff --git a/librecomp/include/librecomp/mods.hpp b/librecomp/include/librecomp/mods.hpp index 853cda7..fb9c406 100644 --- a/librecomp/include/librecomp/mods.hpp +++ b/librecomp/include/librecomp/mods.hpp @@ -133,7 +133,9 @@ namespace recomp { std::string mod_id; }; - struct ModHandle; + using GenericFunction = std::variant; + + class ModHandle; class ModContext { public: ModContext(); @@ -160,25 +162,29 @@ namespace recomp { std::unordered_set enabled_mods; std::unordered_map patched_funcs; std::unordered_map loaded_mods_by_id; + // // Maps (mod id, export name) to (mod index, function index). + // std::unordered_map, std::pair> mod_exports; + // // Maps (mod id, event name) to a vector of callback functions attached to that event. + // std::unordered_map, std::vector> callbacks; }; - using ModFunction = std::variant; - class ModCodeHandle { public: virtual ~ModCodeHandle() {} virtual bool good() = 0; - virtual void set_imported_function_pointer(size_t import_index, recomp_func_t* ptr) = 0; + virtual void set_imported_function(size_t import_index, GenericFunction func) = 0; virtual void set_reference_symbol_pointer(size_t symbol_index, recomp_func_t* ptr) = 0; virtual void set_event_index(size_t local_event_index, uint32_t global_event_index) = 0; virtual void set_recomp_trigger_event_pointer(void (*ptr)(uint8_t* rdram, recomp_context* ctx, uint32_t index)) = 0; virtual void set_get_function_pointer(recomp_func_t* (*ptr)(int32_t)) = 0; virtual void set_reference_section_addresses_pointer(int32_t* ptr) = 0; virtual void set_local_section_address(size_t section_index, int32_t address) = 0; - virtual ModFunction get_function_handle(size_t func_index) = 0; + virtual GenericFunction get_function_handle(size_t func_index) = 0; }; - struct ModHandle { + class ModHandle { + public: + // TODO make these private and expose methods for the functionality they're currently used in. ModManifest manifest; std::unique_ptr code_handle; std::unique_ptr recompiler_context; @@ -190,6 +196,17 @@ namespace recomp { ModHandle(ModHandle&& rhs); ModHandle& operator=(ModHandle&& rhs); ~ModHandle(); + + size_t num_exports() const; + size_t num_events() const; + + ModLoadError populate_exports(std::string& error_param); + bool get_export_function(const std::string& export_name, GenericFunction& out) const; + private: + // Mapping of export name to function index. + std::unordered_map exports_by_name; + // List of global event indices ordered by the event's local index. + std::vector global_event_indices; }; class DynamicLibrary; @@ -198,9 +215,7 @@ namespace recomp { NativeCodeHandle(const std::filesystem::path& dll_path, const N64Recomp::Context& context); ~NativeCodeHandle() = default; bool good() final; - void set_imported_function_pointer(size_t import_index, recomp_func_t* ptr) final { - imported_funcs[import_index] = ptr; - } + void set_imported_function(size_t import_index, GenericFunction func) final; void set_reference_symbol_pointer(size_t symbol_index, recomp_func_t* ptr) final { reference_symbol_funcs[symbol_index] = ptr; }; @@ -219,8 +234,8 @@ namespace recomp { void set_local_section_address(size_t section_index, int32_t address) final { section_addresses[section_index] = address; }; - ModFunction get_function_handle(size_t func_index) final { - return ModFunction{ functions[func_index] }; + GenericFunction get_function_handle(size_t func_index) final { + return GenericFunction{ functions[func_index] }; } private: void set_bad(); diff --git a/librecomp/src/mods.cpp b/librecomp/src/mods.cpp index 5728ac9..2cd4a36 100644 --- a/librecomp/src/mods.cpp +++ b/librecomp/src/mods.cpp @@ -15,11 +15,45 @@ #endif recomp::mods::ModHandle::ModHandle(ModManifest&& manifest) : - manifest(std::move(manifest)), code_handle(), recompiler_context{std::make_unique()} { } + manifest(std::move(manifest)), + code_handle(), + recompiler_context{std::make_unique()} +{ + +} + recomp::mods::ModHandle::ModHandle(ModHandle&& rhs) = default; recomp::mods::ModHandle& recomp::mods::ModHandle::operator=(ModHandle&& rhs) = default; recomp::mods::ModHandle::~ModHandle() = default; +size_t recomp::mods::ModHandle::num_exports() const { + return recompiler_context->exported_funcs.size(); +} + +size_t recomp::mods::ModHandle::num_events() const { + return recompiler_context->event_symbols.size(); +} + + +recomp::mods::ModLoadError recomp::mods::ModHandle::populate_exports(std::string& error_param) { + for (size_t func_index : recompiler_context->exported_funcs) { + const auto& func_handle = recompiler_context->functions[func_index]; + exports_by_name.emplace(func_handle.name, func_index); + } + + return ModLoadError::Good; +} + +bool recomp::mods::ModHandle::get_export_function(const std::string& export_name, GenericFunction& out) const { + auto find_it = exports_by_name.find(export_name); + if (find_it == exports_by_name.end()) { + return false; + } + + out = code_handle->get_function_handle(find_it->second); + return true; +} + template struct overloaded : Ts... { using Ts::operator()...; }; template @@ -99,8 +133,13 @@ recomp::mods::NativeCodeHandle::NativeCodeHandle(const std::filesystem::path& dl // Fill out the list of function pointers. functions.resize(context.functions.size()); for (size_t i = 0; i < functions.size(); i++) { - std::string func_name = "mod_func_" + std::to_string(i); - is_good &= dynamic_lib->get_dll_symbol(functions[i], func_name.c_str()); + if(!context.functions[i].name.empty()) { + is_good &= dynamic_lib->get_dll_symbol(functions[i], context.functions[i].name.c_str()); + } + else { + std::string func_name = "mod_func_" + std::to_string(i); + is_good &= dynamic_lib->get_dll_symbol(functions[i], func_name.c_str()); + } if (!is_good) { return; } @@ -126,7 +165,15 @@ void recomp::mods::NativeCodeHandle::set_bad() { is_good = false; } -void patch_func(recomp_func_t* target_func, recomp::mods::ModFunction replacement_func) { +void recomp::mods::NativeCodeHandle::set_imported_function(size_t import_index, GenericFunction func) { + std::visit(overloaded { + [this, import_index](recomp_func_t* native_func) { + imported_funcs[import_index] = native_func; + } + }, func); +} + +void patch_func(recomp_func_t* target_func, recomp::mods::GenericFunction replacement_func) { static const uint8_t movabs_rax[] = {0x48, 0xB8}; static const uint8_t jmp_rax[] = {0xFF, 0xE0}; uint8_t* target_func_u8 = reinterpret_cast(target_func); @@ -184,7 +231,7 @@ recomp::mods::ModLoadError recomp::mods::ModContext::load_mod(uint8_t* rdram, co std::span binary_span {reinterpret_cast(binary_data.data()), binary_data.size() }; - // Parse the symbol file into the recompiler contexts. + // Parse the symbol file into the recompiler context. N64Recomp::ModSymbolsError symbol_load_error = N64Recomp::parse_mod_symbols(syms_data, binary_span, section_vrom_map, *handle.recompiler_context); if (symbol_load_error != N64Recomp::ModSymbolsError::Good) { return ModLoadError::FailedToLoadSyms; @@ -230,8 +277,8 @@ std::vector recomp::mods::ModContext::scan_mo } // Nothing needed for these two, they just need to be explicitly declared outside the header to allow forward declaration of ModHandle. -recomp::mods::ModContext::ModContext() {} -recomp::mods::ModContext::~ModContext() {} +recomp::mods::ModContext::ModContext() = default; +recomp::mods::ModContext::~ModContext() = default; void recomp::mods::ModContext::enable_mod(const std::string& mod_id, bool enabled) { if (enabled) { @@ -406,7 +453,14 @@ recomp::mods::ModLoadError recomp::mods::ModContext::load_mod_code(recomp::mods: return ModLoadError::FailedToLoadNativeCode; } - // TODO exports + // Populate the mod's export map. + std::string export_error_param; + ModLoadError export_error = mod.populate_exports(export_error_param); + + if (export_error != ModLoadError::Good) { + error_param = std::move(export_error_param); + return export_error; + } // TODO events @@ -440,19 +494,33 @@ recomp::mods::ModLoadError recomp::mods::ModContext::resolve_dependencies(recomp const N64Recomp::ImportSymbol& imported_func = mod.recompiler_context->import_symbols[import_index]; const N64Recomp::Dependency& dependency = mod.recompiler_context->dependencies[imported_func.dependency_index]; - recomp_func_t* found_func = nullptr; + GenericFunction func_handle{}; + bool did_find_func = false; if (dependency.mod_id == N64Recomp::DependencyBaseRecomp) { - found_func = recomp::overlays::get_base_export(imported_func.base.name); + recomp_func_t* func_ptr = recomp::overlays::get_base_export(imported_func.base.name); + did_find_func = func_ptr != nullptr; + func_handle = func_ptr; + } + else if (dependency.mod_id == N64Recomp::DependencySelf) { + did_find_func = mod.get_export_function(imported_func.base.name, func_handle); + } + else { + auto find_mod_it = loaded_mods_by_id.find(dependency.mod_id); + if (find_mod_it == loaded_mods_by_id.end()) { + error_param = dependency.mod_id; + return ModLoadError::MissingDependency; + } + const auto& dependency = opened_mods[find_mod_it->second]; + did_find_func = dependency.get_export_function(imported_func.base.name, func_handle); } - // TODO DependencySelf and other mods - if (found_func == nullptr) { + if (!did_find_func) { error_param = dependency.mod_id + ":" + imported_func.base.name; return ModLoadError::InvalidImport; } - mod.code_handle->set_imported_function_pointer(import_index, found_func); + mod.code_handle->set_imported_function(import_index, func_handle); } // TODO event_indices