Skip to content

Commit

Permalink
[WebConnection] Fix race condition between Close and BeginWrite (mono…
Browse files Browse the repository at this point in the history
…#4693)

* [WebConnection] Make ReadDone and InitRead instance methods to avoid passing cnc around

* [WebConnection] Inline only call to InitConnection

* [WebConnection] Fix race condition between Close and BeginWrite
  • Loading branch information
luhenry authored and martinpotter committed Jun 30, 2017
1 parent e820cbc commit fe4c92e
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 48 deletions.
79 changes: 35 additions & 44 deletions mcs/class/System/System.Net/WebConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ class WebConnection
WaitCallback initConn;
bool keepAlive;
byte [] buffer;
static AsyncCallback readDoneDelegate = new AsyncCallback (ReadDone);
EventHandler abortHandler;
AbortHelper abortHelper;
internal WebConnectionData Data;
Expand Down Expand Up @@ -100,11 +99,6 @@ public WebConnection (IWebConnectionState wcs, ServicePoint sPoint)
this.sPoint = sPoint;
buffer = new byte [4096];
Data = new WebConnectionData ();
initConn = new WaitCallback (state => {
try {
InitConnection (state);
} catch {}
});
queue = wcs.Group.Queue;
abortHelper = new AbortHelper ();
abortHelper.Connection = this;
Expand Down Expand Up @@ -460,13 +454,12 @@ void HandleError (WebExceptionStatus st, Exception e, string where)
}
}

static void ReadDone (IAsyncResult result)
void ReadDone (IAsyncResult result)
{
WebConnection cnc = (WebConnection)result.AsyncState;
WebConnectionData data = cnc.Data;
Stream ns = cnc.nstream;
WebConnectionData data = Data;
Stream ns = nstream;
if (ns == null) {
cnc.Close (true);
Close (true);
return;
}

Expand All @@ -479,84 +472,84 @@ static void ReadDone (IAsyncResult result)
if (e.InnerException is ObjectDisposedException)
return;

cnc.HandleError (WebExceptionStatus.ReceiveFailure, e, "ReadDone1");
HandleError (WebExceptionStatus.ReceiveFailure, e, "ReadDone1");
return;
}

if (nread == 0) {
cnc.HandleError (WebExceptionStatus.ReceiveFailure, null, "ReadDone2");
HandleError (WebExceptionStatus.ReceiveFailure, null, "ReadDone2");
return;
}

if (nread < 0) {
cnc.HandleError (WebExceptionStatus.ServerProtocolViolation, null, "ReadDone3");
HandleError (WebExceptionStatus.ServerProtocolViolation, null, "ReadDone3");
return;
}

int pos = -1;
nread += cnc.position;
nread += position;
if (data.ReadState == ReadState.None) {
Exception exc = null;
try {
pos = GetResponse (data, cnc.sPoint, cnc.buffer, nread);
pos = GetResponse (data, sPoint, buffer, nread);
} catch (Exception e) {
exc = e;
}

if (exc != null || pos == -1) {
cnc.HandleError (WebExceptionStatus.ServerProtocolViolation, exc, "ReadDone4");
HandleError (WebExceptionStatus.ServerProtocolViolation, exc, "ReadDone4");
return;
}
}

if (data.ReadState == ReadState.Aborted) {
cnc.HandleError (WebExceptionStatus.RequestCanceled, null, "ReadDone");
HandleError (WebExceptionStatus.RequestCanceled, null, "ReadDone");
return;
}

if (data.ReadState != ReadState.Content) {
int est = nread * 2;
int max = (est < cnc.buffer.Length) ? cnc.buffer.Length : est;
int max = (est < buffer.Length) ? buffer.Length : est;
byte [] newBuffer = new byte [max];
Buffer.BlockCopy (cnc.buffer, 0, newBuffer, 0, nread);
cnc.buffer = newBuffer;
cnc.position = nread;
Buffer.BlockCopy (buffer, 0, newBuffer, 0, nread);
buffer = newBuffer;
position = nread;
data.ReadState = ReadState.None;
InitRead (cnc);
InitRead ();
return;
}

cnc.position = 0;
position = 0;

WebConnectionStream stream = new WebConnectionStream (cnc, data);
WebConnectionStream stream = new WebConnectionStream (this, data);
bool expect_content = ExpectContent (data.StatusCode, data.request.Method);
string tencoding = null;
if (expect_content)
tencoding = data.Headers ["Transfer-Encoding"];

