diff --git a/src/Control/Monad/Bayes/Class.hs b/src/Control/Monad/Bayes/Class.hs index b11f3f86..eb1945d9 100644 --- a/src/Control/Monad/Bayes/Class.hs +++ b/src/Control/Monad/Bayes/Class.hs @@ -79,9 +79,9 @@ import Control.Monad (replicateM, when) import Control.Monad.Cont (ContT) import Control.Monad.Except (ExceptT, lift) import Control.Monad.Identity (IdentityT) -import Control.Monad.List (ListT) import Control.Monad.Reader (ReaderT) import Control.Monad.State (StateT) +import Control.Monad.Trans.Free.Ap (FreeT) import Control.Monad.Writer (WriterT) import Data.Histogram qualified as H import Data.Histogram.Fill qualified as H @@ -390,15 +390,15 @@ instance MonadFactor m => MonadFactor (StateT s m) where instance MonadMeasure m => MonadMeasure (StateT s m) -instance MonadDistribution m => MonadDistribution (ListT m) where +instance (Applicative f, MonadDistribution m) => MonadDistribution (FreeT f m) where random = lift random bernoulli = lift . bernoulli categorical = lift . categorical -instance MonadFactor m => MonadFactor (ListT m) where +instance (Applicative f, MonadFactor m) => MonadFactor (FreeT f m) where score = lift . score -instance MonadMeasure m => MonadMeasure (ListT m) +instance (Applicative f, MonadMeasure m) => MonadMeasure (FreeT f m) instance MonadDistribution m => MonadDistribution (ContT r m) where random = lift random diff --git a/src/Control/Monad/Bayes/Population.hs b/src/Control/Monad/Bayes/Population.hs index be670df2..4e305223 100644 --- a/src/Control/Monad/Bayes/Population.hs +++ b/src/Control/Monad/Bayes/Population.hs @@ -53,7 +53,9 @@ import Control.Monad.Bayes.Weighted weighted, withWeight, ) -import Control.Monad.List (ListT (..), MonadIO, MonadTrans (..)) +import Control.Monad.IO.Class +import Control.Monad.Trans +import Control.Monad.Trans.Free.Ap import Data.List (unfoldr) import Data.List qualified import Data.Maybe (catMaybes) @@ -64,7 +66,7 @@ import Numeric.Log qualified as Log import Prelude hiding (all, sum) -- | A collection of weighted samples, or particles. -newtype Population m a = Population (Weighted (ListT m) a) +newtype Population m a = Population (Weighted (FreeT [] m) a) deriving newtype (Functor, Applicative, Monad, MonadIO, MonadDistribution, MonadFactor, MonadMeasure) instance MonadTrans Population where @@ -72,19 +74,19 @@ instance MonadTrans Population where -- | Explicit representation of the weighted sample with weights in the log -- domain. -population, runPopulation :: Population m a -> m [(a, Log Double)] -population (Population m) = runListT $ weighted m +population, runPopulation :: Monad m => Population m a -> m [(a, Log Double)] +population (Population m) = iterT ((fmap concat . sequence)) $ fmap pure $ weighted m -- | deprecated synonym runPopulation = population -- | Explicit representation of the weighted sample. -explicitPopulation :: Functor m => Population m a -> m [(a, Double)] +explicitPopulation :: Monad m => Population m a -> m [(a, Double)] explicitPopulation = fmap (map (second (exp . ln))) . population -- | Initialize 'Population' with a concrete weighted sample. fromWeightedList :: Monad m => m [(a, Log Double)] -> Population m a -fromWeightedList = Population . withWeight . ListT +fromWeightedList = Population . withWeight . FreeT . fmap (Free . fmap pure) -- | Increase the sample size by a given factor. -- The weights are adjusted such that their sum is preserved. @@ -269,7 +271,7 @@ popAvg f p = do -- | Applies a transformation to the inner monad. hoist :: - Monad n => + (Monad m, Monad n) => (forall x. m x -> n x) -> Population m a -> Population n a