Add support for exporting functions from native libraries in mods

This commit is contained in:
Mr-Wiseguy 2024-09-01 20:05:32 -04:00
parent 073bd0e4c2
commit 37ce5c2c1c
3 changed files with 265 additions and 91 deletions

View file

@ -51,6 +51,8 @@ namespace recomp {
FailedToLoadSyms,
FailedToLoadBinary,
FailedToLoadNativeCode,
FailedToLoadNativeLibrary,
FailedToFindNativeExport,
InvalidReferenceSymbol,
InvalidImport,
InvalidCallbackEvent,
@ -60,6 +62,9 @@ namespace recomp {
MissingDependency,
WrongDependencyVersion,
ModConflict,
DuplicateExport,
NoSpecifiedApiVersion,
UnsupportedApiVersion,
};
std::string error_to_string(ModLoadError);
@ -93,6 +98,11 @@ namespace recomp {
bool file_exists(const std::string& filepath) const final;
};
struct NativeLibraryManifest {
std::string name;
std::vector<std::string> exports;
};
struct ModManifest {
std::filesystem::path mod_root_path;
@ -107,6 +117,7 @@ namespace recomp {
std::string binary_syms_path;
std::string rom_patch_path;
std::string rom_patch_syms_path;
std::vector<NativeLibraryManifest> native_libraries;
std::unique_ptr<ModFileHandle> file_handle;
};
@ -170,6 +181,7 @@ namespace recomp {
public:
virtual ~ModCodeHandle() {}
virtual bool good() = 0;
virtual uint32_t get_api_version() = 0;
virtual void set_imported_function(size_t import_index, GenericFunction func) = 0;
virtual void set_reference_symbol_pointer(size_t symbol_index, recomp_func_t* ptr) = 0;
virtual void set_base_event_index(uint32_t global_event_index) = 0;
@ -181,6 +193,7 @@ namespace recomp {
virtual GenericFunction get_function_handle(size_t func_index) = 0;
};
class DynamicLibrary;
class ModHandle {
public:
// TODO make these private and expose methods for the functionality they're currently used in.
@ -203,19 +216,24 @@ namespace recomp {
bool get_export_function(const std::string& export_name, GenericFunction& out) const;
ModLoadError populate_events(size_t base_event_index, std::string& error_param);
bool get_global_event_index(const std::string& event_name, size_t& event_index_out) const;
ModLoadError load_native_library(const NativeLibraryManifest& lib_manifest, std::string& error_param);
private:
// Mapping of export name to function index.
std::unordered_map<std::string, size_t> exports_by_name;
// Mapping of export name to native library function pointer.
std::unordered_map<std::string, recomp_func_t*> native_library_exports;
// Mapping of event name to local index.
std::unordered_map<std::string, size_t> events_by_name;
// Loaded dynamic libraries.
std::vector<std::unique_ptr<DynamicLibrary>> native_libraries; // Vector of pointers so that implementation can be elsewhere.
};
class DynamicLibrary;
class NativeCodeHandle : public ModCodeHandle {
public:
NativeCodeHandle(const std::filesystem::path& dll_path, const N64Recomp::Context& context);
~NativeCodeHandle() = default;
bool good() final;
uint32_t get_api_version() final;
void set_imported_function(size_t import_index, GenericFunction func) final;
void set_reference_symbol_pointer(size_t symbol_index, recomp_func_t* ptr) final {
reference_symbol_funcs[symbol_index] = ptr;
@ -258,6 +276,7 @@ namespace recomp {
void setup_events(size_t num_events);
void register_event_callback(size_t event_index, GenericFunction callback);
void reset_events();
ModLoadError validate_api_version(uint32_t api_version, std::string& error_param);
}
};

View file

@ -138,7 +138,8 @@ enum class ManifestField {
BinaryPath,
BinarySymsPath,
RomPatchPath,
RomPatchSymsPath
RomPatchSymsPath,
NativeLibraryPaths,
};
const std::string mod_id_key = "id";
@ -149,16 +150,18 @@ const std::string binary_path_key = "binary";
const std::string binary_syms_path_key = "binary_syms";
const std::string rom_patch_path_key = "rom_patch";
const std::string rom_patch_syms_path_key = "rom_patch_syms";
const std::string native_library_paths_key = "native_libraries";
std::unordered_map<std::string, ManifestField> field_map {
{ mod_id_key, ManifestField::Id },
{ major_version_key, ManifestField::MajorVersion },
{ minor_version_key, ManifestField::MinorVersion },
{ patch_version_key, ManifestField::PatchVersion },
{ binary_path_key, ManifestField::BinaryPath },
{ binary_syms_path_key, ManifestField::BinarySymsPath },
{ rom_patch_path_key, ManifestField::RomPatchPath },
{ rom_patch_syms_path_key, ManifestField::RomPatchSymsPath },
{ mod_id_key, ManifestField::Id },
{ major_version_key, ManifestField::MajorVersion },
{ minor_version_key, ManifestField::MinorVersion },
{ patch_version_key, ManifestField::PatchVersion },
{ binary_path_key, ManifestField::BinaryPath },
{ binary_syms_path_key, ManifestField::BinarySymsPath },
{ rom_patch_path_key, ManifestField::RomPatchPath },
{ rom_patch_syms_path_key, ManifestField::RomPatchSymsPath },
{ native_library_paths_key, ManifestField::NativeLibraryPaths },
};
template <typename T1, typename T2>
@ -172,6 +175,28 @@ bool get_to(const nlohmann::json& val, T2& out) {
return true;
}
template <typename T1, typename T2>
bool get_to_vec(const nlohmann::json& val, std::vector<T2>& out) {
const nlohmann::json::array_t* ptr = val.get_ptr<const nlohmann::json::array_t*>();
if (ptr == nullptr) {
return false;
}
out.clear();
for (const nlohmann::json& cur_val : *ptr) {
const T1* temp_ptr = cur_val.get_ptr<const T1*>();
if (temp_ptr == nullptr) {
out.clear();
return false;
}
out.emplace_back(*temp_ptr);
}
return true;
}
recomp::mods::ModOpenError parse_manifest(recomp::mods::ModManifest& ret, const std::vector<char>& manifest_data, std::string& error_param) {
using json = nlohmann::json;
json manifest_json = json::parse(manifest_data.begin(), manifest_data.end(), nullptr, false);
@ -242,6 +267,23 @@ recomp::mods::ModOpenError parse_manifest(recomp::mods::ModManifest& ret, const
return recomp::mods::ModOpenError::IncorrectManifestFieldType;
}
break;
case ManifestField::NativeLibraryPaths:
{
if (!val.is_object()) {
error_param = key;
return recomp::mods::ModOpenError::IncorrectManifestFieldType;
}
for (const auto& [lib_name, lib_exports] : val.items()) {
recomp::mods::NativeLibraryManifest& cur_lib = ret.native_libraries.emplace_back();
cur_lib.name = lib_name;
if (!get_to_vec<std::string>(lib_exports, cur_lib.exports)) {
error_param = key;
return recomp::mods::ModOpenError::IncorrectManifestFieldType;
}
}
}
break;
}
}
@ -413,7 +455,7 @@ std::string recomp::mods::error_to_string(ModOpenError error) {
case ModOpenError::DuplicateMod:
return "Duplicate mod found";
}
return "Unknown error " + std::to_string((int)error);
return "Unknown mod opening error: " + std::to_string((int)error);
}
std::string recomp::mods::error_to_string(ModLoadError error) {
@ -425,7 +467,11 @@ std::string recomp::mods::error_to_string(ModLoadError error) {
case ModLoadError::FailedToLoadBinary:
return "Failed to load mod binary file";
case ModLoadError::FailedToLoadNativeCode:
return "Failed to load mod DLL";
return "Failed to load mod code DLL";
case ModLoadError::FailedToLoadNativeLibrary:
return "Failed to load mod library DLL";
case ModLoadError::FailedToFindNativeExport:
return "Failed to find native export";
case ModLoadError::InvalidReferenceSymbol:
return "Reference symbol does not exist";
case ModLoadError::InvalidImport:
@ -444,6 +490,12 @@ std::string recomp::mods::error_to_string(ModLoadError error) {
return "Wrong dependency version";
case ModLoadError::ModConflict:
return "Conflicts with other mod";
case ModLoadError::DuplicateExport:
return "Duplicate exports in mod";
case ModLoadError::NoSpecifiedApiVersion:
return "Mod DLL does not specify an API version";
case ModLoadError::UnsupportedApiVersion:
return "Mod DLL has an unsupported API version";
}
return "Unknown error " + std::to_string((int)error);
return "Unknown mod loading error " + std::to_string((int)error);
}

View file

@ -14,6 +14,102 @@
#define PATHFMT "%s"
#endif
template<class... Ts>
struct overloaded : Ts... { using Ts::operator()...; };
template<class... Ts>
overloaded(Ts...) -> overloaded<Ts...>;
#if defined(_WIN32)
# define WIN32_LEAN_AND_MEAN
# include "Windows.h"
class recomp::mods::DynamicLibrary {
public:
static constexpr std::string_view PlatformExtension = ".dll";
DynamicLibrary() = default;
DynamicLibrary(const std::filesystem::path& path) {
native_handle = LoadLibraryW(path.c_str());
if (good()) {
uint32_t* recomp_api_version;
if (get_dll_symbol(recomp_api_version, "recomp_api_version")) {
api_version = *recomp_api_version;
}
else {
api_version = (uint32_t)-1;
}
}
}
~DynamicLibrary() {
unload();
}
DynamicLibrary(const DynamicLibrary&) = delete;
DynamicLibrary& operator=(const DynamicLibrary&) = delete;
DynamicLibrary(DynamicLibrary&&) = delete;
DynamicLibrary& operator=(DynamicLibrary&&) = delete;
void unload() {
if (native_handle != nullptr) {
FreeLibrary(native_handle);
}
native_handle = nullptr;
}
bool good() const {
return native_handle != nullptr;
}
template <typename T>
bool get_dll_symbol(T& out, const char* name) const {
out = (T)GetProcAddress(native_handle, name);
if (out == nullptr) {
return false;
}
return true;
};
uint32_t get_api_version() {
return api_version;
}
private:
HMODULE native_handle;
uint32_t api_version;
};
void unprotect(void* target_func, uint64_t* old_flags) {
DWORD old_flags_dword;
BOOL result = VirtualProtect(target_func,
16,
PAGE_READWRITE,
&old_flags_dword);
*old_flags = old_flags_dword;
(void)result;
}
void protect(void* target_func, uint64_t old_flags) {
DWORD dummy_old_flags;
BOOL result = VirtualProtect(target_func,
16,
static_cast<DWORD>(old_flags),
&dummy_old_flags);
(void)result;
}
#else
# error "Mods not implemented yet on this platform"
#endif
recomp::mods::ModLoadError recomp::mods::validate_api_version(uint32_t api_version, std::string& error_param) {
switch (api_version) {
case 1:
return ModLoadError::Good;
case (size_t)-1:
return ModLoadError::NoSpecifiedApiVersion;
default:
error_param = std::to_string(api_version);
return ModLoadError::UnsupportedApiVersion;
}
}
recomp::mods::ModHandle::ModHandle(ModManifest&& manifest) :
manifest(std::move(manifest)),
code_handle(),
@ -34,7 +130,6 @@ size_t recomp::mods::ModHandle::num_events() const {
return recompiler_context->event_symbols.size();
}
recomp::mods::ModLoadError recomp::mods::ModHandle::populate_exports(std::string& error_param) {
for (size_t func_index : recompiler_context->exported_funcs) {
const auto& func_handle = recompiler_context->functions[func_index];
@ -44,14 +139,64 @@ recomp::mods::ModLoadError recomp::mods::ModHandle::populate_exports(std::string
return ModLoadError::Good;
}
bool recomp::mods::ModHandle::get_export_function(const std::string& export_name, GenericFunction& out) const {
auto find_it = exports_by_name.find(export_name);
if (find_it == exports_by_name.end()) {
return false;
recomp::mods::ModLoadError recomp::mods::ModHandle::load_native_library(const recomp::mods::NativeLibraryManifest& lib_manifest, std::string& error_param) {
std::string lib_filename = lib_manifest.name + std::string{DynamicLibrary::PlatformExtension};
std::filesystem::path lib_path = manifest.mod_root_path.parent_path() / lib_filename;
std::unique_ptr<DynamicLibrary>& lib = native_libraries.emplace_back(std::make_unique<DynamicLibrary>(lib_path));
if (!lib->good()) {
error_param = lib_filename;
return ModLoadError::FailedToLoadNativeLibrary;
}
std::string api_error_param;
ModLoadError api_error = validate_api_version(lib->get_api_version(), api_error_param);
if (api_error != ModLoadError::Good) {
if (api_error_param.empty()) {
error_param = lib_filename;
}
else {
error_param = lib_filename + ":" + api_error_param;
}
return api_error;
}
out = code_handle->get_function_handle(find_it->second);
return true;
for (const std::string& export_name : lib_manifest.exports) {
recomp_func_t* cur_func;
if (native_library_exports.contains(export_name)) {
error_param = export_name;
return ModLoadError::DuplicateExport;
}
if (!lib->get_dll_symbol(cur_func, export_name.c_str())) {
error_param = lib_manifest.name + ":" + export_name;
return ModLoadError::FailedToFindNativeExport;
}
native_library_exports.emplace(export_name, cur_func);
}
return ModLoadError::Good;
}
bool recomp::mods::ModHandle::get_export_function(const std::string& export_name, GenericFunction& out) const {
// First, check the code exports.
auto code_find_it = exports_by_name.find(export_name);
if (code_find_it != exports_by_name.end()) {
out = code_handle->get_function_handle(code_find_it->second);
return true;
}
// Next, check the native library exports.
auto native_find_it = native_library_exports.find(export_name);
if (native_find_it != native_library_exports.end()) {
out = native_find_it->second;
return true;
}
// Nothing found.
return false;
}
recomp::mods::ModLoadError recomp::mods::ModHandle::populate_events(size_t base_event_index, std::string& error_param) {
@ -74,74 +219,6 @@ bool recomp::mods::ModHandle::get_global_event_index(const std::string& event_na
return true;
}
template<class... Ts>
struct overloaded : Ts... { using Ts::operator()...; };
template<class... Ts>
overloaded(Ts...) -> overloaded<Ts...>;
#if defined(_WIN32)
# define WIN32_LEAN_AND_MEAN
# include "Windows.h"
class recomp::mods::DynamicLibrary {
public:
DynamicLibrary() = default;
DynamicLibrary(const std::filesystem::path& path) {
mod_dll = LoadLibraryW(path.c_str());
}
~DynamicLibrary() {
unload();
}
DynamicLibrary(const DynamicLibrary&) = delete;
DynamicLibrary& operator=(const DynamicLibrary&) = delete;
DynamicLibrary(DynamicLibrary&&) = delete;
DynamicLibrary& operator=(DynamicLibrary&&) = delete;
void unload() {
if (mod_dll != nullptr) {
FreeLibrary(mod_dll);
}
mod_dll = nullptr;
}
bool good() {
return mod_dll != nullptr;
}
template <typename T>
bool get_dll_symbol(T& out, const char* name) const {
out = (T)GetProcAddress(mod_dll, name);
if (out == nullptr) {
return false;
}
return true;
};
private:
HMODULE mod_dll;
};
void unprotect(void* target_func, uint64_t* old_flags) {
DWORD old_flags_dword;
BOOL result = VirtualProtect(target_func,
16,
PAGE_READWRITE,
&old_flags_dword);
*old_flags = old_flags_dword;
(void)result;
}
void protect(void* target_func, uint64_t old_flags) {
DWORD dummy_old_flags;
BOOL result = VirtualProtect(target_func,
16,
static_cast<DWORD>(old_flags),
&dummy_old_flags);
(void)result;
}
#else
# error "Mods not implemented yet on this platform"
#endif
recomp::mods::NativeCodeHandle::NativeCodeHandle(const std::filesystem::path& dll_path, const N64Recomp::Context& context) {
// Load the DLL.
dynamic_lib = std::make_unique<DynamicLibrary>(dll_path);
@ -180,6 +257,10 @@ bool recomp::mods::NativeCodeHandle::good() {
return dynamic_lib->good() && is_good;
}
uint32_t recomp::mods::NativeCodeHandle::get_api_version() {
return dynamic_lib->get_api_version();
}
void recomp::mods::NativeCodeHandle::set_bad() {
dynamic_lib.reset();
is_good = false;
@ -469,7 +550,7 @@ void recomp::mods::ModContext::check_dependencies(recomp::mods::ModHandle& mod,
recomp::mods::ModLoadError recomp::mods::ModContext::load_mod_code(recomp::mods::ModHandle& mod, std::string& error_param) {
// TODO implement LuaJIT recompilation and allow it instead of native code loading via a mod manifest flag.
std::filesystem::path dll_path = mod.manifest.mod_root_path;
dll_path.replace_extension(".dll");
dll_path.replace_extension(DynamicLibrary::PlatformExtension);
mod.code_handle = std::make_unique<NativeCodeHandle>(dll_path, *mod.recompiler_context);
if (!mod.code_handle->good()) {
mod.code_handle.reset();
@ -477,15 +558,37 @@ recomp::mods::ModLoadError recomp::mods::ModContext::load_mod_code(recomp::mods:
return ModLoadError::FailedToLoadNativeCode;
}
// Populate the mod's export map.
std::string cur_error_param;
ModLoadError cur_error = mod.populate_exports(cur_error_param);
ModLoadError cur_error = validate_api_version(mod.code_handle->get_api_version(), cur_error_param);
if (cur_error != ModLoadError::Good) {
if (cur_error_param.empty()) {
error_param = dll_path.filename().string();
}
else {
error_param = dll_path.filename().string() + ":" + std::move(cur_error_param);
}
return cur_error;
}
// Populate the mod's export map.
cur_error = mod.populate_exports(cur_error_param);
if (cur_error != ModLoadError::Good) {
error_param = std::move(cur_error_param);
return cur_error;
}
// Load any native libraries specified by the mod and validate/register the expors.
std::filesystem::path parent_path = mod.manifest.mod_root_path.parent_path();
for (const recomp::mods::NativeLibraryManifest& cur_lib_manifest: mod.manifest.native_libraries) {
cur_error = mod.load_native_library(cur_lib_manifest, cur_error_param);
if (cur_error != ModLoadError::Good) {
error_param = std::move(cur_error_param);
return cur_error;
}
}
// Populate the mod's event map and set its base event index.
cur_error = mod.populate_events(num_events, cur_error_param);