Reorganized mod loading, added mod dependency validation

This commit is contained in:
Mr-Wiseguy 2024-08-28 23:54:57 -04:00
parent 60a25faebe
commit e7afded4eb
3 changed files with 270 additions and 136 deletions

View file

@ -51,7 +51,8 @@ namespace recomp {
InvalidFunctionReplacement, InvalidFunctionReplacement,
FailedToFindReplacement, FailedToFindReplacement,
ReplacementConflict, ReplacementConflict,
MissingDependencies, MissingDependency,
WrongDependencyVersion,
ModConflict, ModConflict,
}; };
@ -144,13 +145,16 @@ namespace recomp {
private: private:
ModOpenError open_mod(const std::filesystem::path& mod_path, std::string& error_param); ModOpenError open_mod(const std::filesystem::path& mod_path, std::string& error_param);
ModLoadError load_mod(uint8_t* rdram, const std::unordered_map<uint32_t, uint16_t>& section_map, recomp::mods::ModHandle& handle, int32_t load_address, uint32_t& ram_used, std::string& error_param); ModLoadError load_mod(uint8_t* rdram, const std::unordered_map<uint32_t, uint16_t>& section_map, recomp::mods::ModHandle& handle, int32_t load_address, uint32_t& ram_used, std::string& error_param);
void check_dependencies(recomp::mods::ModHandle& mod, std::vector<std::pair<recomp::mods::ModLoadError, std::string>>& errors);
ModLoadError load_mod_code(recomp::mods::ModHandle& mod, std::string& error_param);
ModLoadError resolve_dependencies(recomp::mods::ModHandle& mod, std::string& error_param);
void add_opened_mod(ModManifest&& manifest); void add_opened_mod(ModManifest&& manifest);
std::vector<ModHandle> opened_mods; std::vector<ModHandle> opened_mods;
std::unordered_set<std::string> mod_ids; std::unordered_set<std::string> mod_ids;
std::unordered_set<std::string> enabled_mods; std::unordered_set<std::string> enabled_mods;
std::unordered_map<recomp_func_t*, PatchData> patched_funcs; std::unordered_map<recomp_func_t*, PatchData> patched_funcs;
std::unordered_map<uint32_t, size_t> sections_by_vrom; std::unordered_map<std::string, size_t> loaded_mods_by_id;
}; };
} }
}; };

View file

@ -138,8 +138,7 @@ enum class ManifestField {
BinaryPath, BinaryPath,
BinarySymsPath, BinarySymsPath,
RomPatchPath, RomPatchPath,
RomPatchSymsPath, RomPatchSymsPath
Invalid,
}; };
const std::string mod_id_key = "id"; const std::string mod_id_key = "id";
@ -425,14 +424,22 @@ std::string recomp::mods::error_to_string(ModLoadError error) {
return "Failed to load mod symbol file"; return "Failed to load mod symbol file";
case ModLoadError::FailedToLoadBinary: case ModLoadError::FailedToLoadBinary:
return "Failed to load mod binary file"; return "Failed to load mod binary file";
case ModLoadError::FailedToLoadNativeCode:
return "Failed to load mod DLL";
case ModLoadError::InvalidReferenceSymbol:
return "Reference symbol does not exist";
case ModLoadError::InvalidImport:
return "Imported function not found";
case ModLoadError::InvalidFunctionReplacement: case ModLoadError::InvalidFunctionReplacement:
return "Function to be replaced does not exist"; return "Function to be replaced does not exist";
case ModLoadError::FailedToFindReplacement: case ModLoadError::FailedToFindReplacement:
return "Failed to find replacement function"; return "Failed to find replacement function";
case ModLoadError::ReplacementConflict: case ModLoadError::ReplacementConflict:
return "Attempted to replace a function that cannot be replaced"; return "Attempted to replace a function that cannot be replaced";
case ModLoadError::MissingDependencies: case ModLoadError::MissingDependency:
return "Missing dependencies"; return "Missing dependency";
case ModLoadError::WrongDependencyVersion:
return "Wrong dependency version";
case ModLoadError::ModConflict: case ModLoadError::ModConflict:
return "Conflicts with other mod"; return "Conflicts with other mod";
} }

View file

