diff --git a/src/waitress/task.py b/src/waitress/task.py index 956c0c0f..4f7f0344 100644 --- a/src/waitress/task.py +++ b/src/waitress/task.py @@ -180,6 +180,19 @@ def has_body(self): or self.status.startswith("304") ) + def set_close_on_finish(self) -> None: + # if headers have not been written yet, tell the remote + # client we are closing the connection + if not self.wrote_header: + connection_close_header = None + for (headername, headerval) in self.response_headers: + headername = "-".join([x.capitalize() for x in headername.split("-")]) + if headername == "Connection": + connection_close_header = headerval.lower() + if connection_close_header is None: + self.response_headers.append(("Connection", "close")) + self.close_on_finish = True + def build_response_header(self): version = self.version # Figure out whether the connection should be closed. @@ -188,7 +201,6 @@ def build_response_header(self): content_length_header = None date_header = None server_header = None - connection_close_header = None for (headername, headerval) in self.response_headers: headername = "-".join([x.capitalize() for x in headername.split("-")]) @@ -205,8 +217,6 @@ def build_response_header(self): if headername == "Server": server_header = headerval - if headername == "Connection": - connection_close_header = headerval.lower() # replace with properly capitalized version response_headers.append((headername, headerval)) @@ -218,23 +228,18 @@ def build_response_header(self): content_length_header = str(self.content_length) response_headers.append(("Content-Length", content_length_header)) - def close_on_finish(): - if connection_close_header is None: - response_headers.append(("Connection", "close")) - self.close_on_finish = True - if version == "1.0": if connection == "keep-alive": if not content_length_header: - close_on_finish() + self.set_close_on_finish() else: response_headers.append(("Connection", "Keep-Alive")) else: - close_on_finish() + self.set_close_on_finish() elif version == "1.1": if connection == "close": - close_on_finish() + self.set_close_on_finish() if not content_length_header: # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length @@ -244,8 +249,8 @@ def close_on_finish(): response_headers.append(("Transfer-Encoding", "chunked")) self.chunked_response = True - if not self.close_on_finish: - close_on_finish() + if not self.self.set_close_on_finish: + self.set_close_on_finish() # under HTTP 1.1 keep-alive is default, no need to set the header else: @@ -350,11 +355,7 @@ def execute(self): status, headers, body = e.to_response(ident) self.status = status self.response_headers.extend(headers) - # We need to explicitly tell the remote client we are closing the - # connection, because self.close_on_finish is set, and we are going to - # slam the door in the clients face. - self.response_headers.append(("Connection", "close")) - self.close_on_finish = True + self.set_close_on_finish() self.content_length = len(body) self.write(body) @@ -478,7 +479,7 @@ def start_response(status, headers, exc_info=None): # close the connection so the client isn't sitting around # waiting for more data when there are too few bytes # to service content-length - self.close_on_finish = True + self.set_close_on_finish() if self.request.command != "HEAD": self.logger.warning( "application returned too few bytes (%s) "