Update N64Recomp for new modding infrastructure, begin adding support for offline compiled mods

This commit is contained in:
Mr-Wiseguy 2024-08-28 01:12:18 -04:00
parent 7e063bfdae
commit 60a25faebe
7 changed files with 300 additions and 103 deletions

@ -1 +1 @@
Subproject commit a9f9b0ddc4c7f860017d247c02d8507ca301352f Subproject commit 131157dad85953c3f484cea5a67c61c49f38d0b3

View file

@ -48,8 +48,8 @@ if (WIN32)
add_compile_definitions(NOMINMAX) add_compile_definitions(NOMINMAX)
endif() endif()
add_subdirectory(${PROJECT_SOURCE_DIR}/../thirdparty/miniz ${CMAKE_BINARY_DIR}/miniz) add_subdirectory(${PROJECT_SOURCE_DIR}/../thirdparty/miniz ${CMAKE_CURRENT_BINARY_DIR}/miniz)
add_subdirectory(${PROJECT_SOURCE_DIR}/../N64Recomp ${CMAKE_BINARY_DIR}/N64Recomp EXCLUDE_FROM_ALL) add_subdirectory(${PROJECT_SOURCE_DIR}/../N64Recomp ${CMAKE_CURRENT_BINARY_DIR}/N64Recomp EXCLUDE_FROM_ALL)
target_link_libraries(librecomp PRIVATE ultramodern N64Recomp) target_link_libraries(librecomp PRIVATE ultramodern N64Recomp)
target_link_libraries(librecomp PUBLIC miniz) target_link_libraries(librecomp PUBLIC miniz)

View file

@ -19,6 +19,7 @@
#include "miniz_zip.h" #include "miniz_zip.h"
#include "librecomp/recomp.h" #include "librecomp/recomp.h"
#include "librecomp/sections.h"
namespace recomp { namespace recomp {
namespace mods { namespace mods {
@ -44,6 +45,9 @@ namespace recomp {
Good, Good,
FailedToLoadSyms, FailedToLoadSyms,
FailedToLoadBinary, FailedToLoadBinary,
FailedToLoadNativeCode,
InvalidReferenceSymbol,
InvalidImport,
InvalidFunctionReplacement, InvalidFunctionReplacement,
FailedToFindReplacement, FailedToFindReplacement,
ReplacementConflict, ReplacementConflict,
@ -129,6 +133,7 @@ namespace recomp {
ModContext(); ModContext();
~ModContext(); ~ModContext();
void setup_sections();
std::vector<ModOpenErrorDetails> scan_mod_folder(const std::filesystem::path& mod_folder); std::vector<ModOpenErrorDetails> scan_mod_folder(const std::filesystem::path& mod_folder);
void enable_mod(const std::string& mod_id, bool enabled); void enable_mod(const std::string& mod_id, bool enabled);
bool is_mod_enabled(const std::string& mod_id); bool is_mod_enabled(const std::string& mod_id);
@ -138,12 +143,14 @@ namespace recomp {
// const ModManifest& get_mod_manifest(size_t mod_index); // const ModManifest& get_mod_manifest(size_t mod_index);
private: private:
ModOpenError open_mod(const std::filesystem::path& mod_path, std::string& error_param); ModOpenError open_mod(const std::filesystem::path& mod_path, std::string& error_param);
ModLoadError load_mod(uint8_t* rdram, const std::unordered_map<uint32_t, uint16_t>& section_map, recomp::mods::ModHandle& handle, int32_t load_address, uint32_t& ram_used, std::string& error_param);
void add_opened_mod(ModManifest&& manifest); void add_opened_mod(ModManifest&& manifest);
std::vector<ModHandle> opened_mods; std::vector<ModHandle> opened_mods;
std::unordered_set<std::string> mod_ids; std::unordered_set<std::string> mod_ids;
std::unordered_set<std::string> enabled_mods; std::unordered_set<std::string> enabled_mods;
std::unordered_map<recomp_func_t*, PatchData> patched_funcs; std::unordered_map<recomp_func_t*, PatchData> patched_funcs;
std::unordered_map<uint32_t, size_t> sections_by_vrom;
}; };
} }
}; };

View file

@ -3,6 +3,8 @@
#include <cstdint> #include <cstdint>
#include <cstddef> #include <cstddef>
#include <string>
#include <unordered_map>
#include "sections.h" #include "sections.h"
namespace recomp { namespace recomp {
@ -21,10 +23,13 @@ namespace recomp {
void register_overlays(const overlay_section_table_data_t& sections, const overlays_by_index_t& overlays); void register_overlays(const overlay_section_table_data_t& sections, const overlays_by_index_t& overlays);
void register_patches(const char* patch_data, size_t patch_size, SectionTableEntry* code_sections, size_t num_sections); void register_patches(const char* patch_data, size_t patch_size, SectionTableEntry* code_sections, size_t num_sections);
void register_base_exports(const FunctionExport* exports);
void read_patch_data(uint8_t* rdram, gpr patch_data_address); void read_patch_data(uint8_t* rdram, gpr patch_data_address);
void init_overlays(); void init_overlays();
const std::unordered_map<uint32_t, uint16_t>& get_vrom_to_section_map();
recomp_func_t* get_func_by_section_ram(uint32_t section_rom, uint32_t function_vram); recomp_func_t* get_func_by_section_ram(uint32_t section_rom, uint32_t function_vram);
recomp_func_t* get_base_export(const std::string& export_name);
} }
}; };