@ -125,6 +125,7 @@ namespace recomp {
ModManifest manifest; ModManifest manifest;
std::unique_ptr<ModCodeHandle> code_handle; std::unique_ptr<ModCodeHandle> code_handle;
N64Recomp::Context recompiler_context; N64Recomp::Context recompiler_context;
std::vector<uint32_t> section_load_addresses;
ModHandle(ModManifest&& manifest) : ModHandle(ModManifest&& manifest) :
manifest(std::move(manifest)), code_handle(), recompiler_context{} { manifest(std::move(manifest)), code_handle(), recompiler_context{} {
@ -188,9 +189,8 @@ void recomp::mods::ModContext::add_opened_mod(ModManifest&& manifest) {
recomp::mods::ModLoadError recomp::mods::ModContext::load_mod(uint8_t* rdram, const std::unordered_map<uint32_t, uint16_t>& section_vrom_map, recomp::mods::ModHandle& handle, int32_t load_address, uint32_t& ram_used, std::string& error_param) { recomp::mods::ModLoadError recomp::mods::ModContext::load_mod(uint8_t* rdram, const std::unordered_map<uint32_t, uint16_t>& section_vrom_map, recomp::mods::ModHandle& handle, int32_t load_address, uint32_t& ram_used, std::string& error_param) {
using namespace recomp::mods; using namespace recomp::mods;
std::vector<int32_t> section_load_addresses{}; handle.section_load_addresses.clear();
{
// Load the mod symbol data from the file provided in the manifest. // Load the mod symbol data from the file provided in the manifest.
bool binary_syms_exists = false; bool binary_syms_exists = false;
std::vector<char> syms_data = handle.manifest.file_handle->read_file(handle.manifest.binary_syms_path, binary_syms_exists); std::vector<char> syms_data = handle.manifest.file_handle->read_file(handle.manifest.binary_syms_path, binary_syms_exists);
@ -215,7 +215,7 @@ recomp::mods::ModLoadError recomp::mods::ModContext::load_mod(uint8_t* rdram, co
return ModLoadError::FailedToLoadSyms; return ModLoadError::FailedToLoadSyms;
} }
section_load_addresses.resize(handle.recompiler_context.sections.size()); handle.section_load_addresses.resize(handle.recompiler_context.sections.size());
// Copy each section's binary into rdram, leaving room for the section's bss before the next one. // Copy each section's binary into rdram, leaving room for the section's bss before the next one.
int32_t cur_section_addr = load_address; int32_t cur_section_addr = load_address;
@ -224,104 +224,11 @@ recomp::mods::ModLoadError recomp::mods::ModContext::load_mod(uint8_t* rdram, co
for (size_t i = 0; i < section.size; i++) { for (size_t i = 0; i < section.size; i++) {
MEM_B(i, (gpr)cur_section_addr) = binary_data[section.rom_addr + i]; MEM_B(i, (gpr)cur_section_addr) = binary_data[section.rom_addr + i];
} }
section_load_addresses[section_index] = cur_section_addr; handle.section_load_addresses[section_index] = cur_section_addr;
cur_section_addr += section.size + section.bss_size; cur_section_addr += section.size + section.bss_size;
} }
ram_used = cur_section_addr - load_address; ram_used = cur_section_addr - load_address;
}
// TODO implement LuaJIT recompilation and allow it instead of native code loading via a mod manifest flag.
std::filesystem::path dll_path = handle.manifest.mod_root_path;
dll_path.replace_extension(".dll");
handle.code_handle = std::make_unique<NativeCodeHandle>(dll_path, handle.recompiler_context);
if (!handle.code_handle->good()) {
printf("Failed to open mod dll: %ls\n", dll_path.c_str());
handle.code_handle.reset();
return ModLoadError::FailedToLoadNativeCode;
}
// TODO dependency resolution
// TODO imported_funcs from other mods
for (size_t import_index = 0; import_index < handle.recompiler_context.import_symbols.size(); import_index++) {
const N64Recomp::ImportSymbol& imported_func = handle.recompiler_context.import_symbols[import_index];
const N64Recomp::Dependency& dependency = handle.recompiler_context.dependencies[imported_func.dependency_index];
recomp_func_t* found_func = nullptr;
if (dependency.mod_id == "*") {
found_func = recomp::overlays::get_base_export(imported_func.base.name);
}
if (found_func == nullptr) {
error_param = dependency.mod_id + ":" + imported_func.base.name;
return ModLoadError::InvalidImport;
}
handle.code_handle->set_imported_function_pointer(import_index, found_func);
}
for (size_t reference_sym_index = 0; reference_sym_index < handle.recompiler_context.num_regular_reference_symbols(); reference_sym_index++) {
const N64Recomp::ReferenceSymbol& reference_sym = handle.recompiler_context.get_regular_reference_symbol(reference_sym_index);
uint32_t reference_section_vrom = handle.recompiler_context.get_reference_section_rom(reference_sym.section_index);
uint32_t reference_section_vram = handle.recompiler_context.get_reference_section_vram(reference_sym.section_index);
uint32_t reference_symbol_vram = reference_section_vram + reference_sym.section_offset;
recomp_func_t* found_func = recomp::overlays::get_func_by_section_ram(reference_section_vrom, reference_symbol_vram);
if (found_func == nullptr) {
std::stringstream error_param_stream{};
error_param_stream << std::hex <<
"section: 0x" << reference_section_vrom <<
" func: 0x" << std::setfill('0') << std::setw(8) << reference_symbol_vram;
error_param = error_param_stream.str();
return ModLoadError::InvalidReferenceSymbol;
}
handle.code_handle->set_reference_symbol_pointer(reference_sym_index, found_func);
}
// TODO event_indices
// TODO recomp_trigger_event
handle.code_handle->set_get_function_pointer(get_function);
handle.code_handle->set_reference_section_addresses_pointer(section_addresses);
for (size_t section_index = 0; section_index < section_load_addresses.size(); section_index++) {
handle.code_handle->set_local_section_address(section_index, section_load_addresses[section_index]);
}
// Apply all the function replacements in the mod.
for (const auto& replacement : handle.recompiler_context.replacements) {
recomp_func_t* to_replace = recomp::overlays::get_func_by_section_ram(replacement.original_section_vrom, replacement.original_vram);
if (to_replace == nullptr) {
std::stringstream error_param_stream{};
error_param_stream << std::hex <<
"section: 0x" << replacement.original_section_vrom <<
" func: 0x" << std::setfill('0') << std::setw(8) << replacement.original_vram;
error_param = error_param_stream.str();
return ModLoadError::InvalidFunctionReplacement;
}
// Check if this function has already been replaced.
auto find_patch_it = patched_funcs.find(to_replace);
if (find_patch_it != patched_funcs.end()) {
error_param = find_patch_it->second.mod_id;
return ModLoadError::ModConflict;
}
// Copy the original bytes so they can be restored later after the mod is unloaded.
PatchData& cur_replacement_data = patched_funcs[to_replace];
memcpy(cur_replacement_data.replaced_bytes.data(), to_replace, cur_replacement_data.replaced_bytes.size());
cur_replacement_data.mod_id = handle.manifest.mod_id;
// Patch the function to redirect it to the replacement.
patch_func(to_replace, handle.code_handle->get_function_handle(replacement.func_index));
}
// TODO perform mips32 relocations
return ModLoadError::Good; return ModLoadError::Good;
} }
@ -379,8 +286,15 @@ std::vector<recomp::mods::ModLoadErrorDetails> recomp::mods::ModContext::load_mo
const std::unordered_map<uint32_t, uint16_t>& section_vrom_map = recomp::overlays::get_vrom_to_section_map(); const std::unordered_map<uint32_t, uint16_t>& section_vrom_map = recomp::overlays::get_vrom_to_section_map();
for (auto& mod : opened_mods) { std::vector<size_t> active_mods{};
// Find and load active mods.
for (size_t mod_index = 0; mod_index < opened_mods.size(); mod_index++) {
auto& mod = opened_mods[mod_index];
if (enabled_mods.contains(mod.manifest.mod_id)) { if (enabled_mods.contains(mod.manifest.mod_id)) {
active_mods.push_back(mod_index);
loaded_mods_by_id.emplace(mod.manifest.mod_id, mod_index);
printf("Loading mod %s\n", mod.manifest.mod_id.c_str()); printf("Loading mod %s\n", mod.manifest.mod_id.c_str());
uint32_t cur_ram_used = 0; uint32_t cur_ram_used = 0;
std::string load_error_param; std::string load_error_param;
@ -396,17 +310,226 @@ std::vector<recomp::mods::ModLoadErrorDetails> recomp::mods::ModContext::load_mo
} }
} }
// Exit early if errors were found.
if (!ret.empty()) { if (!ret.empty()) {
printf("Mod loading failed, unpatching funcs\n");
unload_mods(); unload_mods();
return ret;
}
// Check that mod dependencies are met.
for (size_t mod_index : active_mods) {
auto& mod = opened_mods[mod_index];
std::vector<std::pair<ModLoadError, std::string>> cur_errors;
check_dependencies(mod, cur_errors);
if (!cur_errors.empty()) {
for (auto const& [cur_error, cur_error_param] : cur_errors) {
ret.emplace_back(mod.manifest.mod_id, cur_error, cur_error_param);
}
}
}
// Exit early if errors were found.
if (!ret.empty()) {
unload_mods();
return ret;
}
// Load the code and exports from all mods.
for (size_t mod_index : active_mods) {
auto& mod = opened_mods[mod_index];
std::string cur_error_param;
ModLoadError cur_error = load_mod_code(mod, cur_error_param);
if (cur_error != ModLoadError::Good) {
ret.emplace_back(mod.manifest.mod_id, cur_error, cur_error_param);
}
}
// Exit early if errors were found.
if (!ret.empty()) {
unload_mods();
return ret;
}
// Resolve dependencies for all mods.
for (size_t mod_index : active_mods) {
auto& mod = opened_mods[mod_index];
std::string cur_error_param;
ModLoadError cur_error = resolve_dependencies(mod, cur_error_param);
if (cur_error != ModLoadError::Good) {
ret.emplace_back(mod.manifest.mod_id, cur_error, cur_error_param);
}
}
// Exit early if errors were found.
if (!ret.empty()) {
unload_mods();
return ret;
} }
return ret; return ret;
} }
bool dependency_version_met(uint8_t major, uint8_t minor, uint8_t patch, uint8_t major_target, uint8_t minor_target, uint8_t patch_target) {
if (major > major_target) {
return true;
}
else if (major < major_target) {
return false;
}
if (minor > minor_target) {
return true;
}
else if (minor < minor_target) {
return false;
}
if (patch >= patch_target) {
return true;
}
return false;
}
void recomp::mods::ModContext::check_dependencies(recomp::mods::ModHandle& mod, std::vector<std::pair<recomp::mods::ModLoadError, std::string>>& errors) {
errors.clear();
for (N64Recomp::Dependency& cur_dep : mod.recompiler_context.dependencies) {
// Handle special dependency names.
if (cur_dep.mod_id == N64Recomp::DependencyBaseRecomp || cur_dep.mod_id == N64Recomp::DependencySelf) {
continue;
}
// Look for the dependency in the loaded mod mapping.
auto find_it = loaded_mods_by_id.find(cur_dep.mod_id);
if (find_it == loaded_mods_by_id.end()) {
errors.emplace_back(ModLoadError::MissingDependency, cur_dep.mod_id);
continue;
}
const auto& mod = opened_mods[find_it->second];
if (!dependency_version_met(
mod.manifest.major_version, mod.manifest.minor_version, mod.manifest.patch_version,
cur_dep.major_version, cur_dep.minor_version, cur_dep.patch_version))
{
std::stringstream error_param_stream{};
error_param_stream << "requires mod \"" << cur_dep.mod_id << "\" " <<
(int)cur_dep.major_version << "." << (int)cur_dep.minor_version << "." << (int)cur_dep.patch_version << ", got " <<
(int)mod.manifest.major_version << "." << (int)mod.manifest.minor_version << "." << (int)mod.manifest.patch_version << "";
errors.emplace_back(ModLoadError::WrongDependencyVersion, error_param_stream.str());
}
}
}
recomp::mods::ModLoadError recomp::mods::ModContext::load_mod_code(recomp::mods::ModHandle& mod, std::string& error_param) {
// TODO implement LuaJIT recompilation and allow it instead of native code loading via a mod manifest flag.
std::filesystem::path dll_path = mod.manifest.mod_root_path;
dll_path.replace_extension(".dll");
mod.code_handle = std::make_unique<NativeCodeHandle>(dll_path, mod.recompiler_context);
if (!mod.code_handle->good()) {
mod.code_handle.reset();
error_param = dll_path.string();
return ModLoadError::FailedToLoadNativeCode;
}
// TODO exports
// TODO events
return ModLoadError::Good;
}
recomp::mods::ModLoadError recomp::mods::ModContext::resolve_dependencies(recomp::mods::ModHandle& mod, std::string& error_param) {
// Reference symbols from the base recomp.
for (size_t reference_sym_index = 0; reference_sym_index < mod.recompiler_context.num_regular_reference_symbols(); reference_sym_index++) {
const N64Recomp::ReferenceSymbol& reference_sym = mod.recompiler_context.get_regular_reference_symbol(reference_sym_index);
uint32_t reference_section_vrom = mod.recompiler_context.get_reference_section_rom(reference_sym.section_index);
uint32_t reference_section_vram = mod.recompiler_context.get_reference_section_vram(reference_sym.section_index);
uint32_t reference_symbol_vram = reference_section_vram + reference_sym.section_offset;
recomp_func_t* found_func = recomp::overlays::get_func_by_section_ram(reference_section_vrom, reference_symbol_vram);
if (found_func == nullptr) {
std::stringstream error_param_stream{};
error_param_stream << std::hex <<
"section: 0x" << reference_section_vrom <<
" func: 0x" << std::setfill('0') << std::setw(8) << reference_symbol_vram;
error_param = error_param_stream.str();
return ModLoadError::InvalidReferenceSymbol;
}
mod.code_handle->set_reference_symbol_pointer(reference_sym_index, found_func);
}
// Imported symbols.
for (size_t import_index = 0; import_index < mod.recompiler_context.import_symbols.size(); import_index++) {
const N64Recomp::ImportSymbol& imported_func = mod.recompiler_context.import_symbols[import_index];
const N64Recomp::Dependency& dependency = mod.recompiler_context.dependencies[imported_func.dependency_index];
recomp_func_t* found_func = nullptr;
if (dependency.mod_id == N64Recomp::DependencyBaseRecomp) {
found_func = recomp::overlays::get_base_export(imported_func.base.name);
}
// TODO DependencySelf and other mods
if (found_func == nullptr) {
error_param = dependency.mod_id + ":" + imported_func.base.name;
return ModLoadError::InvalidImport;
}
mod.code_handle->set_imported_function_pointer(import_index, found_func);
}
// TODO event_indices
// TODO recomp_trigger_event
mod.code_handle->set_get_function_pointer(get_function);
mod.code_handle->set_reference_section_addresses_pointer(section_addresses);
for (size_t section_index = 0; section_index < mod.section_load_addresses.size(); section_index++) {
mod.code_handle->set_local_section_address(section_index, mod.section_load_addresses[section_index]);
}
// Apply all the function replacements in the mod.
for (const auto& replacement : mod.recompiler_context.replacements) {
recomp_func_t* to_replace = recomp::overlays::get_func_by_section_ram(replacement.original_section_vrom, replacement.original_vram);
if (to_replace == nullptr) {
std::stringstream error_param_stream{};
error_param_stream << std::hex <<
"section: 0x" << replacement.original_section_vrom <<
" func: 0x" << std::setfill('0') << std::setw(8) << replacement.original_vram;
error_param = error_param_stream.str();
return ModLoadError::InvalidFunctionReplacement;
}
// Check if this function has already been replaced.
auto find_patch_it = patched_funcs.find(to_replace);
if (find_patch_it != patched_funcs.end()) {
error_param = find_patch_it->second.mod_id;
return ModLoadError::ModConflict;
}
// Copy the original bytes so they can be restored later after the mod is unloaded.
PatchData& cur_replacement_data = patched_funcs[to_replace];
memcpy(cur_replacement_data.replaced_bytes.data(), to_replace, cur_replacement_data.replaced_bytes.size());
cur_replacement_data.mod_id = mod.manifest.mod_id;
// Patch the function to redirect it to the replacement.
patch_func(to_replace, mod.code_handle->get_function_handle(replacement.func_index));
}
// TODO perform mips32 relocations
// TODO hook up callbacks
return ModLoadError::Good;
}
void recomp::mods::ModContext::unload_mods() { void recomp::mods::ModContext::unload_mods() {
for (auto& [replacement_func, replacement_data] : patched_funcs) { for (auto& [replacement_func, replacement_data] : patched_funcs) {
unpatch_func(replacement_func, replacement_data); unpatch_func(replacement_func, replacement_data);
} }
patched_funcs.clear(); patched_funcs.clear();
loaded_mods_by_id.clear();
} }