Skip to content

Commit

Permalink
DEV9: Improve validation of received sequence numbers
Browse files Browse the repository at this point in the history
  • Loading branch information
TheLastRar authored and stenzek committed Apr 21, 2024
1 parent 6d8a906 commit 1e09409
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 23 deletions.
6 changes: 3 additions & 3 deletions pcsx2/DEV9/Sessions/TCP_Session/TCP_Session.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ namespace Sessions
enum struct NumCheckResult
{
OK,
GotOldData,
OldSeq,
Bad
};

Expand Down Expand Up @@ -98,10 +98,10 @@ namespace Sessions
void ResetMyNumbers();

NumCheckResult CheckRepeatSYNNumbers(PacketReader::IP::TCP::TCP_Packet* tcp);
NumCheckResult CheckNumbers(PacketReader::IP::TCP::TCP_Packet* tcp);
NumCheckResult CheckNumbers(PacketReader::IP::TCP::TCP_Packet* tcp, bool rejectOldSeq = false);
s32 GetDelta(u32 a, u32 b); //Returns a - b
//Returns true if errored
bool ErrorOnNonEmptyPacket(PacketReader::IP::TCP::TCP_Packet* tcp);
bool ValidateEmptyPacket(PacketReader::IP::TCP::TCP_Packet* tcp, bool ignoreOld = true);

//PS2 sent SYN
PacketReader::IP::TCP::TCP_Packet* ConnectTCPComplete(bool success);
Expand Down
44 changes: 24 additions & 20 deletions pcsx2/DEV9/Sessions/TCP_Session/TCP_Session_Out.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,11 +311,7 @@ namespace Sessions
//Check if we already have some of the data sent
const uint delta = GetDelta(expectedSeqNumber, tcp->sequenceNumber);
pxAssert(delta >= 0);
//if (Result == NumCheckResult::GotOldData)
//{
// DevCon.WriteLn("[PS2] New Data Offset: %d bytes", delta);
// DevCon.WriteLn("[PS2] New Data Length: %d bytes", ((uint)tcp->GetPayload()->GetLength() - delta));
//}

if (Result == NumCheckResult::Bad)
{
CloseByRemoteRST();
Expand All @@ -324,6 +320,11 @@ namespace Sessions
}
if (tcp->GetPayload()->GetLength() != 0)
{
//if (Result == NumCheckResult::OldSeq)
//{
// DevCon.WriteLn("[PS2] New Data Offset: %d bytes", delta);
// DevCon.WriteLn("[PS2] New Data Length: %d bytes", ((uint)tcp->GetPayload()->GetLength() - delta));
//}
if (tcp->GetPayload()->GetLength() - delta > 0)
{
DevCon.WriteLn("DEV9: TCP: [PS2] Sending: %d bytes", tcp->GetPayload()->GetLength());
Expand Down Expand Up @@ -396,7 +397,7 @@ namespace Sessions
}
}

ErrorOnNonEmptyPacket(tcp);
ValidateEmptyPacket(tcp);

return true;
}
Expand All @@ -414,7 +415,7 @@ namespace Sessions
return NumCheckResult::OK;
}

TCP_Session::NumCheckResult TCP_Session::CheckNumbers(TCP_Packet* tcp)
TCP_Session::NumCheckResult TCP_Session::CheckNumbers(TCP_Packet* tcp, bool rejectOldSeq)
{
u32 seqNum;
std::vector<u32> oldSeqNums;
Expand Down Expand Up @@ -446,17 +447,23 @@ namespace Sessions

if (tcp->sequenceNumber != expectedSeqNumber)
{
if (tcp->GetPayload()->GetLength() == 0)
if (rejectOldSeq)
{
Console.Error("DEV9: TCP: [PS2] Sent Unexpected Sequence Number, Got %u Expected %u", tcp->sequenceNumber, expectedSeqNumber);
return NumCheckResult::Bad;
}
else if (tcp->GetPayload()->GetLength() == 0)
{
Console.Error("DEV9: TCP: [PS2] Sent Unexpected Sequence Number From ACK Packet, Got %u Expected %u", tcp->sequenceNumber, expectedSeqNumber);
return NumCheckResult::OldSeq;
}
else
{
//Check if receivedPS2SeqNumbers contains tcp->sequenceNumber
if (std::find(receivedPS2SeqNumbers.begin(), receivedPS2SeqNumbers.end(), tcp->sequenceNumber) == receivedPS2SeqNumbers.end())
{
Console.Error("DEV9: TCP: [PS2] Sent an Old Seq Number on an Data packet, Got %u Expected %u", tcp->sequenceNumber, expectedSeqNumber);
return NumCheckResult::GotOldData;
return NumCheckResult::OldSeq;
}
else
{
Expand All @@ -468,13 +475,9 @@ namespace Sessions

return NumCheckResult::OK;
}
bool TCP_Session::ErrorOnNonEmptyPacket(TCP_Packet* tcp)
bool TCP_Session::ValidateEmptyPacket(TCP_Packet* tcp, bool ignoreOld)
{
NumCheckResult ResultFIN = CheckNumbers(tcp);
if (ResultFIN == NumCheckResult::GotOldData)
{
return false;
}
NumCheckResult ResultFIN = CheckNumbers(tcp, !ignoreOld);
if (ResultFIN == NumCheckResult::Bad)
{
CloseByRemoteRST();
Expand All @@ -484,7 +487,8 @@ namespace Sessions
if (tcp->GetPayload()->GetLength() > 0)
{
uint delta = GetDelta(expectedSeqNumber, tcp->sequenceNumber);
if (delta == 0)
//Check if packet contains only old data
if (delta >= tcp->GetPayload()->GetLength())
return false;

CloseByRemoteRST();
Expand All @@ -499,7 +503,7 @@ namespace Sessions
{
//Console.WriteLn("DEV9: TCP: PS2 has closed connection");

if (ErrorOnNonEmptyPacket(tcp)) //Sending FIN with data
if (ValidateEmptyPacket(tcp, false)) //Sending FIN with data
return true;

receivedPS2SeqNumbers.erase(receivedPS2SeqNumbers.begin());
Expand Down Expand Up @@ -532,7 +536,7 @@ namespace Sessions
//Close Part 4, Receive ACK from PS2
//Console.WriteLn("DEV9: TCP: Completed Close By PS2");

if (ErrorOnNonEmptyPacket(tcp))
if (ValidateEmptyPacket(tcp))
return true;

if (myNumberACKed.load())
Expand All @@ -551,7 +555,7 @@ namespace Sessions
{
//Console.WriteLn("DEV9: TCP: Completed Close By PS2");

if (ErrorOnNonEmptyPacket(tcp))
if (ValidateEmptyPacket(tcp))
return true;

if (myNumberACKed.load())
Expand All @@ -568,7 +572,7 @@ namespace Sessions
{
//Console.WriteLn("DEV9: TCP: PS2 has closed connection after remote");

if (ErrorOnNonEmptyPacket(tcp))
if (ValidateEmptyPacket(tcp, false)) //Sending FIN with data
return true;

receivedPS2SeqNumbers.erase(receivedPS2SeqNumbers.begin());
Expand Down

0 comments on commit 1e09409

Please sign in to comment.