Implement mod loading on posix systems

This commit is contained in:
Mr-Wiseguy 2024-09-08 22:43:03 -04:00
parent b9592c625d
commit a7e5a7770f
5 changed files with 90 additions and 13 deletions

@ -1 +1 @@
Subproject commit b8dcb21dec80048b2530b4258b57a87a37343008
Subproject commit e098b0e018d8b5b5657a4cea0018652d00e96eff

View file

@ -83,7 +83,7 @@ namespace recomp {
virtual bool file_exists(const std::string& filepath) const = 0;
};
struct ZipModFileHandle : public ModFileHandle {
struct ZipModFileHandle final : public ModFileHandle {
FILE* file_handle = nullptr;
std::unique_ptr<mz_zip_archive> archive;
@ -95,7 +95,7 @@ namespace recomp {
bool file_exists(const std::string& filepath) const final;
};
struct LooseModFileHandle : public ModFileHandle {
struct LooseModFileHandle final : public ModFileHandle {
std::filesystem::path root_path;
LooseModFileHandle() = default;
@ -286,7 +286,7 @@ namespace recomp {
}
private:
void set_bad();
bool is_good;
bool is_good = false;
std::unique_ptr<DynamicLibrary> dynamic_lib;
std::vector<recomp_func_t*> functions;
recomp_func_t** imported_funcs;

View file

@ -24,7 +24,7 @@ recomp::mods::ZipModFileHandle::ZipModFileHandle(const std::filesystem::path& mo
return;
}
#else
file_handle = fopen(mod_path.c_str(), L"rb");
file_handle = fopen(mod_path.c_str(), "rb");
if (!file_handle) {
error = ModOpenError::FileError;
return;
@ -473,6 +473,8 @@ std::string recomp::mods::error_to_string(ModOpenError error) {
return "Invalid version string in manifest.json";
case ModOpenError::InvalidMinimumRecompVersionString:
return "Invalid minimum recomp version string in manifest.json";
case ModOpenError::InvalidDependencyString:
return "Invalid dependency string in manifest.json";
case ModOpenError::MissingManifestField:
return "Missing required field in manifest";
case ModOpenError::DuplicateMod:
@ -498,9 +500,9 @@ std::string recomp::mods::error_to_string(ModLoadError error) {
case ModLoadError::FailedToParseSyms:
return "Failed to parse mod symbol file";
case ModLoadError::FailedToLoadNativeCode:
return "Failed to load mod code DLL";
return "Failed to load offline mod library";
case ModLoadError::FailedToLoadNativeLibrary:
return "Failed to load mod library DLL";
return "Failed to load mod library";
case ModLoadError::FailedToFindNativeExport:
return "Failed to find native export";
case ModLoadError::InvalidReferenceSymbol:

View file

@ -1,7 +1,5 @@
#include <span>
#include <fstream>
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#include "librecomp/mods.hpp"
#include "librecomp/overlays.hpp"
@ -95,7 +93,82 @@ void protect(void* target_func, uint64_t old_flags) {
(void)result;
}
#else
# error "Mods not implemented yet on this platform"
# include <dlfcn.h>
# include <sys/mman.h>
class recomp::mods::DynamicLibrary {
public:
static constexpr std::string_view PlatformExtension = ".so";
DynamicLibrary() = default;
DynamicLibrary(const std::filesystem::path& path) {
native_handle = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL);
if (good()) {
uint32_t* recomp_api_version;
if (get_dll_symbol(recomp_api_version, "recomp_api_version")) {
api_version = *recomp_api_version;
}
else {
api_version = (uint32_t)-1;
}
}
}
~DynamicLibrary() {
unload();
}
DynamicLibrary(const DynamicLibrary&) = delete;
DynamicLibrary& operator=(const DynamicLibrary&) = delete;
DynamicLibrary(DynamicLibrary&&) = delete;
DynamicLibrary& operator=(DynamicLibrary&&) = delete;
void unload() {
if (native_handle != nullptr) {
dlclose(native_handle);
}
native_handle = nullptr;
}
bool good() const {
return native_handle != nullptr;
}
template <typename T>
bool get_dll_symbol(T& out, const char* name) const {
out = (T)dlsym(native_handle, name);
if (out == nullptr) {
return false;
}
return true;
};
uint32_t get_api_version() {
return api_version;
}
private:
void* native_handle;
uint32_t api_version;
};
void unprotect(void* target_func, uint64_t* old_flags) {
// Align the address to a page boundary.
uintptr_t page_start = (uintptr_t)target_func;
int page_size = getpagesize();
page_start = (page_start / page_size) * page_size;
int result = mprotect((void*)page_start, page_size, PROT_READ | PROT_WRITE);
*old_flags = 0;
(void)result;
}
void protect(void* target_func, uint64_t old_flags) {
// Align the address to a page boundary.
uintptr_t page_start = (uintptr_t)target_func;
int page_size = getpagesize();
page_start = (page_start / page_size) * page_size;
int result = mprotect((void*)page_start, page_size, PROT_READ | PROT_EXEC);
(void)result;
}
#endif
namespace modpaths {
@ -107,7 +180,7 @@ recomp::mods::ModLoadError recomp::mods::validate_api_version(uint32_t api_versi
switch (api_version) {
case 1:
return ModLoadError::Good;
case (size_t)-1:
case (uint32_t)-1:
return ModLoadError::NoSpecifiedApiVersion;
default:
error_param = std::to_string(api_version);
@ -226,6 +299,7 @@ bool recomp::mods::ModHandle::get_global_event_index(const std::string& event_na
}
recomp::mods::NativeCodeHandle::NativeCodeHandle(const std::filesystem::path& dll_path, const N64Recomp::Context& context) {
is_good = true;
// Load the DLL.
dynamic_lib = std::make_unique<DynamicLibrary>(dll_path);
if (!dynamic_lib->good()) {
@ -768,7 +842,7 @@ recomp::mods::ModLoadError recomp::mods::ModContext::resolve_dependencies(recomp
// Copy the original bytes so they can be restored later after the mod is unloaded.
PatchData& cur_replacement_data = patched_funcs[to_replace];
memcpy(cur_replacement_data.replaced_bytes.data(), to_replace, cur_replacement_data.replaced_bytes.size());
memcpy(cur_replacement_data.replaced_bytes.data(), reinterpret_cast<void*>(to_replace), cur_replacement_data.replaced_bytes.size());
cur_replacement_data.mod_id = mod.manifest.mod_id;
// Patch the function to redirect it to the replacement.
@ -782,7 +856,7 @@ recomp::mods::ModLoadError recomp::mods::ModContext::resolve_dependencies(recomp
void recomp::mods::ModContext::unload_mods() {
for (auto& [replacement_func, replacement_data] : patched_funcs) {
unpatch_func(replacement_func, replacement_data);
unpatch_func(reinterpret_cast<void*>(replacement_func), replacement_data);
}
patched_funcs.clear();
loaded_mods_by_id.clear();

View file

@ -12,6 +12,7 @@
#include <array>
#include <cinttypes>
#include <cuchar>
#include <charconv>
#include "librecomp/recomp.h"
#include "librecomp/overlays.hpp"