cnc.chunkedRead = (tencoding != null && tencoding.IndexOf ("chunked", StringComparison.OrdinalIgnoreCase) != -1);
if (!cnc.chunkedRead) {
stream.ReadBuffer = cnc.buffer;
chunkedRead = (tencoding != null && tencoding.IndexOf ("chunked", StringComparison.OrdinalIgnoreCase) != -1);
if (!chunkedRead) {
stream.ReadBuffer = buffer;
stream.ReadBufferOffset = pos;
stream.ReadBufferSize = nread;
try {
stream.CheckResponseInBuffer ();
} catch (Exception e) {
cnc.HandleError (WebExceptionStatus.ReceiveFailure, e, "ReadDone7");
HandleError (WebExceptionStatus.ReceiveFailure, e, "ReadDone7");
}
} else if (cnc.chunkStream == null) {
} else if (chunkStream == null) {
try {
cnc.chunkStream = new ChunkStream (cnc.buffer, pos, nread, data.Headers);
chunkStream = new ChunkStream (buffer, pos, nread, data.Headers);
} catch (Exception e) {
cnc.HandleError (WebExceptionStatus.ServerProtocolViolation, e, "ReadDone5");
HandleError (WebExceptionStatus.ServerProtocolViolation, e, "ReadDone5");
return;
}
} else {
cnc.chunkStream.ResetBuffer ();
chunkStream.ResetBuffer ();
try {
cnc.chunkStream.Write (cnc.buffer, pos, nread);
chunkStream.Write (buffer, pos, nread);
} catch (Exception e) {
cnc.HandleError (WebExceptionStatus.ServerProtocolViolation, e, "ReadDone6");
HandleError (WebExceptionStatus.ServerProtocolViolation, e, "ReadDone6");
return;
}
}
Expand All @@ -576,16 +569,15 @@ static bool ExpectContent (int statusCode, string method)
return (statusCode >= 200 && statusCode != 204 && statusCode != 304);
}

internal static void InitRead (object state)
internal void InitRead ()
{
WebConnection cnc = (WebConnection) state;
Stream ns = cnc.nstream;
Stream ns = nstream;

try {
int size = cnc.buffer.Length - cnc.position;
ns.BeginRead (cnc.buffer, cnc.position, size, readDoneDelegate, cnc);
int size = buffer.Length - position;
ns.BeginRead (buffer, position, size, ReadDone, null);
} catch (Exception e) {
cnc.HandleError (WebExceptionStatus.ReceiveFailure, e, "InitRead");
HandleError (WebExceptionStatus.ReceiveFailure, e, "InitRead");
}
}

Expand Down Expand Up @@ -709,9 +701,8 @@ static int GetResponse (WebConnectionData data, ServicePoint sPoint,
return -1;
}

void InitConnection (object state)
void InitConnection (HttpWebRequest request)
{
HttpWebRequest request = (HttpWebRequest) state;
request.WebConnection = this;
if (request.ReuseConnection)
request.StoredConnection = this;
Expand Down Expand Up @@ -773,7 +764,7 @@ internal EventHandler SendRequest (HttpWebRequest request)
lock (this) {
if (state.TrySetBusy ()) {
status = WebExceptionStatus.Success;
ThreadPool.QueueUserWorkItem (initConn, request);
ThreadPool.QueueUserWorkItem (o => { try { InitConnection ((HttpWebRequest) o); } catch {} }, request);
} else {
lock (queue) {
#if MONOTOUCH
Expand Down
8 changes: 4 additions & 4 deletions mcs/class/System/System.Net/WebConnectionStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ void WriteAsyncCB (IAsyncResult r)
result.SetCompleted (false, 0);
if (!initRead) {
initRead = true;
WebConnection.InitRead (cnc);
cnc.InitRead ();
}
} catch (Exception e) {
KillBuffer ();
Expand Down Expand Up @@ -666,7 +666,7 @@ bool SetHeadersAsync (SimpleAsyncResult result, bool setInternalLength)
cnc.EndWrite (request, true, r);
if (!initRead) {
initRead = true;
WebConnection.InitRead (cnc);
cnc.InitRead ();
}
var cl = request.ContentLength;
if (!sendChunked && cl == 0)
Expand Down Expand Up @@ -730,7 +730,7 @@ internal bool WriteRequestAsync (SimpleAsyncResult result)

if (!initRead) {
initRead = true;
WebConnection.InitRead (cnc);
cnc.InitRead ();
}

if (length == 0) {
Expand Down Expand Up @@ -800,7 +800,7 @@ public override void Close ()
complete_request_written = true;
if (!initRead) {
initRead = true;
WebConnection.InitRead (cnc);
cnc.InitRead ();
}
return;
}
Expand Down

0 comments on commit fe4c92e

Please sign in to comment.