Skip to content

Commit

Permalink
gh-113538: Add asycio.Server.{close,abort}_clients (#114432)
Browse files Browse the repository at this point in the history
These give applications the option of more forcefully terminating client
connections for asyncio servers. Useful when terminating a service and
there is limited time to wait for clients to finish up their work.
  • Loading branch information
CendioOssman authored Mar 11, 2024
1 parent 872c071 commit 1d0d49a
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 20 deletions.
25 changes: 25 additions & 0 deletions Doc/library/asyncio-eventloop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1641,6 +1641,31 @@ Do not instantiate the :class:`Server` class directly.
coroutine to wait until the server is closed (and no more
connections are active).

.. method:: close_clients()

Close all existing incoming client connections.

Calls :meth:`~asyncio.BaseTransport.close` on all associated
transports.

:meth:`close` should be called before :meth:`close_clients` when
closing the server to avoid races with new clients connecting.

.. versionadded:: 3.13

.. method:: abort_clients()

Close all existing incoming client connections immediately,
without waiting for pending operations to complete.

Calls :meth:`~asyncio.WriteTransport.abort` on all associated
transports.

:meth:`close` should be called before :meth:`abort_clients` when
closing the server to avoid races with new clients connecting.

.. versionadded:: 3.13

.. method:: get_loop()

Return the event loop associated with the server object.
Expand Down
5 changes: 5 additions & 0 deletions Doc/whatsnew/3.13.rst
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,11 @@ asyncio
the buffer size.
(Contributed by Jamie Phan in :gh:`115199`.)

* Add :meth:`asyncio.Server.close_clients` and
:meth:`asyncio.Server.abort_clients` methods which allow to more
forcefully close an asyncio server.
(Contributed by Pierre Ossman in :gh:`113538`.)

base64
---

Expand Down
25 changes: 17 additions & 8 deletions Lib/asyncio/base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,9 @@ def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog,
ssl_handshake_timeout, ssl_shutdown_timeout=None):
self._loop = loop
self._sockets = sockets
self._active_count = 0
# Weak references so we don't break Transport's ability to
# detect abandoned transports
self._clients = weakref.WeakSet()
self._waiters = []
self._protocol_factory = protocol_factory
self._backlog = backlog
Expand All @@ -292,14 +294,13 @@ def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog,
def __repr__(self):
return f'<{self.__class__.__name__} sockets={self.sockets!r}>'

def _attach(self):
def _attach(self, transport):
assert self._sockets is not None
self._active_count += 1
self._clients.add(transport)

def _detach(self):
assert self._active_count > 0
self._active_count -= 1
if self._active_count == 0 and self._sockets is None:
def _detach(self, transport):
self._clients.discard(transport)
if len(self._clients) == 0 and self._sockets is None:
self._wakeup()

def _wakeup(self):
Expand Down Expand Up @@ -348,9 +349,17 @@ def close(self):
self._serving_forever_fut.cancel()
self._serving_forever_fut = None

if self._active_count == 0:
if len(self._clients) == 0:
self._wakeup()

def close_clients(self):
for transport in self._clients.copy():
transport.close()

def abort_clients(self):
for transport in self._clients.copy():
transport.abort()

async def start_serving(self):
self._start_serving()
# Skip one loop iteration so that all 'loop.add_reader'
Expand Down
8 changes: 8 additions & 0 deletions Lib/asyncio/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,14 @@ def close(self):
"""Stop serving. This leaves existing connections open."""
raise NotImplementedError

def close_clients(self):
"""Close all active connections."""
raise NotImplementedError

def abort_clients(self):
"""Close all active connections immediately."""
raise NotImplementedError

def get_loop(self):
"""Get the event loop the Server object is attached to."""
raise NotImplementedError
Expand Down
4 changes: 2 additions & 2 deletions Lib/asyncio/proactor_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(self, loop, sock, protocol, waiter=None,
self._called_connection_lost = False
self._eof_written = False
if self._server is not None:
self._server._attach()
self._server._attach(self)
self._loop.call_soon(self._protocol.connection_made, self)
if waiter is not None:
# only wake up the waiter when connection_made() has been called
Expand Down Expand Up @@ -167,7 +167,7 @@ def _call_connection_lost(self, exc):
self._sock = None
server = self._server
if server is not None:
server._detach()
server._detach(self)
self._server = None
self._called_connection_lost = True

Expand Down
6 changes: 4 additions & 2 deletions Lib/asyncio/selector_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,7 @@ def __init__(self, loop, sock, protocol, extra=None, server=None):
self._paused = False # Set when pause_reading() called

if self._server is not None:
self._server._attach()
self._server._attach(self)
loop._transports[self._sock_fd] = self

def __repr__(self):
Expand Down Expand Up @@ -868,6 +868,8 @@ def __del__(self, _warn=warnings.warn):
if self._sock is not None:
_warn(f"unclosed transport {self!r}", ResourceWarning, source=self)
self._sock.close()
if self._server is not None:
self._server._detach(self)

def _fatal_error(self, exc, message='Fatal error on transport'):
# Should be called from exception handler only.
Expand Down Expand Up @@ -906,7 +908,7 @@ def _call_connection_lost(self, exc):
self._loop = None
server = self._server
if server is not None:
server._detach()
server._detach(self)
self._server = None

def get_write_buffer_size(self):
Expand Down
96 changes: 88 additions & 8 deletions Lib/test/test_asyncio/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,12 @@ async def main(srv):
class TestServer2(unittest.IsolatedAsyncioTestCase):

async def test_wait_closed_basic(self):
async def serve(*args):
pass
async def serve(rd, wr):
try:
await rd.read()
finally:
wr.close()
await wr.wait_closed()

srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0)
self.addCleanup(srv.close)
Expand All @@ -137,7 +141,8 @@ async def serve(*args):
self.assertFalse(task1.done())

