Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-96704: Add task.get_context(), use it in call_exception_handler() #96756

Merged
merged 7 commits into from Oct 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
@@ -1271,6 +1271,15 @@ Allows customizing how exceptions are handled in the event loop.
(see :meth:`call_exception_handler` documentation for details
about context).

If the handler is called on behalf of a :class:`~asyncio.Task` or
:class:`~asyncio.Handle`, it is run in the
:class:`contextvars.Context` of that task or callback handle.

.. versionchanged:: 3.12

The handler may be called in the :class:`~contextvars.Context`
of the task or handle where the exception originated.

.. method:: loop.get_exception_handler()

Return the current exception handler, or ``None`` if no custom
@@ -1474,6 +1483,13 @@ Callback Handles
A callback wrapper object returned by :meth:`loop.call_soon`,
:meth:`loop.call_soon_threadsafe`.

.. method:: get_context()

Return the :class:`contextvars.Context` object
associated with the handle.

.. versionadded:: 3.12

.. method:: cancel()

Cancel the callback. If the callback has already been canceled
@@ -1097,6 +1097,13 @@ Task Object

.. versionadded:: 3.8

.. method:: get_context()

Return the :class:`contextvars.Context` object
associated with the task.

.. versionadded:: 3.12

.. method:: get_name()

Return the name of the Task.
@@ -1808,7 +1808,22 @@ def call_exception_handler(self, context):
exc_info=True)
else:
try:
self._exception_handler(self, context)
ctx = None
thing = context.get("task")
if thing is None:
# Even though Futures don't have a context,
# Task is a subclass of Future,
# and sometimes the 'future' key holds a Task.
thing = context.get("future")
if thing is None:
# Handles also have a context.
thing = context.get("handle")
if thing is not None and hasattr(thing, "get_context"):
ctx = thing.get_context()
if ctx is not None and hasattr(ctx, "run"):
ctx.run(self._exception_handler, self, context)
else:
self._exception_handler(self, context)
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as exc:
@@ -61,6 +61,9 @@ def __repr__(self):
info = self._repr_info()
return '<{}>'.format(' '.join(info))

def get_context(self):
return self._context

def cancel(self):
if not self._cancelled:
self._cancelled = True
@@ -139,6 +139,9 @@ def __repr__(self):
def get_coro(self):
return self._coro

def get_context(self):
return self._context

def get_name(self):
return self._name

@@ -1,5 +1,6 @@
# IsolatedAsyncioTestCase based tests
import asyncio
import contextvars
import traceback
import unittest
from asyncio import tasks
@@ -27,6 +28,46 @@ async def raise_exc():
else:
self.fail('TypeError was not raised')

async def test_task_exc_handler_correct_context(self):
# see https://github.com/python/cpython/issues/96704
name = contextvars.ContextVar('name', default='foo')
exc_handler_called = False

def exc_handler(*args):
self.assertEqual(name.get(), 'bar')
nonlocal exc_handler_called
exc_handler_called = True

async def task():
name.set('bar')
1/0

loop = asyncio.get_running_loop()
loop.set_exception_handler(exc_handler)
self.cls(task())
await asyncio.sleep(0)
self.assertTrue(exc_handler_called)

async def test_handle_exc_handler_correct_context(self):
# see https://github.com/python/cpython/issues/96704
name = contextvars.ContextVar('name', default='foo')
exc_handler_called = False

def exc_handler(*args):
self.assertEqual(name.get(), 'bar')
nonlocal exc_handler_called
exc_handler_called = True

def callback():
name.set('bar')
1/0

loop = asyncio.get_running_loop()
loop.set_exception_handler(exc_handler)
loop.call_soon(callback)
await asyncio.sleep(0)
self.assertTrue(exc_handler_called)

@unittest.skipUnless(hasattr(tasks, '_CTask'),
'requires the C _asyncio module')
class CFutureTests(FutureTests, unittest.IsolatedAsyncioTestCase):
@@ -2482,6 +2482,17 @@ def test_get_coro(self):
finally:
loop.close()

def test_get_context(self):
loop = asyncio.new_event_loop()
coro = coroutine_function()
context = contextvars.copy_context()
try:
task = self.new_task(loop, coro, context=context)
loop.run_until_complete(task)
self.assertIs(task.get_context(), context)
finally:
loop.close()


def add_subclass_tests(cls):
BaseTask = cls.Task
@@ -0,0 +1 @@
Pass the correct ``contextvars.Context`` when a ``asyncio`` exception handler is called on behalf of a task or callback handle. This adds a new ``Task`` method, ``get_context``, and also a new ``Handle`` method with the same name. If this method is not found on a task object (perhaps because it is a third-party library that does not yet provide this method), the context prevailing at the time the exception handler is called is used.
@@ -2409,6 +2409,18 @@ _asyncio_Task_get_coro_impl(TaskObj *self)
return self->task_coro;
}

/*[clinic input]
_asyncio.Task.get_context
[clinic start generated code]*/

static PyObject *
_asyncio_Task_get_context_impl(TaskObj *self)
/*[clinic end generated code: output=6996f53d3dc01aef input=87c0b209b8fceeeb]*/
{
Py_INCREF(self->task_context);
return self->task_context;
}

/*[clinic input]
_asyncio.Task.get_name
[clinic start generated code]*/
@@ -2536,6 +2548,7 @@ static PyMethodDef TaskType_methods[] = {
_ASYNCIO_TASK_GET_NAME_METHODDEF
_ASYNCIO_TASK_SET_NAME_METHODDEF
_ASYNCIO_TASK_GET_CORO_METHODDEF
_ASYNCIO_TASK_GET_CONTEXT_METHODDEF
{"__class_getitem__", Py_GenericAlias, METH_O|METH_CLASS, PyDoc_STR("See PEP 585")},
{NULL, NULL} /* Sentinel */
};

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.