Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework data pattern matching to use default cases #5557

Merged
merged 3 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 117 additions & 35 deletions unison-runtime/src/Unison/Runtime/Machine.hs
Original file line number Diff line number Diff line change
Expand Up @@ -683,14 +683,13 @@ eval env !denv !activeThreads !stk !k r (Match i br) = do
n <- peekOffN stk i
eval env denv activeThreads stk k r $ selectBranch n br
eval env !denv !activeThreads !stk !k r (DMatch mr i br) = do
(t, stk) <- dumpDataNoTag mr stk =<< peekOff stk i
eval env denv activeThreads stk k r $
selectBranch (maskTags t) br
(nx, stk) <- dataBranch mr stk br =<< bpeekOff stk i
eval env denv activeThreads stk k r nx
eval env !denv !activeThreads !stk !k r (NMatch _mr i br) = do
n <- peekOffN stk i
eval env denv activeThreads stk k r $ selectBranch n br
eval env !denv !activeThreads !stk !k r (RMatch i pu br) = do
(t, stk) <- dumpDataNoTag Nothing stk =<< peekOff stk i
(t, stk) <- dumpDataValNoTag stk =<< peekOff stk i
if t == TT.pureEffectTag
then eval env denv activeThreads stk k r pu
else case ANF.unpackTags t of
Expand Down Expand Up @@ -1000,46 +999,41 @@ buildData !stk !r !t (VArgV i) = do
l = fsize stk - i
{-# INLINE buildData #-}

dumpDataValNoTag ::
Stack ->
Val ->
IO (PackedTag, Stack)
dumpDataValNoTag stk (BoxedVal c) =
(closureTag c,) <$> dumpDataNoTag Nothing stk c
dumpDataValNoTag _ v =
die $ "dumpDataValNoTag: unboxed val: " ++ show v
{-# inline dumpDataValNoTag #-}

-- Dumps a data type closure to the stack without writing its tag.
-- Instead, the tag is returned for direct case analysis.
dumpDataNoTag ::
Maybe Reference ->
Stack ->
Val ->
IO (PackedTag, Stack)
Closure ->
IO Stack
dumpDataNoTag !mr !stk = \case
-- Normally we want to avoid dumping unboxed values since it's unnecessary, but sometimes we don't know the type of
-- the incoming value and end up dumping unboxed values, so we just push them back to the stack as-is. e.g. in type-casts/coercions
val@(UnboxedVal _ t) -> do
Enum _ _ -> pure stk
Data1 _ _ x -> do
stk <- bump stk
poke stk val
pure (unboxedPackedTag t, stk)
BoxedVal clos -> case clos of
(Enum _ t) -> pure (t, stk)
(Data1 _ t x) -> do
stk <- bump stk
poke stk x
pure (t, stk)
(Data2 _ t x y) -> do
stk <- bumpn stk 2
pokeOff stk 1 y
poke stk x
pure (t, stk)
(DataG _ t seg) -> do
stk <- dumpSeg stk seg S
pure (t, stk)
clo ->
die $
"dumpDataNoTag: bad closure: "
++ show clo
++ maybe "" (\r -> "\nexpected type: " ++ show r) mr
where
unboxedPackedTag :: UnboxedTypeTag -> PackedTag
unboxedPackedTag = \case
CharTag -> TT.charTag
FloatTag -> TT.floatTag
IntTag -> TT.intTag
NatTag -> TT.natTag
poke stk x
pure stk
Data2 _ _ x y -> do
stk <- bumpn stk 2
pokeOff stk 1 y
stk <$ poke stk x
DataG _ _ seg -> dumpSeg stk seg S
clo ->
die $
"dumpDataNoTag: bad closure: "
++ show clo
++ maybe "" (\r -> "\nexpected type: " ++ show r) mr
{-# INLINE dumpDataNoTag #-}

-- Note: although the representation allows it, it is impossible
Expand Down Expand Up @@ -1995,6 +1989,94 @@ selectBranch t (TestW df cs) = lookupWithDefault df t cs
selectBranch _ (TestT {}) = error "impossible"
{-# INLINE selectBranch #-}

-- Combined branch selection and field dumping function for data types.
-- Fields should only be dumped on _matches_, not default cases, because
-- default cases potentially cover many constructors which could result
-- in a variable number of values being put on the stack. Default cases
-- uniformly expect _no_ values to be added to the stack.
dataBranch
:: Maybe Reference -> Stack -> MBranch -> Closure -> IO (MSection, Stack)
dataBranch mrf stk (Test1 u cu df) = \case
Enum _ t
| maskTags t == u -> pure (cu, stk)
| otherwise -> pure (df, stk)
Data1 _ t x
| maskTags t == u -> do
stk <- bump stk
(cu, stk) <$ poke stk x
| otherwise -> pure (df, stk)
Data2 _ t x y
| maskTags t == u -> do
stk <- bumpn stk 2
pokeOff stk 1 y
(cu, stk) <$ poke stk x
| otherwise -> pure (df, stk)
DataG _ t seg
| maskTags t == u -> (cu,) <$> dumpSeg stk seg S
| otherwise -> pure (df, stk)
clo -> dataBranchClosureError mrf clo
dataBranch mrf stk (Test2 u cu v cv df) = \case
Enum _ t
| maskTags t == u -> pure (cu, stk)
| maskTags t == v -> pure (cv, stk)
| otherwise -> pure (df, stk)
Data1 _ t x
| maskTags t == u -> do
stk <- bump stk
(cu, stk) <$ poke stk x
| maskTags t == v -> do
stk <- bump stk
(cv, stk) <$ poke stk x
| otherwise -> pure (df, stk)
Data2 _ t x y
| maskTags t == u -> do
stk <- bumpn stk 2
pokeOff stk 1 y
(cu, stk) <$ poke stk x
| maskTags t == v -> do
stk <- bumpn stk 2
pokeOff stk 1 y
(cv, stk) <$ poke stk x
| otherwise -> pure (df, stk)
DataG _ t seg
| maskTags t == u -> (cu,) <$> dumpSeg stk seg S
| maskTags t == v -> (cv,) <$> dumpSeg stk seg S
| otherwise -> pure (df, stk)
clo -> dataBranchClosureError mrf clo
dataBranch mrf stk (TestW df bs) = \case
Enum _ t
| Just ca <- EC.lookup (maskTags t) bs -> pure (ca, stk)
| otherwise -> pure (df, stk)
Data1 _ t x
| Just ca <- EC.lookup (maskTags t) bs -> do
stk <- bump stk
(ca, stk) <$ poke stk x
| otherwise -> pure (df, stk)
Data2 _ t x y
| Just ca <- EC.lookup (maskTags t) bs -> do
stk <- bumpn stk 2
pokeOff stk 1 y
(ca, stk) <$ poke stk x
| otherwise -> pure (df, stk)
DataG _ t seg
| Just ca <- EC.lookup (maskTags t) bs ->
(ca,) <$> dumpSeg stk seg S
| otherwise -> pure (df, stk)
clo -> dataBranchClosureError mrf clo
dataBranch _ _ br = \_ ->
dataBranchBranchError br
{-# inline dataBranch #-}

dataBranchClosureError :: Maybe Reference -> Closure -> IO a
dataBranchClosureError mrf clo =
die $ "dataBranch: bad closure: "
++ show clo
++ maybe "" (\ r -> "\nexpected type: " ++ show r) mrf

dataBranchBranchError :: MBranch -> IO a
dataBranchBranchError br =
die $ "dataBranch: unexpected branch: " ++ show br

-- Splits off a portion of the continuation up to a given prompt.
--
-- The main procedure walks along the 'code' stack `k`, keeping track of how
Expand Down
69 changes: 64 additions & 5 deletions unison-runtime/src/Unison/Runtime/Pattern.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ module Unison.Runtime.Pattern
where

import Control.Monad.State (State, evalState, modify, runState, state)
import Data.Containers.ListUtils (nubOrd)
import Data.List (transpose)
import Data.Map.Strict
( fromListWith,
Expand Down Expand Up @@ -92,6 +93,11 @@ builtinDataSpec = Map.fromList decls
| (_, x, y) <- builtinEffectDecls
]

findPattern :: Eq v => v -> PatternRow v -> Maybe (Pattern v)
findPattern v (PR ms _ _)
| (_, p : _) <- break ((== v) . loc) ms = Just p
| otherwise = Nothing

-- A pattern compilation matrix is just a list of rows. There is
-- no need for the rows to have uniform length; the variable
-- annotations on the patterns in the rows keep track of what
Expand Down Expand Up @@ -125,8 +131,11 @@ refutable (P.Unbound _) = False
refutable (P.Var _) = False
refutable _ = True

rowIrrefutable :: PatternRow v -> Bool
rowIrrefutable (PR ps _ _) = null ps
noMatches :: PatternRow v -> Bool
noMatches (PR ps _ _) = null ps

rowRefutable :: PatternRow v -> Bool
rowRefutable (PR ps g _) = isJust g || not (null ps)

firstRow :: ([P.Pattern v] -> Maybe v) -> Heuristic v
firstRow f (PM (r : _)) = f $ matches r
Expand Down Expand Up @@ -481,6 +490,19 @@ splitMatrix v rf cons (PM rs) =
where
mmap = fmap (\(t, fs) -> (t, splitRow v rf t fs =<< rs)) cons

-- Eliminates a variable from a matrix, keeping the rows that are
-- _not_ specific matches on that variable (so, would potentially
-- occur in a default case).
antiSplitMatrix ::
(Var v) =>
v ->
PatternMatrix v ->
PatternMatrix v
antiSplitMatrix v (PM rs) = PM (f =<< rs)
where
-- keep rows that do not have a refutable pattern for v
f r = [ r | isNothing $ findPattern v r ]

-- Monad for pattern preparation. It is a state monad carrying a fresh
-- variable source, the list of variables bound the pattern being
-- prepared, and a variable renaming mapping.
Expand Down Expand Up @@ -596,7 +618,7 @@ compile _ _ (PM []) = apps' bu [text () "pattern match failure"]
where
bu = ref () (Builtin "bug")
compile spec ctx m@(PM (r : rs))
| rowIrrefutable r =
| noMatches r =
case guard r of
Nothing -> body r
Just g -> iff mempty g (body r) $ compile spec ctx (PM rs)
Expand All @@ -614,8 +636,11 @@ compile spec ctx m@(PM (r : rs))
case lookupData rf spec of
Right cons ->
match () (var () v) $
buildCase spec rf False cons ctx
<$> splitMatrix v (Just rf) (numberCons cons) m
(buildCase spec rf False cons ctx
<$> splitMatrix v (Just rf) ncons m)
++ buildDefaultCase spec False needDefault ctx dm
where
needDefault = length ncons < length cons
Left err -> internalBug err
| PReq rfs <- ty =
match () (var () v) $
Expand All @@ -631,7 +656,29 @@ compile spec ctx m@(PM (r : rs))
internalBug "unknown pattern compilation type"
where
v = choose heuristics m
ncons = relevantConstructors m v
ty = Map.findWithDefault Unknown v ctx
dm = antiSplitMatrix v m

-- Calculates the data constructors—with their arities—that should be
-- matched on when splitting a matrix on a given variable. This
-- includes
relevantConstructors :: Ord v => PatternMatrix v -> v -> [(Int, Int)]
relevantConstructors (PM rows) v = search [] rows
where
search acc (row : rows)
| rowRefutable row = case findPattern v row of
Just (P.Constructor _ (ConstructorReference _ t) sps) ->
search ((fromIntegral t, length sps) : acc) rows
Just (P.Boolean _ b) ->
search ((if b then 1 else 0, 0) : acc) rows
Just p ->
internalBug $ "unexpected data pattern: " ++ show p
-- if the pattern is not found, it must have been irrefutable,
-- so contributes no relevant constructor.
_ -> search acc rows
-- irrefutable row, or no rows left
search acc _ = nubOrd $ reverse acc

buildCaseBuiltin ::
(Var v) =>
Expand Down Expand Up @@ -677,6 +724,18 @@ buildCase spec r eff cons ctx0 (t, vts, m) =
vs = ((),) . fst <$> vts
ctx = Map.fromList vts <> ctx0

buildDefaultCase ::
(Var v) =>
DataSpec ->
Bool ->
Bool ->
Ctx v ->
PatternMatrix v ->
[MatchCase () (Term v)]
buildDefaultCase spec _eff needed ctx pm
| needed = [MatchCase (Unbound ()) Nothing $ compile spec ctx pm]
| otherwise = []

mkRow ::
(Var v) =>
v ->
Expand Down
12 changes: 11 additions & 1 deletion unison-runtime/src/Unison/Runtime/Stack.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ module Unison.Runtime.Stack
BlackHole,
UnboxedTypeTag
),
closureTag,
UnboxedTypeTag (..),
unboxedTypeTagToInt,
unboxedTypeTagFromInt,
Expand Down Expand Up @@ -153,7 +154,7 @@ module Unison.Runtime.Stack
)
where

import Control.Exception (throwIO)
import Control.Exception (throw, throwIO)
import Control.Monad.Primitive
import Data.Char qualified as Char
import Data.IORef (IORef)
Expand Down Expand Up @@ -371,6 +372,15 @@ splitData = \case
(DataG r t seg) -> Just (r, t, segToList seg)
_ -> Nothing

closureTag :: Closure -> PackedTag
closureTag (Enum _ t) = t
closureTag (Data1 _ t _) = t
closureTag (Data2 _ t _ _) = t
closureTag (DataG _ t _) = t
closureTag c =
throw $ Panic "closureTag: unexpected closure" (Just $ BoxedVal c)
{-# inline closureTag #-}

-- | Converts a list of integers representing an unboxed segment back into the
-- appropriate segment. Segments are stored backwards in the runtime, so this
-- reverses the list.
Expand Down
Loading