From e7afded4ebc08ee3dd3e9ecc3739b938015e24ef Mon Sep 17 00:00:00 2001 From: Mr-Wiseguy Date: Wed, 28 Aug 2024 23:54:57 -0400 Subject: [PATCH] Reorganized mod loading, added mod dependency validation --- librecomp/include/librecomp/mods.hpp | 8 +- librecomp/src/mod_manifest.cpp | 15 +- librecomp/src/mods.cpp | 383 ++++++++++++++++++--------- 3 files changed, 270 insertions(+), 136 deletions(-) diff --git a/librecomp/include/librecomp/mods.hpp b/librecomp/include/librecomp/mods.hpp index 7913beb..4a74b61 100644 --- a/librecomp/include/librecomp/mods.hpp +++ b/librecomp/include/librecomp/mods.hpp @@ -51,7 +51,8 @@ namespace recomp { InvalidFunctionReplacement, FailedToFindReplacement, ReplacementConflict, - MissingDependencies, + MissingDependency, + WrongDependencyVersion, ModConflict, }; @@ -144,13 +145,16 @@ namespace recomp { private: ModOpenError open_mod(const std::filesystem::path& mod_path, std::string& error_param); ModLoadError load_mod(uint8_t* rdram, const std::unordered_map& section_map, recomp::mods::ModHandle& handle, int32_t load_address, uint32_t& ram_used, std::string& error_param); + void check_dependencies(recomp::mods::ModHandle& mod, std::vector>& errors); + ModLoadError load_mod_code(recomp::mods::ModHandle& mod, std::string& error_param); + ModLoadError resolve_dependencies(recomp::mods::ModHandle& mod, std::string& error_param); void add_opened_mod(ModManifest&& manifest); std::vector opened_mods; std::unordered_set mod_ids; std::unordered_set enabled_mods; std::unordered_map patched_funcs; - std::unordered_map sections_by_vrom; + std::unordered_map loaded_mods_by_id; }; } }; diff --git a/librecomp/src/mod_manifest.cpp b/librecomp/src/mod_manifest.cpp index db1ffa4..6e7566e 100644 --- a/librecomp/src/mod_manifest.cpp +++ b/librecomp/src/mod_manifest.cpp @@ -138,8 +138,7 @@ enum class ManifestField { BinaryPath, BinarySymsPath, RomPatchPath, - RomPatchSymsPath, - Invalid, + RomPatchSymsPath }; const std::string mod_id_key = "id"; @@ -425,14 +424,22 @@ std::string recomp::mods::error_to_string(ModLoadError error) { return "Failed to load mod symbol file"; case ModLoadError::FailedToLoadBinary: return "Failed to load mod binary file"; + case ModLoadError::FailedToLoadNativeCode: + return "Failed to load mod DLL"; + case ModLoadError::InvalidReferenceSymbol: + return "Reference symbol does not exist"; + case ModLoadError::InvalidImport: + return "Imported function not found"; case ModLoadError::InvalidFunctionReplacement: return "Function to be replaced does not exist"; case ModLoadError::FailedToFindReplacement: return "Failed to find replacement function"; case ModLoadError::ReplacementConflict: return "Attempted to replace a function that cannot be replaced"; - case ModLoadError::MissingDependencies: - return "Missing dependencies"; + case ModLoadError::MissingDependency: + return "Missing dependency"; + case ModLoadError::WrongDependencyVersion: + return "Wrong dependency version"; case ModLoadError::ModConflict: return "Conflicts with other mod"; } diff --git a/librecomp/src/mods.cpp b/librecomp/src/mods.cpp index 9333a87..ce45354 100644 --- a/librecomp/src/mods.cpp +++ b/librecomp/src/mods.cpp @@ -125,6 +125,7 @@ namespace recomp { ModManifest manifest; std::unique_ptr code_handle; N64Recomp::Context recompiler_context; + std::vector section_load_addresses; ModHandle(ModManifest&& manifest) : manifest(std::move(manifest)), code_handle(), recompiler_context{} { @@ -188,140 +189,46 @@ void recomp::mods::ModContext::add_opened_mod(ModManifest&& manifest) { recomp::mods::ModLoadError recomp::mods::ModContext::load_mod(uint8_t* rdram, const std::unordered_map& section_vrom_map, recomp::mods::ModHandle& handle, int32_t load_address, uint32_t& ram_used, std::string& error_param) { using namespace recomp::mods; - std::vector section_load_addresses{}; + handle.section_load_addresses.clear(); + + // Load the mod symbol data from the file provided in the manifest. + bool binary_syms_exists = false; + std::vector syms_data = handle.manifest.file_handle->read_file(handle.manifest.binary_syms_path, binary_syms_exists); - { - // Load the mod symbol data from the file provided in the manifest. - bool binary_syms_exists = false; - std::vector syms_data = handle.manifest.file_handle->read_file(handle.manifest.binary_syms_path, binary_syms_exists); + if (!binary_syms_exists) { + return recomp::mods::ModLoadError::FailedToLoadSyms; + } + + // Load the binary data from the file provided in the manifest. + bool binary_exists = false; + std::vector binary_data = handle.manifest.file_handle->read_file(handle.manifest.binary_path, binary_exists); - if (!binary_syms_exists) { - return recomp::mods::ModLoadError::FailedToLoadSyms; - } - - // Load the binary data from the file provided in the manifest. - bool binary_exists = false; - std::vector binary_data = handle.manifest.file_handle->read_file(handle.manifest.binary_path, binary_exists); - - if (!binary_exists) { - return recomp::mods::ModLoadError::FailedToLoadBinary; - } - - std::span binary_span {reinterpret_cast(binary_data.data()), binary_data.size() }; - - // Parse the symbol file into the recompiler contexts. - N64Recomp::ModSymbolsError symbol_load_error = N64Recomp::parse_mod_symbols(syms_data, binary_span, section_vrom_map, handle.recompiler_context); - if (symbol_load_error != N64Recomp::ModSymbolsError::Good) { - return ModLoadError::FailedToLoadSyms; - } - - section_load_addresses.resize(handle.recompiler_context.sections.size()); - - // Copy each section's binary into rdram, leaving room for the section's bss before the next one. - int32_t cur_section_addr = load_address; - for (size_t section_index = 0; section_index < handle.recompiler_context.sections.size(); section_index++) { - const auto& section = handle.recompiler_context.sections[section_index]; - for (size_t i = 0; i < section.size; i++) { - MEM_B(i, (gpr)cur_section_addr) = binary_data[section.rom_addr + i]; - } - section_load_addresses[section_index] = cur_section_addr; - cur_section_addr += section.size + section.bss_size; - } - - ram_used = cur_section_addr - load_address; + if (!binary_exists) { + return recomp::mods::ModLoadError::FailedToLoadBinary; } - // TODO implement LuaJIT recompilation and allow it instead of native code loading via a mod manifest flag. - std::filesystem::path dll_path = handle.manifest.mod_root_path; - dll_path.replace_extension(".dll"); - handle.code_handle = std::make_unique(dll_path, handle.recompiler_context); - if (!handle.code_handle->good()) { - printf("Failed to open mod dll: %ls\n", dll_path.c_str()); - handle.code_handle.reset(); - return ModLoadError::FailedToLoadNativeCode; + std::span binary_span {reinterpret_cast(binary_data.data()), binary_data.size() }; + + // Parse the symbol file into the recompiler contexts. + N64Recomp::ModSymbolsError symbol_load_error = N64Recomp::parse_mod_symbols(syms_data, binary_span, section_vrom_map, handle.recompiler_context); + if (symbol_load_error != N64Recomp::ModSymbolsError::Good) { + return ModLoadError::FailedToLoadSyms; + } + + handle.section_load_addresses.resize(handle.recompiler_context.sections.size()); + + // Copy each section's binary into rdram, leaving room for the section's bss before the next one. + int32_t cur_section_addr = load_address; + for (size_t section_index = 0; section_index < handle.recompiler_context.sections.size(); section_index++) { + const auto& section = handle.recompiler_context.sections[section_index]; + for (size_t i = 0; i < section.size; i++) { + MEM_B(i, (gpr)cur_section_addr) = binary_data[section.rom_addr + i]; + } + handle.section_load_addresses[section_index] = cur_section_addr; + cur_section_addr += section.size + section.bss_size; } - // TODO dependency resolution - - // TODO imported_funcs from other mods - for (size_t import_index = 0; import_index < handle.recompiler_context.import_symbols.size(); import_index++) { - const N64Recomp::ImportSymbol& imported_func = handle.recompiler_context.import_symbols[import_index]; - const N64Recomp::Dependency& dependency = handle.recompiler_context.dependencies[imported_func.dependency_index]; - - recomp_func_t* found_func = nullptr; - - if (dependency.mod_id == "*") { - found_func = recomp::overlays::get_base_export(imported_func.base.name); - } - - if (found_func == nullptr) { - error_param = dependency.mod_id + ":" + imported_func.base.name; - return ModLoadError::InvalidImport; - } - - handle.code_handle->set_imported_function_pointer(import_index, found_func); - } - - for (size_t reference_sym_index = 0; reference_sym_index < handle.recompiler_context.num_regular_reference_symbols(); reference_sym_index++) { - const N64Recomp::ReferenceSymbol& reference_sym = handle.recompiler_context.get_regular_reference_symbol(reference_sym_index); - uint32_t reference_section_vrom = handle.recompiler_context.get_reference_section_rom(reference_sym.section_index); - uint32_t reference_section_vram = handle.recompiler_context.get_reference_section_vram(reference_sym.section_index); - uint32_t reference_symbol_vram = reference_section_vram + reference_sym.section_offset; - - recomp_func_t* found_func = recomp::overlays::get_func_by_section_ram(reference_section_vrom, reference_symbol_vram); - - if (found_func == nullptr) { - std::stringstream error_param_stream{}; - error_param_stream << std::hex << - "section: 0x" << reference_section_vrom << - " func: 0x" << std::setfill('0') << std::setw(8) << reference_symbol_vram; - error_param = error_param_stream.str(); - return ModLoadError::InvalidReferenceSymbol; - } - - handle.code_handle->set_reference_symbol_pointer(reference_sym_index, found_func); - } - - // TODO event_indices - // TODO recomp_trigger_event - - handle.code_handle->set_get_function_pointer(get_function); - handle.code_handle->set_reference_section_addresses_pointer(section_addresses); - - for (size_t section_index = 0; section_index < section_load_addresses.size(); section_index++) { - handle.code_handle->set_local_section_address(section_index, section_load_addresses[section_index]); - } - - // Apply all the function replacements in the mod. - for (const auto& replacement : handle.recompiler_context.replacements) { - recomp_func_t* to_replace = recomp::overlays::get_func_by_section_ram(replacement.original_section_vrom, replacement.original_vram); - - if (to_replace == nullptr) { - std::stringstream error_param_stream{}; - error_param_stream << std::hex << - "section: 0x" << replacement.original_section_vrom << - " func: 0x" << std::setfill('0') << std::setw(8) << replacement.original_vram; - error_param = error_param_stream.str(); - return ModLoadError::InvalidFunctionReplacement; - } - - // Check if this function has already been replaced. - auto find_patch_it = patched_funcs.find(to_replace); - if (find_patch_it != patched_funcs.end()) { - error_param = find_patch_it->second.mod_id; - return ModLoadError::ModConflict; - } - - // Copy the original bytes so they can be restored later after the mod is unloaded. - PatchData& cur_replacement_data = patched_funcs[to_replace]; - memcpy(cur_replacement_data.replaced_bytes.data(), to_replace, cur_replacement_data.replaced_bytes.size()); - cur_replacement_data.mod_id = handle.manifest.mod_id; - - // Patch the function to redirect it to the replacement. - patch_func(to_replace, handle.code_handle->get_function_handle(replacement.func_index)); - } - - // TODO perform mips32 relocations + ram_used = cur_section_addr - load_address; return ModLoadError::Good; } @@ -379,8 +286,15 @@ std::vector recomp::mods::ModContext::load_mo const std::unordered_map& section_vrom_map = recomp::overlays::get_vrom_to_section_map(); - for (auto& mod : opened_mods) { + std::vector active_mods{}; + + // Find and load active mods. + for (size_t mod_index = 0; mod_index < opened_mods.size(); mod_index++) { + auto& mod = opened_mods[mod_index]; if (enabled_mods.contains(mod.manifest.mod_id)) { + active_mods.push_back(mod_index); + loaded_mods_by_id.emplace(mod.manifest.mod_id, mod_index); + printf("Loading mod %s\n", mod.manifest.mod_id.c_str()); uint32_t cur_ram_used = 0; std::string load_error_param; @@ -396,17 +310,226 @@ std::vector recomp::mods::ModContext::load_mo } } + // Exit early if errors were found. if (!ret.empty()) { - printf("Mod loading failed, unpatching funcs\n"); unload_mods(); + return ret; + } + + // Check that mod dependencies are met. + for (size_t mod_index : active_mods) { + auto& mod = opened_mods[mod_index]; + std::vector> cur_errors; + check_dependencies(mod, cur_errors); + + if (!cur_errors.empty()) { + for (auto const& [cur_error, cur_error_param] : cur_errors) { + ret.emplace_back(mod.manifest.mod_id, cur_error, cur_error_param); + } + } + } + + // Exit early if errors were found. + if (!ret.empty()) { + unload_mods(); + return ret; + } + + // Load the code and exports from all mods. + for (size_t mod_index : active_mods) { + auto& mod = opened_mods[mod_index]; + std::string cur_error_param; + ModLoadError cur_error = load_mod_code(mod, cur_error_param); + if (cur_error != ModLoadError::Good) { + ret.emplace_back(mod.manifest.mod_id, cur_error, cur_error_param); + } + } + + // Exit early if errors were found. + if (!ret.empty()) { + unload_mods(); + return ret; + } + + // Resolve dependencies for all mods. + for (size_t mod_index : active_mods) { + auto& mod = opened_mods[mod_index]; + std::string cur_error_param; + ModLoadError cur_error = resolve_dependencies(mod, cur_error_param); + if (cur_error != ModLoadError::Good) { + ret.emplace_back(mod.manifest.mod_id, cur_error, cur_error_param); + } + } + + // Exit early if errors were found. + if (!ret.empty()) { + unload_mods(); + return ret; } return ret; } +bool dependency_version_met(uint8_t major, uint8_t minor, uint8_t patch, uint8_t major_target, uint8_t minor_target, uint8_t patch_target) { + if (major > major_target) { + return true; + } + else if (major < major_target) { + return false; + } + + if (minor > minor_target) { + return true; + } + else if (minor < minor_target) { + return false; + } + + if (patch >= patch_target) { + return true; + } + return false; +} + +void recomp::mods::ModContext::check_dependencies(recomp::mods::ModHandle& mod, std::vector>& errors) { + errors.clear(); + for (N64Recomp::Dependency& cur_dep : mod.recompiler_context.dependencies) { + // Handle special dependency names. + if (cur_dep.mod_id == N64Recomp::DependencyBaseRecomp || cur_dep.mod_id == N64Recomp::DependencySelf) { + continue; + } + + // Look for the dependency in the loaded mod mapping. + auto find_it = loaded_mods_by_id.find(cur_dep.mod_id); + if (find_it == loaded_mods_by_id.end()) { + errors.emplace_back(ModLoadError::MissingDependency, cur_dep.mod_id); + continue; + } + + const auto& mod = opened_mods[find_it->second]; + if (!dependency_version_met( + mod.manifest.major_version, mod.manifest.minor_version, mod.manifest.patch_version, + cur_dep.major_version, cur_dep.minor_version, cur_dep.patch_version)) + { + std::stringstream error_param_stream{}; + error_param_stream << "requires mod \"" << cur_dep.mod_id << "\" " << + (int)cur_dep.major_version << "." << (int)cur_dep.minor_version << "." << (int)cur_dep.patch_version << ", got " << + (int)mod.manifest.major_version << "." << (int)mod.manifest.minor_version << "." << (int)mod.manifest.patch_version << ""; + errors.emplace_back(ModLoadError::WrongDependencyVersion, error_param_stream.str()); + } + } +} + +recomp::mods::ModLoadError recomp::mods::ModContext::load_mod_code(recomp::mods::ModHandle& mod, std::string& error_param) { + // TODO implement LuaJIT recompilation and allow it instead of native code loading via a mod manifest flag. + std::filesystem::path dll_path = mod.manifest.mod_root_path; + dll_path.replace_extension(".dll"); + mod.code_handle = std::make_unique(dll_path, mod.recompiler_context); + if (!mod.code_handle->good()) { + mod.code_handle.reset(); + error_param = dll_path.string(); + return ModLoadError::FailedToLoadNativeCode; + } + + // TODO exports + + // TODO events + + return ModLoadError::Good; +} + +recomp::mods::ModLoadError recomp::mods::ModContext::resolve_dependencies(recomp::mods::ModHandle& mod, std::string& error_param) { + // Reference symbols from the base recomp. + for (size_t reference_sym_index = 0; reference_sym_index < mod.recompiler_context.num_regular_reference_symbols(); reference_sym_index++) { + const N64Recomp::ReferenceSymbol& reference_sym = mod.recompiler_context.get_regular_reference_symbol(reference_sym_index); + uint32_t reference_section_vrom = mod.recompiler_context.get_reference_section_rom(reference_sym.section_index); + uint32_t reference_section_vram = mod.recompiler_context.get_reference_section_vram(reference_sym.section_index); + uint32_t reference_symbol_vram = reference_section_vram + reference_sym.section_offset; + + recomp_func_t* found_func = recomp::overlays::get_func_by_section_ram(reference_section_vrom, reference_symbol_vram); + + if (found_func == nullptr) { + std::stringstream error_param_stream{}; + error_param_stream << std::hex << + "section: 0x" << reference_section_vrom << + " func: 0x" << std::setfill('0') << std::setw(8) << reference_symbol_vram; + error_param = error_param_stream.str(); + return ModLoadError::InvalidReferenceSymbol; + } + + mod.code_handle->set_reference_symbol_pointer(reference_sym_index, found_func); + } + + // Imported symbols. + for (size_t import_index = 0; import_index < mod.recompiler_context.import_symbols.size(); import_index++) { + const N64Recomp::ImportSymbol& imported_func = mod.recompiler_context.import_symbols[import_index]; + const N64Recomp::Dependency& dependency = mod.recompiler_context.dependencies[imported_func.dependency_index]; + + recomp_func_t* found_func = nullptr; + + if (dependency.mod_id == N64Recomp::DependencyBaseRecomp) { + found_func = recomp::overlays::get_base_export(imported_func.base.name); + } + // TODO DependencySelf and other mods + + if (found_func == nullptr) { + error_param = dependency.mod_id + ":" + imported_func.base.name; + return ModLoadError::InvalidImport; + } + + mod.code_handle->set_imported_function_pointer(import_index, found_func); + } + + // TODO event_indices + // TODO recomp_trigger_event + + mod.code_handle->set_get_function_pointer(get_function); + mod.code_handle->set_reference_section_addresses_pointer(section_addresses); + + for (size_t section_index = 0; section_index < mod.section_load_addresses.size(); section_index++) { + mod.code_handle->set_local_section_address(section_index, mod.section_load_addresses[section_index]); + } + + // Apply all the function replacements in the mod. + for (const auto& replacement : mod.recompiler_context.replacements) { + recomp_func_t* to_replace = recomp::overlays::get_func_by_section_ram(replacement.original_section_vrom, replacement.original_vram); + + if (to_replace == nullptr) { + std::stringstream error_param_stream{}; + error_param_stream << std::hex << + "section: 0x" << replacement.original_section_vrom << + " func: 0x" << std::setfill('0') << std::setw(8) << replacement.original_vram; + error_param = error_param_stream.str(); + return ModLoadError::InvalidFunctionReplacement; + } + + // Check if this function has already been replaced. + auto find_patch_it = patched_funcs.find(to_replace); + if (find_patch_it != patched_funcs.end()) { + error_param = find_patch_it->second.mod_id; + return ModLoadError::ModConflict; + } + + // Copy the original bytes so they can be restored later after the mod is unloaded. + PatchData& cur_replacement_data = patched_funcs[to_replace]; + memcpy(cur_replacement_data.replaced_bytes.data(), to_replace, cur_replacement_data.replaced_bytes.size()); + cur_replacement_data.mod_id = mod.manifest.mod_id; + + // Patch the function to redirect it to the replacement. + patch_func(to_replace, mod.code_handle->get_function_handle(replacement.func_index)); + } + + // TODO perform mips32 relocations + + // TODO hook up callbacks + + return ModLoadError::Good; +} + void recomp::mods::ModContext::unload_mods() { for (auto& [replacement_func, replacement_data] : patched_funcs) { unpatch_func(replacement_func, replacement_data); } patched_funcs.clear(); + loaded_mods_by_id.clear(); }