Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

streams: Prevent RST_STREAM from being sent multiple times #1267

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 61 additions & 18 deletions src/h2/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1486,16 +1486,17 @@ def _receive_frame(self, frame):
# I don't love using __class__ here, maybe reconsider it.
frames, events = self._frame_dispatch_table[frame.__class__](frame)
except StreamClosedError as e:
# If the stream was closed by RST_STREAM, we just send a RST_STREAM
# to the remote peer. Otherwise, this is a connection error, and so
# we will re-raise to trigger one.
if self._stream_is_closed_by_reset(e.stream_id):
if e._connection_error:
raise
else:
# A StreamClosedError is raised when a stream wants to send a
# RST_STREAM frame. Since the H2Stream is the authoritative source
# of its own state, we always respect its wishes here.

f = RstStreamFrame(e.stream_id)
f.error_code = e.error_code
self._prepare_for_sending([f])
events = e._events
else:
raise
except StreamIDTooLowError as e:
# The stream ID seems invalid. This may happen when the closed
# stream has been cleaned up, or when the remote peer has opened a
Expand All @@ -1506,10 +1507,18 @@ def _receive_frame(self, frame):
# is either a stream error or a connection error.
if self._stream_is_closed_by_reset(e.stream_id):
# Closed by RST_STREAM is a stream error.
f = RstStreamFrame(e.stream_id)
f.error_code = ErrorCodes.STREAM_CLOSED
self._prepare_for_sending([f])
events = []
if self._stream_is_closed_by_peer_reset(e.stream_id):
self._closed_streams[e.stream_id] = StreamClosedBy.SEND_RST_STREAM

f = RstStreamFrame(e.stream_id)
f.error_code = ErrorCodes.STREAM_CLOSED
self._prepare_for_sending([f])
events = []
else:
# Stream was closed by a local reset. A stream SHOULD NOT
# send additional RST_STREAM frames. Ignore.
events = []
pass
elif self._stream_is_closed_by_end(e.stream_id):
# Closed by END_STREAM is a connection error.
raise StreamClosedError(e.stream_id)
Expand Down Expand Up @@ -1655,13 +1664,32 @@ def _handle_data_on_closed_stream(self, events, exc, frame):
"auto-emitted a WINDOW_UPDATE by %d",
frame.stream_id, conn_increment
)
f = RstStreamFrame(exc.stream_id)
f.error_code = exc.error_code
frames.append(f)
self.config.logger.debug(
"Stream %d already CLOSED or cleaned up - "
"auto-emitted a RST_FRAME" % frame.stream_id
)

send_rst_frame = False

if frame.stream_id in self._closed_streams:
closed_by = self._closed_streams[frame.stream_id]

if closed_by == StreamClosedBy.RECV_RST_STREAM:
self._closed_streams[frame.stream_id] = StreamClosedBy.SEND_RST_STREAM
send_rst_frame = True
elif closed_by == StreamClosedBy.SEND_RST_STREAM:
# Do not send additional RST_STREAM frames
pass
else:
# Protocol error
raise StreamClosedError(frame.stream_id)
else:
send_rst_frame = True

if send_rst_frame:
f = RstStreamFrame(exc.stream_id)
f.error_code = exc.error_code
frames.append(f)
self.config.logger.debug(
"Stream %d already CLOSED or cleaned up - "
"auto-emitted a RST_FRAME" % frame.stream_id
)
return frames, events + exc._events

def _receive_data_frame(self, frame):
Expand All @@ -1677,6 +1705,8 @@ def _receive_data_frame(self, frame):
flow_controlled_length
)

stream = None

try:
stream = self._get_stream_by_id(frame.stream_id)
frames, stream_events = stream.receive_data(
Expand All @@ -1685,6 +1715,11 @@ def _receive_data_frame(self, frame):
flow_controlled_length
)
except StreamClosedError as e:
# If this exception originated from a yet-to-be clenaed up stream,
# check if it should be a connection error
if stream is not None and e._connection_error:
raise

