diff --git a/src/pc/lua/smlua.c b/src/pc/lua/smlua.c index f93e4d22c..f4e17d372 100644 --- a/src/pc/lua/smlua.c +++ b/src/pc/lua/smlua.c @@ -7,7 +7,7 @@ static void smlua_exec_file(char* path) { lua_State* L = gLuaState; if (luaL_dofile(L, path) != LUA_OK) { LOG_LUA("LUA: Failed to load lua file '%s'.", path); - puts(lua_tostring(L, lua_gettop(L))); + puts(smlua_to_string(L, lua_gettop(L))); } lua_pop(L, lua_gettop(L)); } @@ -16,7 +16,7 @@ static void smlua_exec_str(char* str) { lua_State* L = gLuaState; if (luaL_dostring(L, str) != LUA_OK) { LOG_LUA("LUA: Failed to load lua string."); - puts(lua_tostring(L, lua_gettop(L))); + puts(smlua_to_string(L, lua_gettop(L))); } lua_pop(L, lua_gettop(L)); } @@ -25,7 +25,7 @@ static void smlua_load_script(char* path) { lua_State* L = gLuaState; if (luaL_loadfile(L, path) != LUA_OK) { LOG_LUA("LUA: Failed to load lua script '%s'.", path); - puts(lua_tostring(L, lua_gettop(L))); + puts(smlua_to_string(L, lua_gettop(L))); return; } @@ -45,7 +45,7 @@ static void smlua_load_script(char* path) { // run chunks if (lua_pcall(L, 0, LUA_MULTRET, 0) != LUA_OK) { LOG_LUA("LUA: Failed to execute lua script '%s'.", path); - puts(lua_tostring(L, lua_gettop(L))); + puts(smlua_to_string(L, lua_gettop(L))); smlua_dump_stack(); return; } diff --git a/src/pc/lua/smlua_cobject.c b/src/pc/lua/smlua_cobject.c index 2e32ab4af..0be73cdad 100644 --- a/src/pc/lua/smlua_cobject.c +++ b/src/pc/lua/smlua_cobject.c @@ -299,9 +299,14 @@ static struct LuaObjectField* smlua_get_object_field(struct LuaObjectTable* ot, } static int smlua__get_field(lua_State* L) { - enum LuaObjectType lot = lua_tointeger(L, -3); - u64 pointer = lua_tointeger(L, -2); - const char* key = lua_tostring(L, -1); + enum LuaObjectType lot = smlua_to_integer(L, 1); + if (!gSmLuaConvertSuccess) { return 0; } + + u64 pointer = smlua_to_integer(L, 2); + if (!gSmLuaConvertSuccess) { return 0; } + + const char* key = smlua_to_string(L, 3); + if (!gSmLuaConvertSuccess) { return 0; } if (pointer == 0) { LOG_LUA("_get_field on null pointer"); @@ -344,9 +349,14 @@ static int smlua__get_field(lua_State* L) { } static int smlua__set_field(lua_State* L) { - enum LuaObjectType lot = lua_tointeger(L, -4); - u64 pointer = lua_tointeger(L, -3); - const char* key = lua_tostring(L, -2); + enum LuaObjectType lot = smlua_to_integer(L, 1); + if (!gSmLuaConvertSuccess) { return 0; } + + u64 pointer = smlua_to_integer(L, 2); + if (!gSmLuaConvertSuccess) { return 0; } + + const char* key = smlua_to_string(L, 3); + if (!gSmLuaConvertSuccess) { return 0; } if (pointer == 0) { LOG_LUA("_set_field on null pointer"); diff --git a/src/pc/lua/smlua_hooks.c b/src/pc/lua/smlua_hooks.c index ba7a54551..4c890600e 100644 --- a/src/pc/lua/smlua_hooks.c +++ b/src/pc/lua/smlua_hooks.c @@ -11,7 +11,9 @@ static struct LuaHookedEvent sHookedEvents[HOOK_MAX] = { 0 }; int smlua_hook_event(lua_State* L) { if (L == NULL) { return 0; } - u16 hookType = lua_tointeger(L, -2); + u16 hookType = smlua_to_integer(L, -2); + if (!gSmLuaConvertSuccess) { return 0; } + if (hookType >= HOOK_MAX) { LOG_LUA("LUA: Hook Type: %d exceeds max!", hookType); return 0; @@ -88,8 +90,9 @@ int smlua_hook_mario_action(lua_State* L) { } struct LuaHookedMarioAction* hooked = &sHookedMarioActions[sHookedMarioActionsCount]; - hooked->action = lua_tointeger(L, -2); + hooked->action = smlua_to_integer(L, -2); hooked->reference = luaL_ref(L, LUA_REGISTRYINDEX); + if (!gSmLuaConvertSuccess) { return 0; } sHookedMarioActionsCount++; return 1; @@ -116,9 +119,11 @@ bool smlua_call_action_hook(struct MarioState* m, s32* returnValue) { } // output the return value - *returnValue = lua_tointeger(L, -1); + *returnValue = smlua_to_integer(L, -1); lua_pop(L, 1); + if (!gSmLuaConvertSuccess) { return false; } + return true; } } diff --git a/src/pc/lua/smlua_utils.c b/src/pc/lua/smlua_utils.c index 72a61a7b6..d54a08e4d 100644 --- a/src/pc/lua/smlua_utils.c +++ b/src/pc/lua/smlua_utils.c @@ -41,7 +41,9 @@ void smlua_logline(void) { ////////////////////////////////////////////// lua_Integer smlua_to_integer(lua_State* L, int index) { - if (lua_type(L, index) != LUA_TNUMBER) { + if (lua_type(L, index) == LUA_TBOOLEAN) { + return lua_toboolean(L, index) ? 1 : 0; + } else if (lua_type(L, index) != LUA_TNUMBER) { LOG_LUA("LUA: smlua_to_integer received improper type '%d'", lua_type(L, index)); smlua_logline(); gSmLuaConvertSuccess = false; @@ -63,6 +65,17 @@ lua_Number smlua_to_number(lua_State* L, int index) { return lua_tonumber(L, index); } +const char* smlua_to_string(lua_State* L, int index) { + if (lua_type(L, index) != LUA_TSTRING) { + LOG_LUA("LUA: smlua_to_string received improper type '%d'", lua_type(L, index)); + smlua_logline(); + gSmLuaConvertSuccess = false; + return 0; + } + gSmLuaConvertSuccess = true; + return lua_tostring(L, index); +} + void* smlua_to_cobject(lua_State* L, int index, enum LuaObjectType lot) { if (lua_type(L, index) != LUA_TTABLE) { LOG_LUA("LUA: smlua_to_cobject received improper type '%d'", lua_type(L, index)); diff --git a/src/pc/lua/smlua_utils.h b/src/pc/lua/smlua_utils.h index 2a92ac1e5..404c7f3f0 100644 --- a/src/pc/lua/smlua_utils.h +++ b/src/pc/lua/smlua_utils.h @@ -11,6 +11,7 @@ void smlua_logline(void); lua_Integer smlua_to_integer(lua_State* L, int index); lua_Number smlua_to_number(lua_State* L, int index); +const char* smlua_to_string(lua_State* L, int index); void* smlua_to_cobject(lua_State* L, int index, enum LuaObjectType lot); void smlua_push_object(lua_State* L, enum LuaObjectType lot, void* p);