diff --git a/trio/_core/__init__.py b/trio/_core/__init__.py index 4b3a088d1b..22169b60a0 100644 --- a/trio/_core/__init__.py +++ b/trio/_core/__init__.py @@ -12,7 +12,8 @@ from ._multierror import MultiError from ._ki import ( - enable_ki_protection, disable_ki_protection, currently_ki_protected + enable_ki_protection, disable_ki_protection, mark_ki_unsafe_as_leaf, + ki_allowed_if_safe, ki_forbidden, currently_ki_protected ) # Imports that always exist diff --git a/trio/_core/_generated_io_epoll.py b/trio/_core/_generated_io_epoll.py index fe63a6ee0c..05e82b5e14 100644 --- a/trio/_core/_generated_io_epoll.py +++ b/trio/_core/_generated_io_epoll.py @@ -2,26 +2,26 @@ # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._ki import enable_ki_protection - + +@enable_ki_protection async def wait_readable(fd): - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) except AttributeError: raise RuntimeError('must be called from async context') +@enable_ki_protection async def wait_writable(fd): - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) except AttributeError: raise RuntimeError('must be called from async context') +@enable_ki_protection def notify_closing(fd): - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) except AttributeError: diff --git a/trio/_core/_generated_io_kqueue.py b/trio/_core/_generated_io_kqueue.py index 059a8a95d1..94b436253a 100644 --- a/trio/_core/_generated_io_kqueue.py +++ b/trio/_core/_generated_io_kqueue.py @@ -2,47 +2,47 @@ # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._ki import enable_ki_protection - + +@enable_ki_protection def current_kqueue(): - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.current_kqueue() except AttributeError: raise RuntimeError('must be called from async context') +@enable_ki_protection def monitor_kevent(ident, filter): - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_kevent(ident, filter) except AttributeError: raise RuntimeError('must be called from async context') +@enable_ki_protection async def wait_kevent(ident, filter, abort_func): - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_kevent(ident, filter, abort_func) except AttributeError: raise RuntimeError('must be called from async context') +@enable_ki_protection async def wait_readable(fd): - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) except AttributeError: raise RuntimeError('must be called from async context') +@enable_ki_protection async def wait_writable(fd): - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) except AttributeError: raise RuntimeError('must be called from async context') +@enable_ki_protection def notify_closing(fd): - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) except AttributeError: diff --git a/trio/_core/_generated_io_windows.py b/trio/_core/_generated_io_windows.py index 78dd30db19..7e9fcda94a 100644 --- a/trio/_core/_generated_io_windows.py +++ b/trio/_core/_generated_io_windows.py @@ -2,68 +2,68 @@ # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._ki import enable_ki_protection - + +@enable_ki_protection async def wait_readable(sock): - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(sock) except AttributeError: raise RuntimeError('must be called from async context') +@enable_ki_protection async def wait_writable(sock): - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(sock) except AttributeError: raise RuntimeError('must be called from async context') +@enable_ki_protection def notify_closing(handle): - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(handle) except AttributeError: raise RuntimeError('must be called from async context') +@enable_ki_protection def register_with_iocp(handle): - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.register_with_iocp(handle) except AttributeError: raise RuntimeError('must be called from async context') +@enable_ki_protection async def wait_overlapped(handle, lpOverlapped): - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_overlapped(handle, lpOverlapped) except AttributeError: raise RuntimeError('must be called from async context') +@enable_ki_protection async def write_overlapped(handle, data, file_offset=0): - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.write_overlapped(handle, data, file_offset) except AttributeError: raise RuntimeError('must be called from async context') +@enable_ki_protection async def readinto_overlapped(handle, buffer, file_offset=0): - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.readinto_overlapped(handle, buffer, file_offset) except AttributeError: raise RuntimeError('must be called from async context') +@enable_ki_protection def current_iocp(): - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.current_iocp() except AttributeError: raise RuntimeError('must be called from async context') +@enable_ki_protection def monitor_completion_key(): - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_completion_key() except AttributeError: diff --git a/trio/_core/_generated_run.py b/trio/_core/_generated_run.py index 75f61bfdc5..4ee695bc20 100644 --- a/trio/_core/_generated_run.py +++ b/trio/_core/_generated_run.py @@ -2,10 +2,11 @@ # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._ki import enable_ki_protection - + +@enable_ki_protection def current_statistics(): """Returns an object containing run-loop-level debugging information. @@ -29,12 +30,12 @@ def current_statistics(): other attributes vary between backends. """ - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.current_statistics() except AttributeError: raise RuntimeError('must be called from async context') +@enable_ki_protection def current_time(): """Returns the current time according to Trio's internal clock. @@ -45,34 +46,34 @@ def current_time(): RuntimeError: if not inside a call to :func:`trio.run`. """ - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.current_time() except AttributeError: raise RuntimeError('must be called from async context') +@enable_ki_protection def current_clock(): """Returns the current :class:`~trio.abc.Clock`. """ - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.current_clock() except AttributeError: raise RuntimeError('must be called from async context') +@enable_ki_protection def current_root_task(): """Returns the current root :class:`Task`. This is the task that is the ultimate parent of all other tasks. """ - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.current_root_task() except AttributeError: raise RuntimeError('must be called from async context') +@enable_ki_protection def reschedule(task, next_send=_NO_SEND): """Reschedule the given task with the given :class:`outcome.Outcome`. @@ -91,12 +92,12 @@ def reschedule(task, next_send=_NO_SEND): raise) from :func:`wait_task_rescheduled`. """ - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.reschedule(task, next_send) except AttributeError: raise RuntimeError('must be called from async context') +@enable_ki_protection def spawn_system_task(async_fn, *args, name=None): """Spawn a "system" task. @@ -136,23 +137,23 @@ def spawn_system_task(async_fn, *args, name=None): Task: the newly spawned task """ - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.spawn_system_task(async_fn, *args, name=name) except AttributeError: raise RuntimeError('must be called from async context') +@enable_ki_protection def current_trio_token(): """Retrieve the :class:`TrioToken` for the current call to :func:`trio.run`. """ - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.current_trio_token() except AttributeError: raise RuntimeError('must be called from async context') +@enable_ki_protection async def wait_all_tasks_blocked(cushion=0.0, tiebreaker=0): """Block until there are no runnable tasks. @@ -213,12 +214,12 @@ async def test_lock_fairness(): print("FAIL") """ - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.wait_all_tasks_blocked(cushion, tiebreaker) except AttributeError: raise RuntimeError('must be called from async context') +@enable_ki_protection def add_instrument(instrument): """Start instrumenting the current run loop with the given instrument. @@ -228,12 +229,12 @@ def add_instrument(instrument): If ``instrument`` is already active, does nothing. """ - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.add_instrument(instrument) except AttributeError: raise RuntimeError('must be called from async context') +@enable_ki_protection def remove_instrument(instrument): """Stop instrumenting the current run loop with the given instrument. @@ -247,7 +248,6 @@ def remove_instrument(instrument): deactivated. """ - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.remove_instrument(instrument) except AttributeError: diff --git a/trio/_core/_ki.py b/trio/_core/_ki.py index a3f64c5dca..70d033408d 100644 --- a/trio/_core/_ki.py +++ b/trio/_core/_ki.py @@ -1,10 +1,14 @@ -import inspect +import contextlib +import contextvars +import enum import signal import sys +import weakref from contextlib import contextmanager -from functools import wraps +from functools import partial import async_generator +import outcome from .._util import is_main_thread @@ -15,6 +19,8 @@ __all__ = [ "enable_ki_protection", "disable_ki_protection", + "ki_allowed_if_safe", + "ki_forbidden", "currently_ki_protected", ] @@ -66,9 +72,31 @@ # # Solution: # -# Mark *stack frames* as being interrupt-safe or interrupt-unsafe, and from -# the signal handler check which kind of frame we're currently in when -# deciding whether to raise or schedule the exception. +# Mark *stack frames* as being interrupt-safe or interrupt-unsafe, and +# from the signal handler check which kind of frame we're currently in +# when deciding whether to raise or schedule the exception. +# +# Stack frames don't have much associated metadata, so this "marking" +# is a bit tricky to implement. A given function definition is either +# interrupt-safe or not, based on how it's written and what invariants +# it needs to uphold, so we track the most basic interrupt-safety +# information in a dictionary keyed by code object. (The code object +# is accessible when traversing the stack, while the function object +# itself is not.) Once we know we're in an interrupt-safe context from +# the perspective of Trio's guts, we can use a context variable to +# implement the policy decision of which tasks want to be interrupted +# and which don't. (Using *only* a contextvar runs into the problems +# described in the last bullet above, at least on Python versions that +# don't have a C contextvars module.) +# +# (Historical note: Previously we used a fake local variable to track the +# protection state, i.e., an entry in f_locals with a uniqued name. This +# did great at allowing us to accurately answer "is it OK to raise +# KeyboardInterrupt right now?", but had some unfortunate side effects. +# In particular, it incurred runtime overhead at each protection +# boundary, introduced extra stack frames in tracebacks, and meant +# that a @enable_ki_protection'd async function would not look like an +# async function to tools like inspect.iscoroutinefunction().) # # There are still some cases where this can fail, like if someone hits # control-C while the process is in the event loop, and then it immediately @@ -79,25 +107,183 @@ # but in general the solution is to kill the process some other way, just like # for any Python program that's written to catch and ignore # KeyboardInterrupt.) +# +# Terminology: We say a function is "KI safe" (KI for +# KeyboardInterrupt) if it is OK to deliver arbitrary +# KeyboardInterrupts there, and "KI unsafe" if KeyboardInterrupts may +# only be delivered at Trio checkpoints. The user decision of whether +# they want KIs in a particular context is called "KI allowed" or "KI +# forbidden". A KeyboardInterrupt will be delivered directly from the +# signal handler if the context is both KI-allowed and KI-safe, and +# scheduled for later delivery if not. + + +class KISafetyNote(enum.Enum): + """Information about the interrupt-safety of a stack frame.""" + + # It is safe to deliver KIs at arbitrary points in this frame and its + # transitive callees, except for those callees that are marked unsafe. + SAFE = 1 + + # It is not safe to deliver KIs at arbitrary points in this frame or + # its transitive callees, except for those callees that are marked safe. + UNSAFE = 2 + + # It is not safe to deliver KIs at arbitrary points in this frame, + # but the determination doesn't affect callees (except those + # marked transparent): callees inherit the KI safety or unsafety + # of this frame's parent. (This is used to mark the internals of + # common helpers like Context.run(), @contextmanager, and so on + # as KI-unsafe, while still potentially allowing KI in the various + # user-provided functions that they might execute.) + UNSAFE_AS_LEAF = 3 + + # Pretend this frame doesn't exist at all when determining KI + # safety. For example, if the innermost frame has note + # TRANSPARENT, and the second-innermost is UNSAFE_IF_LEAF, then + # the context is considered KI-unsafe. + TRANSPARENT = 4 + + +# Maps a code object to a KI safety note about all frames that use it. +# Frames whose code objects are not in this dictionary inherit the KI +# safety state of their caller. +ki_safety_note = weakref.WeakKeyDictionary() + + +@async_generator.async_generator +async def example_oldstyle_asyncgen(): + pass -# We use this special string as a unique key into the frame locals dictionary. -# The @ ensures it is not a valid identifier and can't clash with any possible -# real local name. See: https://github.com/python-trio/trio/issues/469 -LOCALS_KEY_KI_PROTECTION_ENABLED = '@TRIO_KI_PROTECTION_ENABLED' + +def _ki_safety_decorator(note: KISafetyNote, name: str) -> "Callable[[F], F]": + def decorator(fn): + if hasattr(fn, "__wrapped__"): + # We don't want to add some common wrapper to our KI notes + # dictionary; if we did, we would be applying the KI + # safety change to every function that uses that wrapper, + # since the safety notes are per code object. + if fn.__code__ is example_oldstyle_asyncgen.__code__: + # Special support for @async_generator, which required that + # the KI safety decorator go on top under the previous + # approach. + decorator(fn.__wrapped__) + return fn + + raise RuntimeError( + f"@{decorator.__name__} must be at the bottom of the decorator " + f"stack (closest to the function definition) since it only " + f"applies to the function itself, not any wrappers added by " + f"other decorators." + ) + + if fn.__code__.co_name != fn.__name__: + raise RuntimeError( + f"{fn.__name__}'s code object is named {fn.__code__.co_name} " + f"which looks like it might be some common helper from a " + f"decorator or similar. Make sure @{decorator.__name__} is " + f"at the bottom of the decorator stack (closest to the " + f"function definition). Otherwise you're affecting the " + f"KeyboardInterrupt safety status of the common helper, " + f"not the function you wrote." + ) + + ki_safety_note[fn.__code__] = note + return fn + + decorator.__name__ = decorator.__qualname__ = name + return decorator + + +enable_ki_protection = _ki_safety_decorator( + KISafetyNote.UNSAFE, "enable_ki_protection" +) +disable_ki_protection = _ki_safety_decorator( + KISafetyNote.SAFE, "disable_ki_protection" +) +mark_ki_unsafe_as_leaf = _ki_safety_decorator( + KISafetyNote.UNSAFE_AS_LEAF, "mark_ki_unsafe_as_leaf" +) +mark_ki_safety_transparent = _ki_safety_decorator( + KISafetyNote.TRANSPARENT, "mark_ki_safety_transparent" +) + + +def setup_default_protections(): + # Protect a handful of functions we don't control. + if sys.version_info < (3, 7): + # The pure-Python contextvars backport -- we want Context.run() to be + # KI protected, but the function it calls to not be + mark_ki_unsafe_as_leaf(contextvars.Context.run) + enable_ki_protection(contextvars._get_context) + enable_ki_protection(contextvars._set_context) + + # Protect outcome capture, but not the function whose result is being + # captured + mark_ki_unsafe_as_leaf(outcome.capture) + mark_ki_unsafe_as_leaf(outcome.acapture) + + # Protect the machinery of @contextmanager and @asynccontextmanager that + # isn't part of the user-supplied function + @contextlib.contextmanager + def sync_cm(): + yield + + mark_ki_unsafe_as_leaf(type(sync_cm()).__enter__) + mark_ki_unsafe_as_leaf(type(sync_cm()).__exit__) + + async def agen(): + yield + + acm_makers = [async_generator.asynccontextmanager] + if sys.version_info >= (3, 7): + acm_makers.append(contextlib.asynccontextmanager) + for maker in acm_makers: + mark_ki_unsafe_as_leaf(type(maker(agen)()).__aenter__) + mark_ki_unsafe_as_leaf(type(maker(agen)()).__aexit__) + + +setup_default_protections() + +# This is a contextvar used to track whether the current task is allowed to be +# interrupted by KeyboardInterrupt at all. False is used for system tasks, +# which by default don't receive KI regardless of the KI protection status. +# True means defer to the KI protection status implied by the call stack. +# If the contextvar is not set, we assume we're not in any Trio task (e.g. +# we're in run() or the IO manager or something) and so do not allow KI. +ki_allowed_cvar = contextvars.ContextVar("trio_ki_allowed") # NB: according to the signal.signal docs, 'frame' can be None on entry to # this function: -def ki_protection_enabled(frame): +def is_ki_safe(frame): + leaf = True while frame is not None: - if LOCALS_KEY_KI_PROTECTION_ENABLED in frame.f_locals: - return frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] if frame.f_code.co_name == "__del__": - return True + return False + note = ki_safety_note.get(frame.f_code) + if note is not None: + if note == KISafetyNote.SAFE: + return True + elif note == KISafetyNote.UNSAFE: + return False + elif note == KISafetyNote.UNSAFE_AS_LEAF: + if leaf: + return False + elif note == KISafetyNote.TRANSPARENT: + frame = frame.f_back + continue + else: + raise RuntimeError( + f"Internal error: invalid KI protection note {note!r} " + f"for code object {frame.f_code!r}" + ) frame = frame.f_back - return False + leaf = False + return True +@mark_ki_safety_transparent def currently_ki_protected(): r"""Check whether the calling code has :exc:`KeyboardInterrupt` protection enabled. @@ -111,73 +297,20 @@ def currently_ki_protected(): bool: True if protection is enabled, and False otherwise. """ - return ki_protection_enabled(sys._getframe()) - - -def _ki_protection_decorator(enabled): - def decorator(fn): - # In some version of Python, isgeneratorfunction returns true for - # coroutine functions, so we have to check for coroutine functions - # first. - if inspect.iscoroutinefunction(fn): - - @wraps(fn) - def wrapper(*args, **kwargs): - # See the comment for regular generators below - coro = fn(*args, **kwargs) - coro.cr_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED - ] = enabled - return coro - - return wrapper - elif inspect.isgeneratorfunction(fn): - - @wraps(fn) - def wrapper(*args, **kwargs): - # It's important that we inject this directly into the - # generator's locals, as opposed to setting it here and then - # doing 'yield from'. The reason is, if a generator is - # throw()n into, then it may magically pop to the top of the - # stack. And @contextmanager generators in particular are a - # case where we often want KI protection, and which are often - # thrown into! See: - # https://bugs.python.org/issue29590 - gen = fn(*args, **kwargs) - gen.gi_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED - ] = enabled - return gen - - return wrapper - elif async_generator.isasyncgenfunction(fn): - - @wraps(fn) - def wrapper(*args, **kwargs): - # See the comment for regular generators above - agen = fn(*args, **kwargs) - agen.ag_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED - ] = enabled - return agen - - return wrapper - else: - - @wraps(fn) - def wrapper(*args, **kwargs): - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled - return fn(*args, **kwargs) + return not (is_ki_safe(sys._getframe()) and ki_allowed_cvar.get(True)) - return wrapper - - return decorator +@contextmanager +def ki_allowed_context(allowed): + token = ki_allowed_cvar.set(allowed) + try: + yield + finally: + ki_allowed_cvar.reset(token) -enable_ki_protection = _ki_protection_decorator(True) # type: Callable[[F], F] -enable_ki_protection.__name__ = "enable_ki_protection" -disable_ki_protection = _ki_protection_decorator( - False -) # type: Callable[[F], F] -disable_ki_protection.__name__ = "disable_ki_protection" +ki_allowed_if_safe = partial(ki_allowed_context, True) +ki_forbidden = partial(ki_allowed_context, False) @contextmanager @@ -191,15 +324,22 @@ def ki_manager(deliver_cb, restrict_keyboard_interrupt_to_checkpoints): def handler(signum, frame): assert signum == signal.SIGINT - protection_enabled = ki_protection_enabled(frame) - if protection_enabled or restrict_keyboard_interrupt_to_checkpoints: - deliver_cb() - else: + + ki_allowed = ( + ki_allowed_cvar.get(True) + and not restrict_keyboard_interrupt_to_checkpoints + ) + if is_ki_safe(frame) and ki_allowed: raise KeyboardInterrupt + else: + deliver_cb() signal.signal(signal.SIGINT, handler) try: - yield + # This ensures that KIs delivered outside of any task context will + # be deferred. That covers run(), run_impl(), handle_io(), etc. + with ki_forbidden(): + yield finally: if signal.getsignal(signal.SIGINT) is handler: signal.signal(signal.SIGINT, signal.default_int_handler) diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 5904e682fd..fd92404d23 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -1,3 +1,5 @@ +# coding: utf-8 + import functools import itertools import logging @@ -23,9 +25,7 @@ from ._entry_queue import EntryQueue, TrioToken from ._exceptions import (TrioInternalError, RunFinishedError, Cancelled) -from ._ki import ( - LOCALS_KEY_KI_PROTECTION_ENABLED, ki_manager, enable_ki_protection -) +from ._ki import ki_manager, enable_ki_protection, ki_allowed_cvar from ._multierror import MultiError from ._traps import ( Abort, @@ -1342,16 +1342,7 @@ def _return_value_looks_like_wrong_library(value): context = self.system_context.copy() else: context = copy_context() - - if not hasattr(coro, "cr_frame"): - # This async function is implemented in C or Cython - async def python_wrapper(orig_coro): - return await orig_coro - - coro = python_wrapper(coro) - coro.cr_frame.f_locals.setdefault( - LOCALS_KEY_KI_PROTECTION_ENABLED, system_task - ) + context.run(ki_allowed_cvar.set, True) task = Task._create( coro=coro, @@ -1438,9 +1429,9 @@ def spawn_system_task(self, async_fn, *args, name=None): * System tasks are automatically cancelled when the main task exits. - * By default, system tasks have :exc:`KeyboardInterrupt` protection - *enabled*. If you want your task to be interruptible by control-C, - then you need to use :func:`disable_ki_protection` explicitly (and + * By default, system tasks have :exc:`KeyboardInterrupt` forbidden. + If you want (part of) your task to be interruptible by control-C, + then you need to wrap it in :func:`ki_allowed_if_safe` (and come up with some plan for what to do with a :exc:`KeyboardInterrupt`, given that system tasks aren't allowed to raise exceptions). @@ -1756,6 +1747,7 @@ def run( io_manager = TheIOManager() system_context = copy_context() system_context.run(current_async_library_cvar.set, "trio") + system_context.run(ki_allowed_cvar.set, False) runner = Runner( clock=clock, instruments=instruments, @@ -1763,7 +1755,6 @@ def run( system_context=system_context, ) GLOBAL_RUN_CONTEXT.runner = runner - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True # KI handling goes outside the core try/except/finally to avoid a window # where KeyboardInterrupt would be allowed and converted into an diff --git a/trio/_threads.py b/trio/_threads.py index 811bc526a0..0e6540a637 100644 --- a/trio/_threads.py +++ b/trio/_threads.py @@ -1,3 +1,5 @@ +# coding: utf-8 + import threading import queue as stdlib_queue from itertools import count @@ -8,7 +10,10 @@ import trio from ._sync import CapacityLimiter -from ._core import enable_ki_protection, disable_ki_protection, RunVar, TrioToken +from ._core import ( + enable_ki_protection, disable_ki_protection, mark_ki_unsafe_as_leaf, + ki_allowed_if_safe, RunVar, TrioToken +) # Global due to Threading API, thread local storage for trio token TOKEN_LOCAL = threading.local() @@ -378,14 +383,17 @@ def from_thread_run(afn, *args, trio_token=None): to enter Trio. """ def callback(q, afn, args): - @disable_ki_protection - async def unprotected_afn(): - return await afn(*args) - - async def await_in_trio_thread_task(): - q.put_nowait(await outcome.acapture(unprotected_afn)) - - trio.lowlevel.spawn_system_task(await_in_trio_thread_task, name=afn) + @mark_ki_unsafe_as_leaf + async def await_in_trio_thread(): + with ki_allowed_if_safe(): + res = await outcome.acapture(afn, *args) + q.put_nowait(res) + + task = trio.lowlevel.spawn_system_task(await_in_trio_thread, name=afn) + # Normally, system tasks are permanently non-KIable. We want the user + # code to be KIable though, so we make the task interruptible + # and use enable/disable protection decorators to protect our own glue. + task.keyboard_interruptible = True return _run_fn_as_system_task(callback, afn, *args, trio_token=trio_token) @@ -422,12 +430,10 @@ def from_thread_run_sync(fn, *args, trio_token=None): "foreign" thread, spawned using some other framework, and still want to enter Trio. """ + @mark_ki_unsafe_as_leaf def callback(q, fn, args): - @disable_ki_protection - def unprotected_fn(): - return fn(*args) - - res = outcome.capture(unprotected_fn) + with ki_allowed_if_safe(): + res = outcome.capture(fn, *args) q.put_nowait(res) return _run_fn_as_system_task(callback, fn, *args, trio_token=trio_token) diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py index 1340d049a2..da31f4105c 100755 --- a/trio/_tools/gen_exports.py +++ b/trio/_tools/gen_exports.py @@ -19,13 +19,12 @@ # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._ki import enable_ki_protection + - """ -TEMPLATE = """locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True -try: +TEMPLATE = """try: return {} GLOBAL_RUN_CONTEXT.{}.{} except AttributeError: raise RuntimeError('must be called from async context') @@ -122,7 +121,7 @@ def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str: ) # Assemble function definition arguments and body - snippet = func + indent(template, ' ' * 4) + snippet = "@enable_ki_protection\n" + func + indent(template, ' ' * 4) # Append the snippet to the corresponding module generated.append(snippet) diff --git a/trio/lowlevel.py b/trio/lowlevel.py index 5fe32c03d9..f20b9790ca 100644 --- a/trio/lowlevel.py +++ b/trio/lowlevel.py @@ -14,8 +14,9 @@ # Generally available symbols from ._core import ( cancel_shielded_checkpoint, Abort, wait_task_rescheduled, - enable_ki_protection, disable_ki_protection, currently_ki_protected, Task, - checkpoint, current_task, ParkingLot, UnboundedQueue, RunVar, TrioToken, + enable_ki_protection, disable_ki_protection, currently_ki_protected, + mark_ki_unsafe_as_leaf, ki_allowed_if_safe, ki_forbidden, Task, checkpoint, + current_task, ParkingLot, UnboundedQueue, RunVar, TrioToken, current_trio_token, temporarily_detach_coroutine_object, permanently_detach_coroutine_object, reattach_detached_coroutine_object, current_statistics, reschedule, remove_instrument, add_instrument,