Skip to content

Commit

Permalink
more refactoring, cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
mitchellwrosen committed Nov 28, 2023
1 parent f5e9a45 commit cebd275
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 52 deletions.
24 changes: 20 additions & 4 deletions ki/src/Ki/Internal/IO.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ module Ki.Internal.IO
unexceptionalTryEither,

-- * Exception utils
isAsyncException,
assertIO,
assertM,
exceptionIs,
interruptiblyMasked,
uninterruptiblyMasked,
tryEitherSTM,
Expand All @@ -29,6 +31,7 @@ import GHC.Base (maskAsyncExceptions#, maskUninterruptible#)
import GHC.Conc (STM, ThreadId (ThreadId), catchSTM)
import GHC.Exts (Int (I#), fork#, forkOn#)
import GHC.IO (IO (IO))
import System.IO.Unsafe (unsafePerformIO)
import Prelude

-- A little promise that this IO action cannot throw an exception (*including* async exceptions, which you normally
Expand Down Expand Up @@ -68,9 +71,22 @@ unexceptionalTryEither onFailure onSuccess action =
(coerce @_ @(a -> IO b) onSuccess <$> action)
(pure . coerce @_ @(SomeException -> IO b) onFailure)

isAsyncException :: SomeException -> Bool
isAsyncException =
isJust . fromException @SomeAsyncException
-- | Make an assertion in a IO that requires IO.
assertIO :: IO Bool -> IO ()
assertIO b =
assert (unsafePerformIO b) (pure ())
{-# INLINE assertIO #-}

-- | Make an assertion in a monad.
assertM :: (Applicative m) => Bool -> m ()
assertM b =
assert b (pure ())
{-# INLINE assertM #-}

-- | @exceptionIs \@e exception@ returns whether @exception@ is an instance of @e@.
exceptionIs :: forall e. (Exception e) => SomeException -> Bool
exceptionIs =
isJust . fromException @e

-- | Call an action with asynchronous exceptions interruptibly masked.
interruptiblyMasked :: forall a. IO a -> IO a
Expand Down
6 changes: 4 additions & 2 deletions ki/src/Ki/Internal/Propagating.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,19 @@ where
import Control.Concurrent (ThreadId)
import Control.Exception (Exception (..), SomeException, asyncExceptionFromException, asyncExceptionToException, throwTo)

-- Internal exception type thrown by a child thread to its parent, if it fails unexpectedly.
-- Internal exception type thrown by a child thread to its parent, if the child fails unexpectedly.
data Propagating = Propagating
{ childId :: {-# UNPACK #-} !Tid,
exception :: !SomeException
}
deriving stock (Show)

instance Exception Propagating where
toException = asyncExceptionToException
fromException = asyncExceptionFromException

instance Show Propagating where
show _ = "<<internal ki exception: propagating>>"

pattern PropagatingFrom :: Tid -> SomeException
pattern PropagatingFrom childId <- (fromException -> Just Propagating {childId})

Expand Down
89 changes: 46 additions & 43 deletions ki/src/Ki/Internal/Scope.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ import Control.Concurrent.MVar (MVar, newEmptyMVar, tryPutMVar, tryTakeMVar)
import Control.Exception
( Exception (fromException, toException),
MaskingState (..),
SomeAsyncException,
SomeException,
assert,
asyncExceptionFromException,
asyncExceptionToException,
throwIO,
Expand All @@ -30,7 +30,6 @@ import Data.Foldable (for_)
import Data.Functor (void)
import Data.IntMap (IntMap)
import qualified Data.IntMap.Lazy as IntMap.Lazy
import Data.Maybe (isJust)
import Data.Void (Void, absurd)
import GHC.Conc
( STM,
Expand All @@ -53,8 +52,9 @@ import Ki.Internal.ByteCount (byteCountToInt64)
import Ki.Internal.IO
( IOResult (..),
UnexceptionalIO (..),
assertM,
exceptionIs,
interruptiblyMasked,
isAsyncException,
unexceptionalTry,
unexceptionalTryEither,
uninterruptiblyMasked,
Expand Down Expand Up @@ -97,7 +97,7 @@ data Scope = Scope
}

-- The scope status: either open (allowing new threads to be created), closing (disallowing new threads to be
-- created, and in the process of killing living children), or closed (at the very end of `scoped`)
-- created, and in the process of killing running children), or closed (at the very end of `scoped`)
type ScopeStatus = Int

-- The number of child threads that are guaranteed to be about to start, in the sense that only the GHC scheduler
Expand All @@ -117,26 +117,20 @@ pattern Closed = -2
{-# COMPLETE Open, Closing, Closed #-}

-- Internal async exception thrown by a parent thread to its children when the scope is closing.
--
-- In various places we trust without verifying that any 'ScopeClosing' exception, which is not exported by this module,
-- was indeed thrown to a thread by its parent. It is possible to write a program that violates this (just catch the
-- async exception and throw it to some other thread)... but who would do that?
data ScopeClosing
= ScopeClosing

instance Show ScopeClosing where
show _ = "ScopeClosing"
show _ = "<<internal ki exception: scope closing>>"

instance Exception ScopeClosing where
toException = asyncExceptionToException
fromException = asyncExceptionFromException

-- Trust without verifying that any 'ScopeClosed' exception, which is not exported by this module, was indeed thrown to
-- a thread by its parent. It is possible to write a program that violates this (just catch the async exception and
-- throw it to some other thread)... but who would do that?
isScopeClosingException :: SomeException -> Bool
isScopeClosingException exception =
isJust (fromException @ScopeClosing exception)

pattern IsScopeClosingException :: SomeException
pattern IsScopeClosingException <- (isScopeClosingException -> True)

-- | Open a scope, perform an IO action with it, then close the scope.
--
-- ==== __👉 Details__
Expand All @@ -156,13 +150,14 @@ scoped action = do
uninterruptibleMask \restore -> do
result <- try (restore (action scope))

!livingChildren <- do
livingChildren0 <-
!runningChildren <- do
runningChildren <-
atomically do
-- Block until we haven't committed to starting any threads. Without this, we may create a thread concurrently
-- with closing its scope, and not grab its thread id to throw an exception to.
starting <- readTVar statusVar
assert (starting >= 0) (guard (starting == 0))
assertM (starting >= 0)
guard (starting == 0)
-- Indicate that this scope is closing, so attempts to create a new thread within it will throw ScopeClosing
-- (as if the calling thread was a parent of this scope, which it should be, and we threw it a ScopeClosing
-- ourselves).
Expand All @@ -174,21 +169,23 @@ scoped action = do
-- If one of our children propagated an exception to us, then we know it's about to terminate, so we don't bother
-- throwing an exception to it.
pure case result of
Left (PropagatingFrom childId) -> IntMap.Lazy.delete childId livingChildren0
_ -> livingChildren0
Left (PropagatingFrom childId) -> IntMap.Lazy.delete childId runningChildren
_ -> runningChildren

-- Deliver a ScopeClosing exception to every living child.
-- Deliver a ScopeClosing exception to every running child.
--
-- This happens to throw in the order the children were created... but I think we decided this feature isn't very
-- useful in practice, so maybe we should simplify the internals and just keep a set of children?
for_ (IntMap.Lazy.elems livingChildren) \livingChild -> throwTo livingChild ScopeClosing
-- This happens to throw in the order the children were created, but that isn't an important/useful enough feature
-- to be worth documenting, so users shouldn't rely on it. It's definitely not the case that child 1 will completely
-- terminate before child 2 is delivered an exception: each child may delay arbitrarily while cleaning up.
for_ runningChildren \child -> throwTo child ScopeClosing

atomically do
-- Block until all children have terminated; this relies on children respecting the async exception, which they
-- must, for correctness. Otherwise, a thread could indeed outlive the scope in which it's created, which is
-- definitely not structured concurrency!
children <- readTVar childrenVar
guard (IntMap.Lazy.null children)

-- Record the scope as closed (from closing), so subsequent attempts to use it will throw a runtime exception
writeTVar statusVar Closed

Expand Down Expand Up @@ -233,19 +230,19 @@ spawn scope@Scope {childrenVar, statusVar} options action = do
-- Record the thread as being about to start. Not allowed to retry.
nonblockingAtomically do
status <- nonblockingReadTVar statusVar
assert (status >= -2) do
case status of
Open -> nonblockingWriteTVar' statusVar (status + 1)
Closing -> nonblockingThrowSTM ScopeClosing
Closed -> nonblockingThrowSTM (ErrorCall "ki: scope closed")
assertM (status >= -2)
case status of
Open -> nonblockingWriteTVar' statusVar (status + 1)
Closing -> nonblockingThrowSTM ScopeClosing
Closed -> nonblockingThrowSTM (ErrorCall "ki: scope closed")

childIds <- spawnChild scope options action

-- Record the child as having started. Not allowed to retry.
nonblockingAtomically do
starting <- nonblockingReadTVar statusVar
assert (starting >= 1) do
nonblockingWriteTVar' statusVar (starting - 1)
assertM (starting >= 1)
nonblockingWriteTVar' statusVar (starting - 1)
recordChild childrenVar childIds

pure childIds
Expand Down Expand Up @@ -309,7 +306,8 @@ awaitAll Scope {childrenVar, statusVar} = do
children <- readTVar childrenVar
guard (IntMap.Lazy.null children)
status <- readTVar statusVar
assert (status >= -2) case status of
assertM (status >= -2)
case status of
Open -> guard (status == 0)
Closing -> retry -- block until closed
Closed -> pure ()
Expand All @@ -336,7 +334,7 @@ forkWith scope opts action = do
spawn scope opts \childId masking -> do
unexceptionalTry (masking action) >>= \case
Failure exception -> do
when (not (isScopeClosingException exception)) do
when (not (exceptionIs @ScopeClosing exception)) do
propagateException scope childId exception
-- even put async exceptions that we propagated. this isn't totally ideal because a caller awaiting this
-- thread would not be able to distinguish between async exceptions delivered to this thread, or itself
Expand All @@ -355,7 +353,10 @@ forkWith_ scope opts action = do
_childThreadId <-
spawn scope opts \childId masking ->
unexceptionalTryEither
(\exception -> when (not (isScopeClosingException exception)) (propagateException scope childId exception))
( \exception ->
when (not (exceptionIs @ScopeClosing exception)) do
propagateException scope childId exception
)
absurd
(masking action)
pure ()
Expand Down Expand Up @@ -383,13 +384,12 @@ forkTryWith scope opts action = do
result <- unexceptionalTry (masking action)
case result of
Failure exception -> do
-- then-branch explanation: if the user calls `forkTry @MyAsyncException` for some reason, we want to ignore
-- this request and propagate the async exception. `forkTry` can only be used to catch synchronous exceptions.
let shouldPropagate =
if isScopeClosingException exception
then False
else case fromException @e exception of
Nothing -> True
-- if the user calls `forkTry @MyAsyncException`, we still want to propagate the async exception
Just _ -> isAsyncException exception
if exceptionIs @e exception
then exceptionIs @SomeAsyncException exception
else not (exceptionIs @ScopeClosing exception)
when shouldPropagate (propagateException scope childId exception)
done (BadResult exception)
Success value -> done (GoodResult value)
Expand Down Expand Up @@ -437,13 +437,16 @@ propagateException :: Scope -> Tid -> SomeException -> UnexceptionalIO ()
propagateException Scope {childExceptionVar, parentThreadId, statusVar} childId exception =
UnexceptionalIO (readTVarIO statusVar) >>= \case
Closing -> tryPutChildExceptionVar -- (A) or (B), we don't care which
status -> assert (status >= 0) loop -- we know status is Open here
status -> do
assertM (status >= 0) -- we know status is Open (0+) here; can't be Closed (-2)
loop
where
loop :: UnexceptionalIO ()
loop =
unexceptionalTry (propagate exception childId parentThreadId) >>= \case
Failure IsScopeClosingException -> tryPutChildExceptionVar -- (C)
Failure _ -> loop -- (D)
Failure secondException
| exceptionIs @ScopeClosing secondException -> tryPutChildExceptionVar -- (C)
| otherwise -> loop -- (D)
Success _ -> pure ()

tryPutChildExceptionVar :: UnexceptionalIO ()
Expand Down
6 changes: 3 additions & 3 deletions ki/test/Tests.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ tests =
scope <- Ki.scoped pure
(atomically . Ki.await =<< Ki.fork scope (pure ())) `shouldThrow` ErrorCall "ki: scope closed"
pure (),
testCase "`fork` throws ScopeClosing when the scope is closing" do
testCase "`fork` throws ScopeClosing to children when the scope is closing" do
Ki.scoped \scope -> do
_ <-
Ki.forkWith scope Ki.defaultThreadOptions {Ki.maskingState = MaskedInterruptible} do
-- Naughty: catch and ignore the ScopeClosing delivered to us
result1 <- try @SomeException (threadDelay maxBound)
show result1 `shouldBe` "Left ScopeClosing"
show result1 `shouldBe` "Left <<internal ki exception: scope closing>>"
-- Try forking a new thread in the closing scope, and assert that (synchronously) throws ScopeClosing
result2 <- try @SomeException (Ki.fork_ scope undefined)
show result2 `shouldBe` "Left ScopeClosing"
show result2 `shouldBe` "Left <<internal ki exception: scope closing>>"
pure (),
testCase "`awaitAll` succeeds when no threads are alive" do
Ki.scoped (atomically . Ki.awaitAll),
Expand Down

0 comments on commit cebd275

Please sign in to comment.