# active count != 0, not closed: should block
srv._attach()
addr = srv.sockets[0].getsockname()
(rd, wr) = await asyncio.open_connection(addr[0], addr[1])
task2 = asyncio.create_task(srv.wait_closed())
await asyncio.sleep(0)
self.assertFalse(task1.done())
Expand All @@ -152,7 +157,8 @@ async def serve(*args):
self.assertFalse(task2.done())
self.assertFalse(task3.done())

srv._detach()
wr.close()
await wr.wait_closed()
# active count == 0, closed: should unblock
await task1
await task2
Expand All @@ -161,22 +167,96 @@ async def serve(*args):

async def test_wait_closed_race(self):
# Test a regression in 3.12.0, should be fixed in 3.12.1
async def serve(*args):
pass
async def serve(rd, wr):
try:
await rd.read()
finally:
wr.close()
await wr.wait_closed()

srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0)
self.addCleanup(srv.close)

task = asyncio.create_task(srv.wait_closed())
await asyncio.sleep(0)
self.assertFalse(task.done())
srv._attach()
addr = srv.sockets[0].getsockname()
(rd, wr) = await asyncio.open_connection(addr[0], addr[1])
loop = asyncio.get_running_loop()
loop.call_soon(srv.close)
loop.call_soon(srv._detach)
loop.call_soon(wr.close)
await srv.wait_closed()

async def test_close_clients(self):
async def serve(rd, wr):
try:
await rd.read()
finally:
wr.close()
await wr.wait_closed()

srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0)
self.addCleanup(srv.close)

addr = srv.sockets[0].getsockname()
(rd, wr) = await asyncio.open_connection(addr[0], addr[1])
self.addCleanup(wr.close)

task = asyncio.create_task(srv.wait_closed())
await asyncio.sleep(0)
self.assertFalse(task.done())

srv.close()
srv.close_clients()
await asyncio.sleep(0)
await asyncio.sleep(0)
self.assertTrue(task.done())

async def test_abort_clients(self):
async def serve(rd, wr):
nonlocal s_rd, s_wr
s_rd = rd
s_wr = wr
await wr.wait_closed()

s_rd = s_wr = None
srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0)
self.addCleanup(srv.close)

addr = srv.sockets[0].getsockname()
(c_rd, c_wr) = await asyncio.open_connection(addr[0], addr[1], limit=4096)
self.addCleanup(c_wr.close)

# Limit the socket buffers so we can reliably overfill them
s_sock = s_wr.get_extra_info('socket')
s_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 65536)
c_sock = c_wr.get_extra_info('socket')
c_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 65536)

# Get the reader in to a paused state by sending more than twice
# the configured limit
s_wr.write(b'a' * 4096)
s_wr.write(b'a' * 4096)
s_wr.write(b'a' * 4096)
while c_wr.transport.is_reading():
await asyncio.sleep(0)

# Get the writer in a waiting state by sending data until the
# socket buffers are full on both server and client sockets and
# the kernel stops accepting more data
s_wr.write(b'a' * c_sock.getsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF))
s_wr.write(b'a' * s_sock.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF))
self.assertNotEqual(s_wr.transport.get_write_buffer_size(), 0)

task = asyncio.create_task(srv.wait_closed())
await asyncio.sleep(0)
self.assertFalse(task.done())

srv.close()
srv.abort_clients()
await asyncio.sleep(0)
await asyncio.sleep(0)
self.assertTrue(task.done())


# Test the various corner cases of Unix server socket removal
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Add :meth:`asyncio.Server.close_clients` and
:meth:`asyncio.Server.abort_clients` methods which allow to more forcefully
close an asyncio server.

0 comments on commit 1d0d49a

Please sign in to comment.