# This stream is either marked as CLOSED or already gone from our
# internal state.
return self._handle_data_on_closed_stream(events, e, frame)
Expand Down Expand Up @@ -1962,7 +1997,7 @@ def _stream_closed_by(self, stream_id):
before opening this one.
"""
if stream_id in self.streams:
return self.streams[stream_id].closed_by
return self.streams[stream_id].closed_by # pragma: no cover
if stream_id in self._closed_streams:
return self._closed_streams[stream_id]
return None
Expand All @@ -1976,6 +2011,14 @@ def _stream_is_closed_by_reset(self, stream_id):
StreamClosedBy.RECV_RST_STREAM, StreamClosedBy.SEND_RST_STREAM
)

def _stream_is_closed_by_peer_reset(self, stream_id):
"""
Returns ``True`` if the stream was closed by sending or receiving a
RST_STREAM frame. Returns ``False`` otherwise.
"""
return (self._stream_closed_by(stream_id) ==
StreamClosedBy.RECV_RST_STREAM)

def _stream_is_closed_by_end(self, stream_id):
"""
Returns ``True`` if the stream was closed by sending or receiving an
Expand Down
8 changes: 7 additions & 1 deletion src/h2/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class StreamClosedError(NoSuchStreamError):
that the stream has since been closed, and that all state relating to that
stream has been removed.
"""
def __init__(self, stream_id):
def __init__(self, stream_id, connection_error=True):
#: The stream ID corresponds to the nonexistent stream.
self.stream_id = stream_id

Expand All @@ -115,6 +115,12 @@ def __init__(self, stream_id):
# external users that may receive a StreamClosedError.
self._events = []

# If this is a connection error or a stream error. This exception
# is used to send a `RST_STREAM` frame on stream errors. If
# connection_error is false, H2Connection will suppress this
# exception after sending the reset frame.
self._connection_error = connection_error


class InvalidSettingsValueError(ProtocolError, ValueError):
"""
Expand Down
80 changes: 53 additions & 27 deletions src/h2/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def reset_stream_on_error(self, previous_state):
"""
self.stream_closed_by = StreamClosedBy.SEND_RST_STREAM

error = StreamClosedError(self.stream_id)
error = StreamClosedError(self.stream_id, connection_error=False)

event = StreamReset()
event.stream_id = self.stream_id
Expand All @@ -334,8 +334,31 @@ def recv_on_closed_stream(self, previous_state):
a stream error or connection error with type STREAM_CLOSED, depending
on the specific frame. The error handling is done at a higher level:
this just raises the appropriate error.
"""
raise StreamClosedError(self.stream_id)

RFC:
Normally, an endpoint SHOULD NOT send more than one RST_STREAM
frame for any stream. However, an endpoint MAY send additional
RST_STREAM frames if it receives frames on a closed stream after
more than a round-trip time. This behavior is permitted to deal
with misbehaving implementations.

Implementation:
Raising StreamClosedError causes the RST_STREAM frame to be sent.
If the stream closed_by value is SEND_RST_STREAM, ignore this
instead of raising, such that only one RST_STREAM frame is sent.
There is currently now latency tracking, and as such measuring
round-trip time for allowed additional RST_STREAM frames which
MAY be sent cannot be implemented.
"""

if self.stream_closed_by == StreamClosedBy.RECV_RST_STREAM:
self.stream_closed_by = StreamClosedBy.SEND_RST_STREAM
raise StreamClosedError(self.stream_id, connection_error=False)
elif self.stream_closed_by in (StreamClosedBy.RECV_END_STREAM,
StreamClosedBy.SEND_END_STREAM):
raise StreamClosedError(self.stream_id)

return []

def send_on_closed_stream(self, previous_state):
"""
Expand Down Expand Up @@ -1040,23 +1063,24 @@ def receive_headers(self, headers, end_stream, header_encoding):

events = self.state_machine.process_input(input_)

if end_stream:
es_events = self.state_machine.process_input(
StreamInputs.RECV_END_STREAM
)
events[0].stream_ended = es_events[0]
events += es_events
if len(events) > 0:
if end_stream:
es_events = self.state_machine.process_input(
StreamInputs.RECV_END_STREAM
)
events[0].stream_ended = es_events[0]
events += es_events

self._initialize_content_length(headers)
self._initialize_content_length(headers)

if isinstance(events[0], TrailersReceived):
if not end_stream:
raise ProtocolError("Trailers must have END_STREAM set")
if isinstance(events[0], TrailersReceived):
if not end_stream:
raise ProtocolError("Trailers must have END_STREAM set")

hdr_validation_flags = self._build_hdr_validation_flags(events)
events[0].headers = self._process_received_headers(
headers, hdr_validation_flags, header_encoding
)
hdr_validation_flags = self._build_hdr_validation_flags(events)
events[0].headers = self._process_received_headers(
headers, hdr_validation_flags, header_encoding
)
return [], events

def receive_data(self, data, end_stream, flow_control_len):
Expand All @@ -1068,18 +1092,20 @@ def receive_data(self, data, end_stream, flow_control_len):
"set to %d", self, end_stream, flow_control_len
)
events = self.state_machine.process_input(StreamInputs.RECV_DATA)
self._inbound_window_manager.window_consumed(flow_control_len)
self._track_content_length(len(data), end_stream)

if end_stream:
es_events = self.state_machine.process_input(
StreamInputs.RECV_END_STREAM
)
events[0].stream_ended = es_events[0]
events.extend(es_events)
if len(events) > 0:
self._inbound_window_manager.window_consumed(flow_control_len)
self._track_content_length(len(data), end_stream)

if end_stream:
es_events = self.state_machine.process_input(
StreamInputs.RECV_END_STREAM
)
events[0].stream_ended = es_events[0]
events.extend(es_events)

events[0].data = data
events[0].flow_controlled_length = flow_control_len
events[0].data = data
events[0].flow_controlled_length = flow_control_len
return [], events

def receive_window_update(self, increment):
Expand Down
Loading