View file

@ -20,4 +20,9 @@ typedef struct {
size_t index; size_t index;
} SectionTableEntry; } SectionTableEntry;
typedef struct {
const char* name;
uint32_t ram_addr;
} FunctionExport;
#endif #endif

View file

@ -1,5 +1,6 @@
#include <span> #include <span>
#include <fstream> #include <fstream>
#include <variant>
#define WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN
#include <Windows.h> #include <Windows.h>
@ -14,11 +15,130 @@
#define PATHFMT "%s" #define PATHFMT "%s"
#endif #endif
template<class... Ts>
struct overloaded : Ts... { using Ts::operator()...; };
template<class... Ts>
overloaded(Ts...) -> overloaded<Ts...>;
namespace recomp {
namespace mods {
using ModFunction = std::variant<recomp_func_t*>;
class ModCodeHandle {
public:
virtual ~ModCodeHandle() {}
virtual bool good() = 0;
virtual void set_imported_function_pointer(size_t import_index, recomp_func_t* ptr) = 0;
virtual void set_reference_symbol_pointer(size_t symbol_index, recomp_func_t* ptr) = 0;
virtual void set_event_index(size_t local_event_index, uint32_t global_event_index) = 0;
virtual void set_recomp_trigger_event_pointer(void (*ptr)(uint8_t* rdram, recomp_context* ctx, uint32_t index)) = 0;
virtual void set_get_function_pointer(recomp_func_t* (*ptr)(int32_t)) = 0;
virtual void set_reference_section_addresses_pointer(int32_t* ptr) = 0;
virtual void set_local_section_address(size_t section_index, int32_t address) = 0;
virtual ModFunction get_function_handle(size_t func_index) = 0;
};
class NativeCodeHandle : public ModCodeHandle {
public:
NativeCodeHandle(const std::filesystem::path& dll_path, const N64Recomp::Context& context) {
// Load the DLL.
mod_dll = LoadLibraryW(dll_path.c_str());
if (mod_dll == nullptr) {
set_bad();
return;
}
// Fill out the list of function pointers.
functions.resize(context.functions.size());
for (size_t i = 0; i < functions.size(); i++) {
std::string func_name = "mod_func_" + std::to_string(i);
functions[i] = (recomp_func_t*)GetProcAddress(mod_dll, func_name.c_str());
if (functions[i] == nullptr) {
set_bad();
return;
}
}
// Get the standard exported symbols.
get_dll_func(imported_funcs, "imported_funcs");
get_dll_func(reference_symbol_funcs, "reference_symbol_funcs");
get_dll_func(event_indices, "event_indices");
get_dll_func(recomp_trigger_event, "recomp_trigger_event");
get_dll_func(get_function, "get_function");
get_dll_func(reference_section_addresses, "reference_section_addresses");
get_dll_func(section_addresses, "section_addresses");
}
~NativeCodeHandle() = default;
bool good() final {
return mod_dll != nullptr;
}
void set_imported_function_pointer(size_t import_index, recomp_func_t* ptr) final {
imported_funcs[import_index] = ptr;
}
void set_reference_symbol_pointer(size_t symbol_index, recomp_func_t* ptr) final {
reference_symbol_funcs[symbol_index] = ptr;
};
void set_event_index(size_t local_event_index, uint32_t global_event_index) final {
event_indices[local_event_index] = global_event_index;
};
void set_recomp_trigger_event_pointer(void (*ptr)(uint8_t* rdram, recomp_context* ctx, uint32_t index)) final {
*recomp_trigger_event = ptr;
};
void set_get_function_pointer(recomp_func_t* (*ptr)(int32_t)) final {
*get_function = ptr;
};
void set_reference_section_addresses_pointer(int32_t* ptr) final {
*reference_section_addresses = ptr;
};
void set_local_section_address(size_t section_index, int32_t address) final {
section_addresses[section_index] = address;
};
ModFunction get_function_handle(size_t func_index) final {
return ModFunction{ functions[func_index] };
}
private:
template <typename T>
void get_dll_func(T& out, const char* name) {
out = (T)GetProcAddress(mod_dll, name);
if (out == nullptr) {
set_bad();
}
};
void set_bad() {
if (mod_dll) {
FreeLibrary(mod_dll);
}
mod_dll = nullptr;
}
HMODULE mod_dll;
std::vector<recomp_func_t*> functions;
recomp_func_t** imported_funcs;
recomp_func_t** reference_symbol_funcs;
uint32_t* event_indices;
void (**recomp_trigger_event)(uint8_t* rdram, recomp_context* ctx, uint32_t index);
recomp_func_t* (**get_function)(int32_t vram);
int32_t** reference_section_addresses;
int32_t* section_addresses;
};
struct ModHandle {
ModManifest manifest;
std::unique_ptr<ModCodeHandle> code_handle;
N64Recomp::Context recompiler_context;
ModHandle(ModManifest&& manifest) :
manifest(std::move(manifest)), code_handle(), recompiler_context{} {
}
};
}
}
void unprotect(void* target_func, DWORD* old_flags) { void unprotect(void* target_func, DWORD* old_flags) {
BOOL result = VirtualProtect(target_func, BOOL result = VirtualProtect(target_func,
16, 16,
PAGE_READWRITE, PAGE_READWRITE,
old_flags); old_flags);
(void)result;
} }
void protect(void* target_func, DWORD old_flags) { void protect(void* target_func, DWORD old_flags) {
@ -27,25 +147,32 @@ void protect(void* target_func, DWORD old_flags) {
16, 16,
old_flags, old_flags,
&dummy_old_flags); &dummy_old_flags);
(void)result;
} }
void patch_func(void* target_func, void* replacement_func) { void patch_func(recomp_func_t* target_func, recomp::mods::ModFunction replacement_func) {
static uint8_t movabs_rax[] = {0x48, 0xB8}; static const uint8_t movabs_rax[] = {0x48, 0xB8};
static uint8_t jmp_rax[] = {0xFF, 0xE0}; static const uint8_t jmp_rax[] = {0xFF, 0xE0};
uint8_t* target_func_u8 = reinterpret_cast<uint8_t*>(target_func); uint8_t* target_func_u8 = reinterpret_cast<uint8_t*>(target_func);
size_t offset = 0; size_t offset = 0;
auto write_bytes = [&](void* bytes, size_t count) { auto write_bytes = [&](const void* bytes, size_t count) {
memcpy(target_func_u8 + offset, bytes, count); memcpy(target_func_u8 + offset, bytes, count);
offset += count; offset += count;
}; };
DWORD old_flags; DWORD old_flags;
unprotect(target_func, &old_flags); unprotect(target_func_u8, &old_flags);
write_bytes(movabs_rax, sizeof(movabs_rax));
write_bytes(&replacement_func, sizeof(&replacement_func)); std::visit(overloaded {
write_bytes(jmp_rax, sizeof(jmp_rax)); [&write_bytes](recomp_func_t* native_func) {
protect(target_func, old_flags); write_bytes(movabs_rax, sizeof(movabs_rax));
write_bytes(&native_func, sizeof(&native_func));
write_bytes(jmp_rax, sizeof(jmp_rax));
}
}, replacement_func);
protect(target_func_u8, old_flags);
} }
void unpatch_func(void* target_func, const recomp::mods::PatchData& data) { void unpatch_func(void* target_func, const recomp::mods::PatchData& data) {
@ -55,25 +182,11 @@ void unpatch_func(void* target_func, const recomp::mods::PatchData& data) {
protect(target_func, old_flags); protect(target_func, old_flags);
} }
namespace recomp {
namespace mods {
struct ModHandle {
ModManifest manifest;
N64Recomp::Context recompiler_context;
N64Recomp::ModContext recompiler_mod_context;
// TODO temporary solution for loading mod DLLs, replace with LuaJIT recompilation (including patching LO16/HI16 relocs).
HMODULE mod_dll;
ModHandle(ModManifest&& manifest) : manifest(std::move(manifest)), recompiler_context{}, recompiler_mod_context{} {}
};
}
}
void recomp::mods::ModContext::add_opened_mod(ModManifest&& manifest) { void recomp::mods::ModContext::add_opened_mod(ModManifest&& manifest) {
opened_mods.emplace_back(std::move(manifest)); opened_mods.emplace_back(std::move(manifest));
} }
recomp::mods::ModLoadError load_mod(uint8_t* rdram, recomp::mods::ModHandle& handle, int32_t load_address, uint32_t& ram_used, std::string& error_param, std::unordered_map<recomp_func_t*, recomp::mods::PatchData>& patched_funcs) { recomp::mods::ModLoadError recomp::mods::ModContext::load_mod(uint8_t* rdram, const std::unordered_map<uint32_t, uint16_t>& section_vrom_map, recomp::mods::ModHandle& handle, int32_t load_address, uint32_t& ram_used, std::string& error_param) {
using namespace recomp::mods; using namespace recomp::mods;
std::vector<int32_t> section_load_addresses{}; std::vector<int32_t> section_load_addresses{};
@ -97,7 +210,7 @@ recomp::mods::ModLoadError load_mod(uint8_t* rdram, recomp::mods::ModHandle& han
std::span<uint8_t> binary_span {reinterpret_cast<uint8_t*>(binary_data.data()), binary_data.size() }; std::span<uint8_t> binary_span {reinterpret_cast<uint8_t*>(binary_data.data()), binary_data.size() };
// Parse the symbol file into the recompiler contexts. // Parse the symbol file into the recompiler contexts.
N64Recomp::ModSymbolsError symbol_load_error = N64Recomp::parse_mod_symbols(syms_data, binary_span, {}, handle.recompiler_context, handle.recompiler_mod_context); N64Recomp::ModSymbolsError symbol_load_error = N64Recomp::parse_mod_symbols(syms_data, binary_span, section_vrom_map, handle.recompiler_context);
if (symbol_load_error != N64Recomp::ModSymbolsError::Good) { if (symbol_load_error != N64Recomp::ModSymbolsError::Good) {
return ModLoadError::FailedToLoadSyms; return ModLoadError::FailedToLoadSyms;
} }
@ -118,64 +231,94 @@ recomp::mods::ModLoadError load_mod(uint8_t* rdram, recomp::mods::ModHandle& han
ram_used = cur_section_addr - load_address; ram_used = cur_section_addr - load_address;
} }
// TODO temporary solution for loading mod DLLs, replace with LuaJIT recompilation (including patching LO16/HI16 relocs). // TODO implement LuaJIT recompilation and allow it instead of native code loading via a mod manifest flag.
// N64Recomp::recompile_function(...);
std::filesystem::path dll_path = handle.manifest.mod_root_path; std::filesystem::path dll_path = handle.manifest.mod_root_path;
dll_path.replace_extension(".dll"); dll_path.replace_extension(".dll");
handle.mod_dll = LoadLibraryW(dll_path.c_str()); handle.code_handle = std::make_unique<NativeCodeHandle>(dll_path, handle.recompiler_context);
if (!handle.code_handle->good()) {
if (!handle.mod_dll) {
printf("Failed to open mod dll: %ls\n", dll_path.c_str()); printf("Failed to open mod dll: %ls\n", dll_path.c_str());
return ModLoadError::Good; handle.code_handle.reset();
return ModLoadError::FailedToLoadNativeCode;
} }
// TODO track replacements by mod to find conflicts // TODO dependency resolution
uint32_t total_func_count = 0;
for (size_t section_index = 0; section_index < handle.recompiler_context.sections.size(); section_index++) {
const auto& section = handle.recompiler_context.sections[section_index];
const auto& mod_section = handle.recompiler_mod_context.section_info[section_index];
// TODO check that section original_vrom is nonzero if it has replacements.
for (const auto& replacement : mod_section.replacements) {
recomp_func_t* to_replace = recomp::overlays::get_func_by_section_ram(mod_section.original_rom_addr, replacement.original_vram);
if (to_replace == nullptr) { // TODO imported_funcs from other mods
std::stringstream error_param_stream{}; for (size_t import_index = 0; import_index < handle.recompiler_context.import_symbols.size(); import_index++) {
error_param_stream << std::hex << const N64Recomp::ImportSymbol& imported_func = handle.recompiler_context.import_symbols[import_index];
"section: 0x" << mod_section.original_rom_addr << const N64Recomp::Dependency& dependency = handle.recompiler_context.dependencies[imported_func.dependency_index];
" func: 0x" << std::setfill('0') << std::setw(8) << replacement.original_vram;
error_param = error_param_stream.str();
return ModLoadError::InvalidFunctionReplacement;
}
uint32_t section_func_index = replacement.func_index; recomp_func_t* found_func = nullptr;
// TODO temporary solution for loading mod DLLs, replace with LuaJIT recompilation. if (dependency.mod_id == "*") {
std::string section_func_name = "mod_func_" + std::to_string(total_func_count + section_func_index); found_func = recomp::overlays::get_base_export(imported_func.base.name);
void* replacement_func = GetProcAddress(handle.mod_dll, section_func_name.c_str());
if (!replacement_func) {
printf("Failed to find func in dll: %s\n", section_func_name.c_str());
return ModLoadError::FailedToFindReplacement;
}
printf("found replacement func: 0x%016llX\n", (uintptr_t)to_replace);
// Check if this function has already been replaced.
auto find_patch_it = patched_funcs.find(to_replace);
if (find_patch_it != patched_funcs.end()) {
error_param = find_patch_it->second.mod_id;
return ModLoadError::ModConflict;
}
// Copy the original bytes so they can be restored later after the mod is unloaded.
PatchData& cur_replacement_data = patched_funcs[to_replace];
memcpy(cur_replacement_data.replaced_bytes.data(), to_replace, cur_replacement_data.replaced_bytes.size());
cur_replacement_data.mod_id = handle.manifest.mod_id;
// Patch the function to redirect it to the replacement.
patch_func(to_replace, replacement_func);
} }
total_func_count += mod_section.replacements.size();
if (found_func == nullptr) {
error_param = dependency.mod_id + ":" + imported_func.base.name;
return ModLoadError::InvalidImport;
}
handle.code_handle->set_imported_function_pointer(import_index, found_func);
}
for (size_t reference_sym_index = 0; reference_sym_index < handle.recompiler_context.num_regular_reference_symbols(); reference_sym_index++) {
const N64Recomp::ReferenceSymbol& reference_sym = handle.recompiler_context.get_regular_reference_symbol(reference_sym_index);
uint32_t reference_section_vrom = handle.recompiler_context.get_reference_section_rom(reference_sym.section_index);
uint32_t reference_section_vram = handle.recompiler_context.get_reference_section_vram(reference_sym.section_index);
uint32_t reference_symbol_vram = reference_section_vram + reference_sym.section_offset;
recomp_func_t* found_func = recomp::overlays::get_func_by_section_ram(reference_section_vrom, reference_symbol_vram);
if (found_func == nullptr) {
std::stringstream error_param_stream{};
error_param_stream << std::hex <<
"section: 0x" << reference_section_vrom <<
" func: 0x" << std::setfill('0') << std::setw(8) << reference_symbol_vram;
error_param = error_param_stream.str();
return ModLoadError::InvalidReferenceSymbol;
}
handle.code_handle->set_reference_symbol_pointer(reference_sym_index, found_func);
}
// TODO event_indices
// TODO recomp_trigger_event
handle.code_handle->set_get_function_pointer(get_function);
handle.code_handle->set_reference_section_addresses_pointer(section_addresses);
for (size_t section_index = 0; section_index < section_load_addresses.size(); section_index++) {
handle.code_handle->set_local_section_address(section_index, section_load_addresses[section_index]);
}
// Apply all the function replacements in the mod.
for (const auto& replacement : handle.recompiler_context.replacements) {
recomp_func_t* to_replace = recomp::overlays::get_func_by_section_ram(replacement.original_section_vrom, replacement.original_vram);
if (to_replace == nullptr) {
std::stringstream error_param_stream{};
error_param_stream << std::hex <<
"section: 0x" << replacement.original_section_vrom <<
" func: 0x" << std::setfill('0') << std::setw(8) << replacement.original_vram;
error_param = error_param_stream.str();
return ModLoadError::InvalidFunctionReplacement;
}
// Check if this function has already been replaced.
auto find_patch_it = patched_funcs.find(to_replace);
if (find_patch_it != patched_funcs.end()) {
error_param = find_patch_it->second.mod_id;
return ModLoadError::ModConflict;
}
// Copy the original bytes so they can be restored later after the mod is unloaded.
PatchData& cur_replacement_data = patched_funcs[to_replace];
memcpy(cur_replacement_data.replaced_bytes.data(), to_replace, cur_replacement_data.replaced_bytes.size());
cur_replacement_data.mod_id = handle.manifest.mod_id;
// Patch the function to redirect it to the replacement.
patch_func(to_replace, handle.code_handle->get_function_handle(replacement.func_index));
} }
// TODO perform mips32 relocations // TODO perform mips32 relocations
@ -234,12 +377,14 @@ std::vector<recomp::mods::ModLoadErrorDetails> recomp::mods::ModContext::load_mo
return {}; return {};
} }
const std::unordered_map<uint32_t, uint16_t>& section_vrom_map = recomp::overlays::get_vrom_to_section_map();
for (auto& mod : opened_mods) { for (auto& mod : opened_mods) {
if (enabled_mods.contains(mod.manifest.mod_id)) { if (enabled_mods.contains(mod.manifest.mod_id)) {
printf("Loading mod %s\n", mod.manifest.mod_id.c_str()); printf("Loading mod %s\n", mod.manifest.mod_id.c_str());
uint32_t cur_ram_used = 0; uint32_t cur_ram_used = 0;
std::string load_error_param; std::string load_error_param;
ModLoadError load_error = load_mod(rdram, mod, load_address, cur_ram_used, load_error_param, patched_funcs); ModLoadError load_error = load_mod(rdram, section_vrom_map, mod, load_address, cur_ram_used, load_error_param);
if (load_error != ModLoadError::Good) { if (load_error != ModLoadError::Good) {
ret.emplace_back(mod.manifest.mod_id, load_error, load_error_param); ret.emplace_back(mod.manifest.mod_id, load_error, load_error_param);

View file

@ -18,19 +18,6 @@ static SectionTableEntry* patch_code_sections = nullptr;
size_t num_patch_code_sections = 0; size_t num_patch_code_sections = 0;
static std::vector<char> patch_data; static std::vector<char> patch_data;
void recomp::overlays::register_overlays(const overlay_section_table_data_t& sections, const overlays_by_index_t& overlays) {
sections_info = sections;
overlays_info = overlays;
}
void recomp::overlays::register_patches(const char* patch, std::size_t size, SectionTableEntry* sections, size_t num_sections) {
patch_code_sections = sections;
num_patch_code_sections = num_sections;
patch_data.resize(size);
std::memcpy(patch_data.data(), patch, size);
}
struct LoadedSection { struct LoadedSection {
int32_t loaded_ram_addr; int32_t loaded_ram_addr;
size_t section_table_index; size_t section_table_index;
@ -45,8 +32,62 @@ struct LoadedSection {
} }
}; };
std::vector<LoadedSection> loaded_sections{}; static std::unordered_map<uint32_t, uint16_t> code_sections_by_rom{};
std::unordered_map<int32_t, recomp_func_t*> func_map{}; static std::vector<LoadedSection> loaded_sections{};
static std::unordered_map<int32_t, recomp_func_t*> func_map{};
static std::unordered_map<std::string, recomp_func_t*> base_exports{};
extern "C" {
int32_t* section_addresses = nullptr;
}
void recomp::overlays::register_overlays(const overlay_section_table_data_t& sections, const overlays_by_index_t& overlays) {
sections_info = sections;
overlays_info = overlays;
}
void recomp::overlays::register_patches(const char* patch, std::size_t size, SectionTableEntry* sections, size_t num_sections) {
patch_code_sections = sections;
num_patch_code_sections = num_sections;
patch_data.resize(size);
std::memcpy(patch_data.data(), patch, size);
}
void recomp::overlays::register_base_exports(const FunctionExport* export_list) {
std::unordered_map<uint32_t, recomp_func_t*> patch_func_vram_map{};
// Iterate over all patch functions to set up a mapping of their vram address.
for (size_t patch_section_index = 0; patch_section_index < num_patch_code_sections; patch_section_index++) {
const SectionTableEntry* cur_section = &patch_code_sections[patch_section_index];
for (size_t func_index = 0; func_index < cur_section->num_funcs; func_index++) {
const FuncEntry* cur_func = &cur_section->funcs[func_index];
patch_func_vram_map.emplace(cur_section->ram_addr + cur_func->offset, cur_func->func);
}
}
// Iterate over exports, using the vram mapping to create a name mapping.
for (const FunctionExport* cur_export = &export_list[0]; cur_export->name != nullptr; cur_export++) {
auto it = patch_func_vram_map.find(cur_export->ram_addr);
if (it == patch_func_vram_map.end()) {
assert(false && "Failed to find exported function in patch function sections!");
}
base_exports.emplace(cur_export->name, it->second);
}
}
recomp_func_t* recomp::overlays::get_base_export(const std::string& export_name) {
auto it = base_exports.find(export_name);
if (it == base_exports.end()) {
return nullptr;
}
return it->second;
}
const std::unordered_map<uint32_t, uint16_t>& recomp::overlays::get_vrom_to_section_map() {
return code_sections_by_rom;
}
void load_overlay(size_t section_table_index, int32_t ram) { void load_overlay(size_t section_table_index, int32_t ram) {
const SectionTableEntry& section = sections_info.code_sections[section_table_index]; const SectionTableEntry& section = sections_info.code_sections[section_table_index];
@ -83,10 +124,6 @@ void recomp::overlays::read_patch_data(uint8_t* rdram, gpr patch_data_address) {
} }
} }
extern "C" {
int32_t* section_addresses = nullptr;
}
extern "C" void load_overlays(uint32_t rom, int32_t ram_addr, uint32_t size) { extern "C" void load_overlays(uint32_t rom, int32_t ram_addr, uint32_t size) {
// Search for the first section that's included in the loaded rom range // Search for the first section that's included in the loaded rom range
// Sections were sorted by `init_overlays` so we can use the bounds functions // Sections were sorted by `init_overlays` so we can use the bounds functions
@ -173,8 +210,6 @@ extern "C" void unload_overlays(int32_t ram_addr, uint32_t size) {
} }
} }
std::unordered_map<uint32_t, SectionTableEntry*> sections_by_rom{};
void recomp::overlays::init_overlays() { void recomp::overlays::init_overlays() {
section_addresses = (int32_t *)calloc(sections_info.total_num_sections, sizeof(int32_t)); section_addresses = (int32_t *)calloc(sections_info.total_num_sections, sizeof(int32_t));
@ -189,19 +224,19 @@ void recomp::overlays::init_overlays() {
SectionTableEntry* code_section = &sections_info.code_sections[section_index]; SectionTableEntry* code_section = &sections_info.code_sections[section_index];
section_addresses[sections_info.code_sections[section_index].index] = code_section->ram_addr; section_addresses[sections_info.code_sections[section_index].index] = code_section->ram_addr;
sections_by_rom[code_section->rom_addr] = code_section; code_sections_by_rom[code_section->rom_addr] = section_index;
} }
load_patch_functions(); load_patch_functions();
} }
recomp_func_t* recomp::overlays::get_func_by_section_ram(uint32_t section_rom, uint32_t function_vram) { recomp_func_t* recomp::overlays::get_func_by_section_ram(uint32_t section_rom, uint32_t function_vram) {
auto find_section_it = sections_by_rom.find(section_rom); auto find_section_it = code_sections_by_rom.find(section_rom);
if (find_section_it == sections_by_rom.end()) { if (find_section_it == code_sections_by_rom.end()) {
return nullptr; return nullptr;
} }
SectionTableEntry* section = find_section_it->second; SectionTableEntry* section = &sections_info.code_sections[find_section_it->second];
if (function_vram < section->ram_addr || function_vram >= section->ram_addr + section->size) { if (function_vram < section->ram_addr || function_vram >= section->ram_addr + section->size) {
return nullptr; return nullptr;
} }