diff --git a/ki/src/Ki/Internal/IO.hs b/ki/src/Ki/Internal/IO.hs index 719de51..97f4df7 100644 --- a/ki/src/Ki/Internal/IO.hs +++ b/ki/src/Ki/Internal/IO.hs @@ -10,7 +10,9 @@ module Ki.Internal.IO unexceptionalTryEither, -- * Exception utils - isAsyncException, + assertIO, + assertM, + exceptionIs, interruptiblyMasked, uninterruptiblyMasked, tryEitherSTM, @@ -29,6 +31,7 @@ import GHC.Base (maskAsyncExceptions#, maskUninterruptible#) import GHC.Conc (STM, ThreadId (ThreadId), catchSTM) import GHC.Exts (Int (I#), fork#, forkOn#) import GHC.IO (IO (IO)) +import System.IO.Unsafe (unsafePerformIO) import Prelude -- A little promise that this IO action cannot throw an exception (*including* async exceptions, which you normally @@ -68,9 +71,22 @@ unexceptionalTryEither onFailure onSuccess action = (coerce @_ @(a -> IO b) onSuccess <$> action) (pure . coerce @_ @(SomeException -> IO b) onFailure) -isAsyncException :: SomeException -> Bool -isAsyncException = - isJust . fromException @SomeAsyncException +-- | Make an assertion in a IO that requires IO. +assertIO :: IO Bool -> IO () +assertIO b = + assert (unsafePerformIO b) (pure ()) +{-# INLINE assertIO #-} + +-- | Make an assertion in a monad. +assertM :: (Applicative m) => Bool -> m () +assertM b = + assert b (pure ()) +{-# INLINE assertM #-} + +-- | @exceptionIs \@e exception@ returns whether @exception@ is an instance of @e@. +exceptionIs :: forall e. (Exception e) => SomeException -> Bool +exceptionIs = + isJust . fromException @e -- | Call an action with asynchronous exceptions interruptibly masked. interruptiblyMasked :: forall a. IO a -> IO a diff --git a/ki/src/Ki/Internal/Propagating.hs b/ki/src/Ki/Internal/Propagating.hs index 1b9a113..1c0067a 100644 --- a/ki/src/Ki/Internal/Propagating.hs +++ b/ki/src/Ki/Internal/Propagating.hs @@ -9,17 +9,19 @@ where import Control.Concurrent (ThreadId) import Control.Exception (Exception (..), SomeException, asyncExceptionFromException, asyncExceptionToException, throwTo) --- Internal exception type thrown by a child thread to its parent, if it fails unexpectedly. +-- Internal exception type thrown by a child thread to its parent, if the child fails unexpectedly. data Propagating = Propagating { childId :: {-# UNPACK #-} !Tid, exception :: !SomeException } - deriving stock (Show) instance Exception Propagating where toException = asyncExceptionToException fromException = asyncExceptionFromException +instance Show Propagating where + show _ = "<>" + pattern PropagatingFrom :: Tid -> SomeException pattern PropagatingFrom childId <- (fromException -> Just Propagating {childId}) diff --git a/ki/src/Ki/Internal/Scope.hs b/ki/src/Ki/Internal/Scope.hs index bce7b7f..c3d1908 100644 --- a/ki/src/Ki/Internal/Scope.hs +++ b/ki/src/Ki/Internal/Scope.hs @@ -16,8 +16,8 @@ import Control.Concurrent.MVar (MVar, newEmptyMVar, tryPutMVar, tryTakeMVar) import Control.Exception ( Exception (fromException, toException), MaskingState (..), + SomeAsyncException, SomeException, - assert, asyncExceptionFromException, asyncExceptionToException, throwIO, @@ -30,7 +30,6 @@ import Data.Foldable (for_) import Data.Functor (void) import Data.IntMap (IntMap) import qualified Data.IntMap.Lazy as IntMap.Lazy -import Data.Maybe (isJust) import Data.Void (Void, absurd) import GHC.Conc ( STM, @@ -53,8 +52,9 @@ import Ki.Internal.ByteCount (byteCountToInt64) import Ki.Internal.IO ( IOResult (..), UnexceptionalIO (..), + assertM, + exceptionIs, interruptiblyMasked, - isAsyncException, unexceptionalTry, unexceptionalTryEither, uninterruptiblyMasked, @@ -97,7 +97,7 @@ data Scope = Scope } -- The scope status: either open (allowing new threads to be created), closing (disallowing new threads to be --- created, and in the process of killing living children), or closed (at the very end of `scoped`) +-- created, and in the process of killing running children), or closed (at the very end of `scoped`) type ScopeStatus = Int -- The number of child threads that are guaranteed to be about to start, in the sense that only the GHC scheduler @@ -117,26 +117,20 @@ pattern Closed = -2 {-# COMPLETE Open, Closing, Closed #-} -- Internal async exception thrown by a parent thread to its children when the scope is closing. +-- +-- In various places we trust without verifying that any 'ScopeClosing' exception, which is not exported by this module, +-- was indeed thrown to a thread by its parent. It is possible to write a program that violates this (just catch the +-- async exception and throw it to some other thread)... but who would do that? data ScopeClosing = ScopeClosing instance Show ScopeClosing where - show _ = "ScopeClosing" + show _ = "<>" instance Exception ScopeClosing where toException = asyncExceptionToException fromException = asyncExceptionFromException --- Trust without verifying that any 'ScopeClosed' exception, which is not exported by this module, was indeed thrown to --- a thread by its parent. It is possible to write a program that violates this (just catch the async exception and --- throw it to some other thread)... but who would do that? -isScopeClosingException :: SomeException -> Bool -isScopeClosingException exception = - isJust (fromException @ScopeClosing exception) - -pattern IsScopeClosingException :: SomeException -pattern IsScopeClosingException <- (isScopeClosingException -> True) - -- | Open a scope, perform an IO action with it, then close the scope. -- -- ==== __👉 Details__ @@ -156,13 +150,14 @@ scoped action = do uninterruptibleMask \restore -> do result <- try (restore (action scope)) - !livingChildren <- do - livingChildren0 <- + !runningChildren <- do + runningChildren <- atomically do -- Block until we haven't committed to starting any threads. Without this, we may create a thread concurrently -- with closing its scope, and not grab its thread id to throw an exception to. starting <- readTVar statusVar - assert (starting >= 0) (guard (starting == 0)) + assertM (starting >= 0) + guard (starting == 0) -- Indicate that this scope is closing, so attempts to create a new thread within it will throw ScopeClosing -- (as if the calling thread was a parent of this scope, which it should be, and we threw it a ScopeClosing -- ourselves). @@ -174,14 +169,15 @@ scoped action = do -- If one of our children propagated an exception to us, then we know it's about to terminate, so we don't bother -- throwing an exception to it. pure case result of - Left (PropagatingFrom childId) -> IntMap.Lazy.delete childId livingChildren0 - _ -> livingChildren0 + Left (PropagatingFrom childId) -> IntMap.Lazy.delete childId runningChildren + _ -> runningChildren - -- Deliver a ScopeClosing exception to every living child. + -- Deliver a ScopeClosing exception to every running child. -- - -- This happens to throw in the order the children were created... but I think we decided this feature isn't very - -- useful in practice, so maybe we should simplify the internals and just keep a set of children? - for_ (IntMap.Lazy.elems livingChildren) \livingChild -> throwTo livingChild ScopeClosing + -- This happens to throw in the order the children were created, but that isn't an important/useful enough feature + -- to be worth documenting, so users shouldn't rely on it. It's definitely not the case that child 1 will completely + -- terminate before child 2 is delivered an exception: each child may delay arbitrarily while cleaning up. + for_ runningChildren \child -> throwTo child ScopeClosing atomically do -- Block until all children have terminated; this relies on children respecting the async exception, which they @@ -189,6 +185,7 @@ scoped action = do -- definitely not structured concurrency! children <- readTVar childrenVar guard (IntMap.Lazy.null children) + -- Record the scope as closed (from closing), so subsequent attempts to use it will throw a runtime exception writeTVar statusVar Closed @@ -233,19 +230,19 @@ spawn scope@Scope {childrenVar, statusVar} options action = do -- Record the thread as being about to start. Not allowed to retry. nonblockingAtomically do status <- nonblockingReadTVar statusVar - assert (status >= -2) do - case status of - Open -> nonblockingWriteTVar' statusVar (status + 1) - Closing -> nonblockingThrowSTM ScopeClosing - Closed -> nonblockingThrowSTM (ErrorCall "ki: scope closed") + assertM (status >= -2) + case status of + Open -> nonblockingWriteTVar' statusVar (status + 1) + Closing -> nonblockingThrowSTM ScopeClosing + Closed -> nonblockingThrowSTM (ErrorCall "ki: scope closed") childIds <- spawnChild scope options action -- Record the child as having started. Not allowed to retry. nonblockingAtomically do starting <- nonblockingReadTVar statusVar - assert (starting >= 1) do - nonblockingWriteTVar' statusVar (starting - 1) + assertM (starting >= 1) + nonblockingWriteTVar' statusVar (starting - 1) recordChild childrenVar childIds pure childIds @@ -309,7 +306,8 @@ awaitAll Scope {childrenVar, statusVar} = do children <- readTVar childrenVar guard (IntMap.Lazy.null children) status <- readTVar statusVar - assert (status >= -2) case status of + assertM (status >= -2) + case status of Open -> guard (status == 0) Closing -> retry -- block until closed Closed -> pure () @@ -336,7 +334,7 @@ forkWith scope opts action = do spawn scope opts \childId masking -> do unexceptionalTry (masking action) >>= \case Failure exception -> do - when (not (isScopeClosingException exception)) do + when (not (exceptionIs @ScopeClosing exception)) do propagateException scope childId exception -- even put async exceptions that we propagated. this isn't totally ideal because a caller awaiting this -- thread would not be able to distinguish between async exceptions delivered to this thread, or itself @@ -355,7 +353,10 @@ forkWith_ scope opts action = do _childThreadId <- spawn scope opts \childId masking -> unexceptionalTryEither - (\exception -> when (not (isScopeClosingException exception)) (propagateException scope childId exception)) + ( \exception -> + when (not (exceptionIs @ScopeClosing exception)) do + propagateException scope childId exception + ) absurd (masking action) pure () @@ -383,13 +384,12 @@ forkTryWith scope opts action = do result <- unexceptionalTry (masking action) case result of Failure exception -> do + -- then-branch explanation: if the user calls `forkTry @MyAsyncException` for some reason, we want to ignore + -- this request and propagate the async exception. `forkTry` can only be used to catch synchronous exceptions. let shouldPropagate = - if isScopeClosingException exception - then False - else case fromException @e exception of - Nothing -> True - -- if the user calls `forkTry @MyAsyncException`, we still want to propagate the async exception - Just _ -> isAsyncException exception + if exceptionIs @e exception + then exceptionIs @SomeAsyncException exception + else not (exceptionIs @ScopeClosing exception) when shouldPropagate (propagateException scope childId exception) done (BadResult exception) Success value -> done (GoodResult value) @@ -437,13 +437,16 @@ propagateException :: Scope -> Tid -> SomeException -> UnexceptionalIO () propagateException Scope {childExceptionVar, parentThreadId, statusVar} childId exception = UnexceptionalIO (readTVarIO statusVar) >>= \case Closing -> tryPutChildExceptionVar -- (A) or (B), we don't care which - status -> assert (status >= 0) loop -- we know status is Open here + status -> do + assertM (status >= 0) -- we know status is Open (0+) here; can't be Closed (-2) + loop where loop :: UnexceptionalIO () loop = unexceptionalTry (propagate exception childId parentThreadId) >>= \case - Failure IsScopeClosingException -> tryPutChildExceptionVar -- (C) - Failure _ -> loop -- (D) + Failure secondException + | exceptionIs @ScopeClosing secondException -> tryPutChildExceptionVar -- (C) + | otherwise -> loop -- (D) Success _ -> pure () tryPutChildExceptionVar :: UnexceptionalIO () diff --git a/ki/test/Tests.hs b/ki/test/Tests.hs index 690c1eb..348b00f 100644 --- a/ki/test/Tests.hs +++ b/ki/test/Tests.hs @@ -20,16 +20,16 @@ tests = scope <- Ki.scoped pure (atomically . Ki.await =<< Ki.fork scope (pure ())) `shouldThrow` ErrorCall "ki: scope closed" pure (), - testCase "`fork` throws ScopeClosing when the scope is closing" do + testCase "`fork` throws ScopeClosing to children when the scope is closing" do Ki.scoped \scope -> do _ <- Ki.forkWith scope Ki.defaultThreadOptions {Ki.maskingState = MaskedInterruptible} do -- Naughty: catch and ignore the ScopeClosing delivered to us result1 <- try @SomeException (threadDelay maxBound) - show result1 `shouldBe` "Left ScopeClosing" + show result1 `shouldBe` "Left <>" -- Try forking a new thread in the closing scope, and assert that (synchronously) throws ScopeClosing result2 <- try @SomeException (Ki.fork_ scope undefined) - show result2 `shouldBe` "Left ScopeClosing" + show result2 `shouldBe` "Left <>" pure (), testCase "`awaitAll` succeeds when no threads are alive" do Ki.scoped (atomically . Ki.awaitAll),