diff --git a/build-windows-visual-studio/sm64ex.vcxproj b/build-windows-visual-studio/sm64ex.vcxproj index adc32307d..cc8641b5b 100644 --- a/build-windows-visual-studio/sm64ex.vcxproj +++ b/build-windows-visual-studio/sm64ex.vcxproj @@ -71,6 +71,7 @@ true + ../;../include/;../src/;$(IncludePath) true @@ -3953,6 +3954,7 @@ + diff --git a/build-windows-visual-studio/sm64ex.vcxproj.filters b/build-windows-visual-studio/sm64ex.vcxproj.filters index 3392431b4..6c370cc37 100644 --- a/build-windows-visual-studio/sm64ex.vcxproj.filters +++ b/build-windows-visual-studio/sm64ex.vcxproj.filters @@ -14970,6 +14970,9 @@ Source Files\src\pc\network\packets + + Source Files\src\pc\network\packets + diff --git a/src/pc/network/network.c b/src/pc/network/network.c index b6abec54a..86584e580 100644 --- a/src/pc/network/network.c +++ b/src/pc/network/network.c @@ -68,7 +68,9 @@ void network_send(struct Packet* p) { network_remember_reliable(p); - int rc = sendto(gSocket, p->buffer, p->cursor, 0, (SOCKADDR *)& txAddr, sizeof(txAddr)); + u32 hash = packet_hash(p); + memcpy(&p->buffer[p->dataLength], &hash, sizeof(u32)); + int rc = sendto(gSocket, p->buffer, p->cursor + sizeof(u32), 0, (SOCKADDR *)& txAddr, sizeof(txAddr)); if (rc == SOCKET_ERROR) { wprintf(L"%s sendto failed with error: %d\n", NETWORKTYPESTR, WSAGetLastError()); return; @@ -104,6 +106,11 @@ void network_update(void) { } if (rc == 0) { break; } + p.dataLength = rc - sizeof(u32); + if (!packet_check_hash(&p)) { + printf("Invalid packet!\n"); + } + switch (p.buffer[0]) { case PACKET_ACK: network_receive_ack(&p); break; case PACKET_PLAYER: network_receive_player(&p); break; diff --git a/src/pc/network/network.h b/src/pc/network/network.h index 173793fdb..3e8f93970 100644 --- a/src/pc/network/network.h +++ b/src/pc/network/network.h @@ -27,6 +27,7 @@ enum PacketType { }; struct Packet { + int dataLength; int cursor; bool error; bool reliable; @@ -67,6 +68,8 @@ void network_shutdown(void); void packet_init(struct Packet* packet, enum PacketType packetType, bool reliable); void packet_write(struct Packet* packet, void* data, int length); void packet_read(struct Packet* packet, void* data, int length); +u32 packet_hash(struct Packet* packet); +bool packet_check_hash(struct Packet* packet); // packet headers void network_send_ack(struct Packet* p); diff --git a/src/pc/network/packets/packet_read_write.c b/src/pc/network/packets/packet_read_write.c index 7033897ee..c5e716807 100644 --- a/src/pc/network/packets/packet_read_write.c +++ b/src/pc/network/packets/packet_read_write.c @@ -10,6 +10,7 @@ void packet_init(struct Packet* packet, enum PacketType packetType, bool reliabl nextSeqNum++; if (nextSeqNum == 0) { nextSeqNum++; } } + packet->dataLength = 3; packet->cursor = 3; packet->error = false; packet->reliable = reliable; @@ -19,6 +20,7 @@ void packet_init(struct Packet* packet, enum PacketType packetType, bool reliabl void packet_write(struct Packet* packet, void* data, int length) { if (data == NULL) { packet->error = true; return; } memcpy(&packet->buffer[packet->cursor], data, length); + packet->dataLength += length; packet->cursor += length; } @@ -28,3 +30,20 @@ void packet_read(struct Packet* packet, void* data, int length) { memcpy(data, &packet->buffer[cursor], length); packet->cursor = cursor + length; } + +u32 packet_hash(struct Packet* packet) { + u32 hash = 0; + int byte = 0; + for (int i = 0; i < packet->dataLength; i++) { + hash ^= ((u32)packet->buffer[i]) << (8 * byte); + byte = (byte + 1) % sizeof(u32); + } + return hash; +} + +bool packet_check_hash(struct Packet* packet) { + u32 localHash = packet_hash(packet); + u32 packetHash = 0; + memcpy(&packetHash, &packet->buffer[packet->dataLength], sizeof(u32)); + return localHash == packetHash; +}