From 6c8050a564ca447f0b7cb577efa4694059128681 Mon Sep 17 00:00:00 2001 From: MysterD Date: Sat, 12 Sep 2020 17:56:42 -0700 Subject: [PATCH] Abstracted all socket code behind a NetworkSystem In preparation for other forms of communication, I have abstracted all of the socket code (which needs direct connections) behind a struct whose calls can be swapped out for other systems if desired. --- build-windows-visual-studio/sm64ex.vcxproj | 2 +- src/menu/file_select.c | 4 +- src/pc/debuglog.h | 11 +- src/pc/network/network.c | 136 +++++++++--------- src/pc/network/network.h | 10 +- .../network/packets/packet_inside_painting.c | 2 +- src/pc/network/socket/socket.c | 83 +++++++++-- src/pc/network/socket/socket.h | 9 +- src/pc/network/socket/socket_linux.c | 12 +- src/pc/network/socket/socket_linux.h | 3 + src/pc/network/socket/socket_windows.c | 15 +- src/pc/network/socket/socket_windows.h | 1 + src/pc/pc_main.c | 10 +- 13 files changed, 198 insertions(+), 100 deletions(-) diff --git a/build-windows-visual-studio/sm64ex.vcxproj b/build-windows-visual-studio/sm64ex.vcxproj index 419c30e53..560242945 100644 --- a/build-windows-visual-studio/sm64ex.vcxproj +++ b/build-windows-visual-studio/sm64ex.vcxproj @@ -87,7 +87,7 @@ Level3 true - _DEBUG;_CONSOLE;WINSOCK;%(PreprocessorDefinitions) + _DEBUG;_CONSOLE;WINSOCK;DEBUG;%(PreprocessorDefinitions) true diff --git a/src/menu/file_select.c b/src/menu/file_select.c index 7560b2e9e..b3bddb904 100644 --- a/src/menu/file_select.c +++ b/src/menu/file_select.c @@ -440,7 +440,7 @@ void join_server_as_client(void) { keyboard_stop_text_input(); joinVersionMismatch = FALSE; - network_init(NT_CLIENT, configJoinIp, configJoinPort); + network_init(NT_CLIENT); } void joined_server_as_client(s16 fileIndex) { @@ -1341,7 +1341,7 @@ void load_main_menu_save_file(struct Object *fileButton, s32 fileNum) { if (fileButton->oMenuButtonState == MENU_BUTTON_STATE_FULLSCREEN) { sSelectedFileNum = fileNum; configHostSaveSlot = fileNum; - network_init(NT_SERVER, "", configHostPort); + network_init(NT_SERVER); } } diff --git a/src/pc/debuglog.h b/src/pc/debuglog.h index b42361816..1801c6faa 100644 --- a/src/pc/debuglog.h +++ b/src/pc/debuglog.h @@ -13,6 +13,10 @@ printf(" [%s] ", NETWORKTYPESTR); } + static void debuglog_print_log_type(char* logType) { + printf("[%s] ", logType); + } + static void debuglog_print_short_filename(char* filename) { char* last = strrchr(filename, '/'); if (last != NULL) { @@ -22,13 +26,16 @@ } } - static void debuglog_print_log(char* filename) { + static void debuglog_print_log(char* logType, char* filename) { debuglog_print_timestamp(); debuglog_print_network_type(); + debuglog_print_log_type(logType); debuglog_print_short_filename(filename); } - #define LOG_INFO(...) ( debuglog_print_log(__FILE__), printf(__VA_ARGS__), printf("\n") ) + #define LOG_INFO(...) ( debuglog_print_log("INFO ", __FILE__), printf(__VA_ARGS__), printf("\n") ) + #define LOG_ERROR(...) ( debuglog_print_log("ERROR", __FILE__), printf(__VA_ARGS__), printf("\n") ) #else #define LOG_INFO(...) + #define LOG_ERROR(...) #endif diff --git a/src/pc/network/network.c b/src/pc/network/network.c index 6aec4456b..5a0da0d0a 100644 --- a/src/pc/network/network.c +++ b/src/pc/network/network.c @@ -4,13 +4,13 @@ #include "object_constants.h" #include "socket/socket.h" #include "pc/configfile.h" +#include "pc/debuglog.h" // Mario 64 specific externs extern s16 sCurrPlayMode; enum NetworkType gNetworkType = NT_NONE; -static SOCKET gSocket = 0; -struct sockaddr_in txAddr = { 0 }; +struct NetworkSystem* gNetworkSystem = &gNetworkSystemSocket; #define LOADING_LEVEL_THRESHOLD 10 u8 networkLoadingLevel = 0; @@ -21,16 +21,22 @@ struct ServerSettings gServerSettings = { .playerKnockbackStrength = 25, }; -void network_init(enum NetworkType inNetworkType, char* ip, unsigned int port) { +bool network_init(enum NetworkType inNetworkType) { + // sanity check network system + if (gNetworkSystem == NULL) { + LOG_ERROR("no network system attached"); + return false; + } + + // initialize the network system + int rc = gNetworkSystem->initialize(inNetworkType); + if (!rc) { + LOG_ERROR("failed to initialize network system"); + return false; + } + // set network type gNetworkType = inNetworkType; - if (gNetworkType == NT_NONE) { return; } - - // sanity check port - if (port == 0) { - port = (gNetworkType == NT_CLIENT) ? configJoinPort : configHostPort; - if (port == 0) { port = DEFAULT_PORT; } - } // set server settings if (gNetworkType == NT_SERVER) { @@ -39,26 +45,17 @@ void network_init(enum NetworkType inNetworkType, char* ip, unsigned int port) { gServerSettings.stayInLevelAfterStar = configStayInLevelAfterStar; } - // create a receiver socket to receive datagrams - gSocket = socket_initialize(); - if (gSocket == INVALID_SOCKET) { return; } - - // connect - if (gNetworkType == NT_SERVER) { - // bind the socket to any address and the specified port. - int rc = socket_bind(gSocket, port); - if (rc != NO_ERROR) { return; } - } else { - // save the port to send to - txAddr.sin_family = AF_INET; - txAddr.sin_port = htons(port); - txAddr.sin_addr.s_addr = inet_addr(ip); + // exit early if we're not really initializing the network + if (gNetworkType == NT_NONE) { + return true; } // send connection request if (gNetworkType == NT_CLIENT) { network_send_save_file_request(); } + + return true; } void network_on_init_level(void) { @@ -77,7 +74,8 @@ void network_on_loaded_level(void) { void network_send(struct Packet* p) { // sanity checks if (gNetworkType == NT_NONE) { return; } - if (p->error) { printf("%s packet error!\n", NETWORKTYPESTR); return; } + if (p->error) { LOG_ERROR("packet error!"); return; } + if (gNetworkSystem == NULL) { LOG_ERROR("no network system attached"); return; } // remember reliable packets network_remember_reliable(p); @@ -87,11 +85,51 @@ void network_send(struct Packet* p) { memcpy(&p->buffer[p->dataLength], &hash, sizeof(u32)); // send - int rc = socket_send(gSocket, &txAddr, p->buffer, p->cursor + sizeof(u32)); + int rc = gNetworkSystem->send(p->buffer, p->cursor + sizeof(u32)); if (rc != NO_ERROR) { return; } p->sent = true; } +void network_receive(u8* data, u16 dataLength) { + // receive packet + struct Packet p = { + .cursor = 3, + .buffer = { 0 }, + .dataLength = dataLength, + }; + memcpy(p.buffer, data, dataLength); + + // subtract and check hash + p.dataLength -= sizeof(u32); + if (!packet_check_hash(&p)) { + LOG_ERROR("invalid packet hash!"); + return; + } + + // execute packet + switch ((u8)p.buffer[0]) { + case PACKET_ACK: network_receive_ack(&p); break; + case PACKET_PLAYER: network_receive_player(&p); break; + case PACKET_OBJECT: network_receive_object(&p); break; + case PACKET_SPAWN_OBJECTS: network_receive_spawn_objects(&p); break; + case PACKET_SPAWN_STAR: network_receive_spawn_star(&p); break; + case PACKET_LEVEL_WARP: network_receive_level_warp(&p); break; + case PACKET_INSIDE_PAINTING: network_receive_inside_painting(&p); break; + case PACKET_COLLECT_STAR: network_receive_collect_star(&p); break; + case PACKET_COLLECT_COIN: network_receive_collect_coin(&p); break; + case PACKET_COLLECT_ITEM: network_receive_collect_item(&p); break; + case PACKET_RESERVATION_REQUEST: network_receive_reservation_request(&p); break; + case PACKET_RESERVATION: network_receive_reservation(&p); break; + case PACKET_SAVE_FILE_REQUEST: network_receive_save_file_request(&p); break; + case PACKET_SAVE_FILE: network_receive_save_file(&p); break; + case PACKET_CUSTOM: network_receive_custom(&p); break; + default: LOG_ERROR("received unknown packet: %d", p.buffer[0]); + } + + // send an ACK if requested + network_send_ack(&p); +} + void network_update(void) { if (gNetworkType == NT_NONE) { return; } @@ -110,50 +148,18 @@ void network_update(void) { } // receive packets - do { - // receive packet - struct Packet p = { .cursor = 3 }; - int rc = socket_receive(gSocket, &txAddr, p.buffer, PACKET_LENGTH, &p.dataLength); - if (rc != NO_ERROR) { break; } - - // subtract and check hash - p.dataLength -= sizeof(u32); - if (!packet_check_hash(&p)) { - printf("Invalid packet!\n"); - continue; - } - - // execute packet - switch ((u8)p.buffer[0]) { - case PACKET_ACK: network_receive_ack(&p); break; - case PACKET_PLAYER: network_receive_player(&p); break; - case PACKET_OBJECT: network_receive_object(&p); break; - case PACKET_SPAWN_OBJECTS: network_receive_spawn_objects(&p); break; - case PACKET_SPAWN_STAR: network_receive_spawn_star(&p); break; - case PACKET_LEVEL_WARP: network_receive_level_warp(&p); break; - case PACKET_INSIDE_PAINTING: network_receive_inside_painting(&p); break; - case PACKET_COLLECT_STAR: network_receive_collect_star(&p); break; - case PACKET_COLLECT_COIN: network_receive_collect_coin(&p); break; - case PACKET_COLLECT_ITEM: network_receive_collect_item(&p); break; - case PACKET_RESERVATION_REQUEST: network_receive_reservation_request(&p); break; - case PACKET_RESERVATION: network_receive_reservation(&p); break; - case PACKET_SAVE_FILE_REQUEST: network_receive_save_file_request(&p); break; - case PACKET_SAVE_FILE: network_receive_save_file(&p); break; - case PACKET_CUSTOM: network_receive_custom(&p); break; - default: printf("%s received unknown packet: %d\n", NETWORKTYPESTR, p.buffer[0]); - } - - // send an ACK if requested - network_send_ack(&p); - - } while (1); + if (gNetworkSystem != NULL) { + gNetworkSystem->update(); + } + // update reliable packets network_update_reliable(); } void network_shutdown(void) { if (gNetworkType == NT_NONE) { return; } - // close down socket - socket_close(gSocket); gNetworkType = NT_NONE; + + if (gNetworkSystem == NULL) { LOG_ERROR("no network system attached"); return; } + gNetworkSystem->shutdown(); } diff --git a/src/pc/network/network.h b/src/pc/network/network.h index 634c4dc29..16972fee1 100644 --- a/src/pc/network/network.h +++ b/src/pc/network/network.h @@ -18,6 +18,13 @@ extern struct MarioState gMarioStates[]; #define PACKET_LENGTH 1024 #define NETWORKTYPESTR (gNetworkType == NT_CLIENT ? "Client" : "Server") +struct NetworkSystem { + bool (*initialize)(enum NetworkType); + void (*update)(void); + int (*send)(u8* data, u16 dataLength); + void (*shutdown)(void); +}; + enum PacketType { PACKET_ACK, PACKET_PLAYER, @@ -85,10 +92,11 @@ extern struct SyncObject gSyncObjects[]; extern struct ServerSettings gServerSettings; // network.c -void network_init(enum NetworkType inNetworkType, char* ip, unsigned int port); +bool network_init(enum NetworkType inNetworkType); void network_on_init_level(void); void network_on_loaded_level(void); void network_send(struct Packet* p); +void network_receive(u8* data, u16 dataLength); void network_update(void); void network_shutdown(void); diff --git a/src/pc/network/packets/packet_inside_painting.c b/src/pc/network/packets/packet_inside_painting.c index 2e44b3e6e..d032ee062 100644 --- a/src/pc/network/packets/packet_inside_painting.c +++ b/src/pc/network/packets/packet_inside_painting.c @@ -57,7 +57,7 @@ void network_receive_inside_painting(struct Packet* p) { // two-player hack: gControlledWarp is a bool instead of an index if (gControlledWarp) { - LOG_INFO("this should never happen, received inside_painting when gControlledWarp"); + LOG_ERROR("this should never happen, received inside_painting when gControlledWarp"); return; } diff --git a/src/pc/network/socket/socket.c b/src/pc/network/socket/socket.c index 2e9953199..d33827724 100644 --- a/src/pc/network/socket/socket.c +++ b/src/pc/network/socket/socket.c @@ -1,40 +1,44 @@ #include -#include "../network.h" #include "socket.h" +#include "pc/configfile.h" +#include "pc/debuglog.h" -int socket_bind(SOCKET sock, unsigned int port) { +static SOCKET curSocket = INVALID_SOCKET; +struct sockaddr_in txAddr = { 0 }; + +static int socket_bind(SOCKET socket, unsigned int port) { struct sockaddr_in rxAddr; rxAddr.sin_family = AF_INET; rxAddr.sin_port = htons(port); rxAddr.sin_addr.s_addr = htonl(INADDR_ANY); - int rc = bind(sock, (SOCKADDR*)&rxAddr, sizeof(rxAddr)); + int rc = bind(socket, (SOCKADDR*)&rxAddr, sizeof(rxAddr)); if (rc != 0) { - printf("%s bind failed with error %d\n", NETWORKTYPESTR, SOCKET_LAST_ERROR); + LOG_ERROR("bind failed with error %d", SOCKET_LAST_ERROR); } return rc; } -int socket_send(SOCKET sock, struct sockaddr_in* txAddr, u8* buffer, u16 bufferLength) { - int txAddrSize = sizeof(struct sockaddr_in); - int rc = sendto(sock, (char*)buffer, bufferLength, 0, (struct sockaddr*)txAddr, txAddrSize); +static int socket_send(SOCKET socket, struct sockaddr_in* addr, u8* buffer, u16 bufferLength) { + int addrSize = sizeof(struct sockaddr_in); + int rc = sendto(socket, (char*)buffer, bufferLength, 0, (struct sockaddr*)addr, addrSize); if (rc == SOCKET_ERROR) { - printf("%s sendto failed with error: %d\n", NETWORKTYPESTR, SOCKET_LAST_ERROR); + LOG_ERROR("sendto failed with error: %d", SOCKET_LAST_ERROR); } return rc; } -int socket_receive(SOCKET sock, struct sockaddr_in* rxAddr, u8* buffer, u16 bufferLength, u16* receiveLength) { +static int socket_receive(SOCKET socket, struct sockaddr_in* rxAddr, u8* buffer, u16 bufferLength, u16* receiveLength) { *receiveLength = 0; int rxAddrSize = sizeof(struct sockaddr_in); - int rc = recvfrom(sock, (char*)buffer, bufferLength, 0, (struct sockaddr*)rxAddr, &rxAddrSize); + int rc = recvfrom(socket, (char*)buffer, bufferLength, 0, (struct sockaddr*)rxAddr, &rxAddrSize); if (rc == SOCKET_ERROR) { int error = SOCKET_LAST_ERROR; if (error != SOCKET_EWOULDBLOCK && error != SOCKET_ECONNRESET) { - printf("%s recvfrom failed with error %d\n", NETWORKTYPESTR, SOCKET_LAST_ERROR); + LOG_ERROR("recvfrom failed with error %d", SOCKET_LAST_ERROR); } return rc; } @@ -42,3 +46,60 @@ int socket_receive(SOCKET sock, struct sockaddr_in* rxAddr, u8* buffer, u16 buff *receiveLength = rc; return NO_ERROR; } + +static bool ns_socket_initialize(enum NetworkType networkType) { + // sanity check port + unsigned int port = (networkType == NT_CLIENT) ? configJoinPort : configHostPort; + if (port == 0) { port = DEFAULT_PORT; } + + // create a receiver socket to receive datagrams + curSocket = socket_initialize(); + if (curSocket == INVALID_SOCKET) { return false; } + + // connect + if (networkType == NT_SERVER) { + // bind the socket to any address and the specified port. + int rc = socket_bind(curSocket, port); + if (rc != NO_ERROR) { return false; } + LOG_INFO("bound to port %u", port); + } else { + // save the port to send to + txAddr.sin_family = AF_INET; + txAddr.sin_port = htons(port); + txAddr.sin_addr.s_addr = inet_addr(configJoinIp); + LOG_INFO("connecting to %s %u", configJoinIp, port); + } + + LOG_INFO("initialized"); + + // success + return true; +} + +static void ns_socket_update(void) { + do { + // receive packet + u8 data[PACKET_LENGTH]; + u16 dataLength = 0; + int rc = socket_receive(curSocket, &txAddr, data, PACKET_LENGTH, &dataLength); + if (rc != NO_ERROR) { break; } + network_receive(data, dataLength); + } while (true); +} + +static int ns_socket_send(u8* data, u16 dataLength) { + return socket_send(curSocket, &txAddr, data, dataLength); +} + +static void ns_socket_shutdown(void) { + socket_shutdown(curSocket); + curSocket = INVALID_SOCKET; + LOG_INFO("shutdown"); +} + +struct NetworkSystem gNetworkSystemSocket = { + .initialize = ns_socket_initialize, + .update = ns_socket_update, + .send = ns_socket_send, + .shutdown = ns_socket_shutdown, +}; diff --git a/src/pc/network/socket/socket.h b/src/pc/network/socket/socket.h index de295a71c..9729e885f 100644 --- a/src/pc/network/socket/socket.h +++ b/src/pc/network/socket/socket.h @@ -1,16 +1,17 @@ #ifndef SOCKET_H #define SOCKET_H +#include "../network.h" + #ifdef WINSOCK #include "socket_windows.h" #else #include "socket_linux.h" #endif +extern struct NetworkSystem gNetworkSystemSocket; + SOCKET socket_initialize(void); -int socket_bind(SOCKET sock, unsigned int port); -int socket_send(SOCKET sock, struct sockaddr_in* txAddr, u8* buffer, u16 bufferLength); -int socket_receive(SOCKET sock, struct sockaddr_in* rxAddr, u8* buffer, u16 bufferLength, u16* receiveLength); -void socket_close(SOCKET sock); +void socket_shutdown(SOCKET socket); #endif diff --git a/src/pc/network/socket/socket_linux.c b/src/pc/network/socket/socket_linux.c index b66706e97..f608e5c9c 100644 --- a/src/pc/network/socket/socket_linux.c +++ b/src/pc/network/socket/socket_linux.c @@ -1,29 +1,31 @@ #ifndef WINSOCK #include "socket_linux.h" #include "../network.h" +#include "pc/debuglog.h" SOCKET socket_initialize(void) { // initialize socket SOCKET sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); if (sock == INVALID_SOCKET) { - printf("%s socket failed with error %d\n", NETWORKTYPESTR, SOCKET_LAST_ERROR); + LOG_ERROR("socket failed with error %d", SOCKET_LAST_ERROR); return INVALID_SOCKET; } // set non-blocking mode int rc = fcntl(sock, F_SETFL, fcntl(sock, F_GETFL, 0) | O_NONBLOCK); if (rc == INVALID_SOCKET) { - printf("%s fcntl failed with error: %d\n", NETWORKTYPESTR, rc); + LOG_ERROR("fcntl failed with error: %d", rc); return INVALID_SOCKET; } return sock; } -void socket_close(SOCKET sock) { - int rc = closesocket(sock); +void socket_shutdown(SOCKET socket) { + if (socket == INVALID_SOCKET) { return; } + int rc = closesocket(socket); if (rc == SOCKET_ERROR) { - printf("%s closesocket failed with error %d\n", NETWORKTYPESTR, SOCKET_LAST_ERROR); + LOG_ERROR("closesocket failed with error %d\n", SOCKET_LAST_ERROR); } } diff --git a/src/pc/network/socket/socket_linux.h b/src/pc/network/socket/socket_linux.h index 5bb91da8c..68b870453 100644 --- a/src/pc/network/socket/socket_linux.h +++ b/src/pc/network/socket/socket_linux.h @@ -1,11 +1,13 @@ #ifndef SOCKET_LINUX_H #define SOCKET_LINUX_H +#ifndef WINSOCK #include #include #include #include #include +#include "socket.h" #define SOCKET unsigned int #define INVALID_SOCKET (unsigned int)(-1) @@ -18,3 +20,4 @@ #define SOCKET_ECONNRESET ECONNRESET #endif +#endif diff --git a/src/pc/network/socket/socket_windows.c b/src/pc/network/socket/socket_windows.c index 0a2c2639b..a9a172209 100644 --- a/src/pc/network/socket/socket_windows.c +++ b/src/pc/network/socket/socket_windows.c @@ -1,21 +1,21 @@ #ifdef WINSOCK #include #include "socket_windows.h" -#include "../network.h" +#include "pc/debuglog.h" SOCKET socket_initialize(void) { // start up winsock WSADATA wsaData; int rc = WSAStartup(MAKEWORD(2, 2), &wsaData); if (rc != NO_ERROR) { - printf("%s WSAStartup failed with error %d\n", NETWORKTYPESTR, rc); + LOG_ERROR("WSAStartup failed with error %d", rc); return INVALID_SOCKET; } // initialize socket SOCKET sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); if (sock == INVALID_SOCKET) { - printf("%s socket failed with error %d\n", NETWORKTYPESTR, SOCKET_LAST_ERROR); + LOG_ERROR("socket failed with error %d", SOCKET_LAST_ERROR); return INVALID_SOCKET; } @@ -23,17 +23,18 @@ SOCKET socket_initialize(void) { u_long iMode = 1; rc = ioctlsocket(sock, FIONBIO, &iMode); if (rc != NO_ERROR) { - printf("%s ioctlsocket failed with error: %d\n", NETWORKTYPESTR, rc); + LOG_ERROR("ioctlsocket failed with error: %d", rc); return INVALID_SOCKET; } return sock; } -void socket_close(SOCKET sock) { - int rc = closesocket(sock); +void socket_shutdown(SOCKET socket) { + if (socket == INVALID_SOCKET) { return; } + int rc = closesocket(socket); if (rc == SOCKET_ERROR) { - printf("%s closesocket failed with error %d\n", NETWORKTYPESTR, SOCKET_LAST_ERROR); + LOG_ERROR("closesocket failed with error %d", SOCKET_LAST_ERROR); } WSACleanup(); } diff --git a/src/pc/network/socket/socket_windows.h b/src/pc/network/socket/socket_windows.h index 50dad0cfe..4fd8f4a15 100644 --- a/src/pc/network/socket/socket_windows.h +++ b/src/pc/network/socket/socket_windows.h @@ -3,6 +3,7 @@ #include #include +#include "socket.h" #define SOCKET_LAST_ERROR WSAGetLastError() #define SOCKET_EWOULDBLOCK WSAEWOULDBLOCK diff --git a/src/pc/pc_main.c b/src/pc/pc_main.c index d064887ed..fc99f4f53 100644 --- a/src/pc/pc_main.c +++ b/src/pc/pc_main.c @@ -261,7 +261,15 @@ void main_func(void) { audio_api = &audio_null; } - network_init(gCLIOpts.Network, gCLIOpts.JoinIp, gCLIOpts.NetworkPort); + if (gCLIOpts.Network == NT_CLIENT) { + strncpy(configJoinIp, gCLIOpts.JoinIp, IP_MAX_LEN); + configJoinPort = gCLIOpts.NetworkPort; + network_init(NT_CLIENT); + } else if (gCLIOpts.Network == NT_SERVER) { + configHostPort = gCLIOpts.NetworkPort; + network_init(NT_SERVER); + } + audio_init(); sound_init();