Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove selector/space stuff #2458

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Release 0.37.0

## Breaking changes

0.37 removes the old Gibbs constructors deprecated in 0.36.

# Release 0.36.0

## Breaking changes
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.36.0"
version = "0.37.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
28 changes: 2 additions & 26 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,18 +91,6 @@ abstract type Hamiltonian <: InferenceAlgorithm end
abstract type StaticHamiltonian <: Hamiltonian end
abstract type AdaptiveHamiltonian <: Hamiltonian end

# TODO(mhauru) Remove the below function once all the space/Selector stuff has been removed.
"""
drop_space(alg::InferenceAlgorithm)

Return an `InferenceAlgorithm` like `alg`, but with all space information removed.
"""
function drop_space end

function drop_space(sampler::Sampler)
return Sampler(drop_space(sampler.alg), sampler.selector)
end

include("repeat_sampler.jl")

"""
Expand Down Expand Up @@ -146,11 +134,6 @@ struct ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrain
end
end

# External samplers don't have notion of space to begin with.
drop_space(x::ExternalSampler) = x

DynamicPPL.getspace(::ExternalSampler) = ()

"""
requires_unconstrained_space(sampler::ExternalSampler)

Expand Down Expand Up @@ -217,8 +200,6 @@ Algorithm for sampling from the prior.
"""
struct Prior <: InferenceAlgorithm end

drop_space(x::Prior) = x

function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
Expand Down Expand Up @@ -592,13 +573,6 @@ include("emcee.jl")
# Typing tools #
################

for alg in (:SMC, :PG, :MH, :IS, :ESS, :Emcee)
@eval DynamicPPL.getspace(::$alg{space}) where {space} = space
end
for alg in (:HMC, :HMCDA, :NUTS, :SGLD, :SGHMC)
@eval DynamicPPL.getspace(::$alg{<:Any,space}) where {space} = space
end

function DynamicPPL.get_matching_type(
spl::Sampler{<:Union{PG,SMC}}, vi, ::Type{TV}
) where {T,N,TV<:Array{T,N}}
Expand All @@ -609,6 +583,8 @@ end
# Utilities #
##############

# TODO(mhauru) Remove this once DynamicPPL has removed all its Selector stuff.
DynamicPPL.getspace(::InferenceAlgorithm) = ()
DynamicPPL.getspace(spl::Sampler) = getspace(spl.alg)
DynamicPPL.inspace(vn::VarName, spl::Sampler) = inspace(vn, getspace(spl.alg))

Expand Down
6 changes: 2 additions & 4 deletions src/mcmc/emcee.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Foreman-Mackey, D., Hogg, D. W., Lang, D., & Goodman, J. (2013).
emcee: The MCMC Hammer. Publications of the Astronomical Society of the
Pacific, 125 (925), 306. https://doi.org/10.1086/670067
"""
struct Emcee{space,E<:AMH.Ensemble} <: InferenceAlgorithm
struct Emcee{E<:AMH.Ensemble} <: InferenceAlgorithm
ensemble::E
end

Expand All @@ -23,11 +23,9 @@ function Emcee(n_walkers::Int, stretch_length=2.0)
# ensemble sampling.
prop = AMH.StretchProposal(nothing, stretch_length)
ensemble = AMH.Ensemble(n_walkers, prop)
return Emcee{(),typeof(ensemble)}(ensemble)
return Emcee{typeof(ensemble)}(ensemble)
end

drop_space(alg::Emcee{space,E}) where {space,E} = Emcee{(),E}(alg.ensemble)

struct EmceeState{V<:AbstractVarInfo,S}
vi::V
states::S
Expand Down
40 changes: 11 additions & 29 deletions src/mcmc/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,7 @@ Mean
│ 1 │ m │ 0.824853 │
```
"""
struct ESS{space} <: InferenceAlgorithm end

ESS() = ESS{()}()
ESS(space::Symbol) = ESS{(space,)}()

drop_space(alg::ESS) = ESS()
struct ESS <: InferenceAlgorithm end

# always accept in the first step
function DynamicPPL.initialstep(
Expand All @@ -35,7 +30,7 @@ function DynamicPPL.initialstep(
vns = _getvns(vi, spl)
length(vns) == 1 ||
error("[ESS] does only support one variable ($(length(vns)) variables specified)")
for vn in vns[1]
for vn in only(vns)
dist = getdist(vi, vn)
EllipticalSliceSampling.isgaussian(typeof(dist)) ||
error("[ESS] only supports Gaussian prior distributions")
Expand All @@ -48,7 +43,7 @@ function AbstractMCMC.step(
rng::AbstractRNG, model::Model, spl::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs...
)
# obtain previous sample
f = vi[spl]
f = vi[:]

# define previous sampler state
# (do not use cache to avoid in-place sampling from prior)
Expand Down Expand Up @@ -129,13 +124,11 @@ function (ℓ::ESSLogLikelihood)(f::AbstractVector)
end

function DynamicPPL.tilde_assume(
rng::Random.AbstractRNG, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, vn, vi
rng::Random.AbstractRNG, ::DefaultContext, ::Sampler{<:ESS}, right, vn, vi
)
return if inspace(vn, sampler)
DynamicPPL.tilde_assume(rng, LikelihoodContext(), SampleFromPrior(), right, vn, vi)
else
DynamicPPL.tilde_assume(rng, ctx, SampleFromPrior(), right, vn, vi)
end
return DynamicPPL.tilde_assume(
rng, LikelihoodContext(), SampleFromPrior(), right, vn, vi
)
end

function DynamicPPL.tilde_observe(
Expand All @@ -145,22 +138,11 @@ function DynamicPPL.tilde_observe(
end

function DynamicPPL.dot_tilde_assume(
rng::Random.AbstractRNG,
ctx::DefaultContext,
sampler::Sampler{<:ESS},
right,
left,
vns,
vi,
rng::Random.AbstractRNG, ::DefaultContext, ::Sampler{<:ESS}, right, left, vns, vi
)
# TODO: Or should we do `all(Base.Fix2(inspace, sampler), vns)`?
return if inspace(first(vns), sampler)
DynamicPPL.dot_tilde_assume(
rng, LikelihoodContext(), SampleFromPrior(), right, left, vns, vi
)
else
DynamicPPL.dot_tilde_assume(rng, ctx, SampleFromPrior(), right, left, vns, vi)
end
return DynamicPPL.dot_tilde_assume(
rng, LikelihoodContext(), SampleFromPrior(), right, left, vns, vi
)
end

function DynamicPPL.dot_tilde_observe(
Expand Down
45 changes: 1 addition & 44 deletions src/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ struct Gibbs{N,V<:NTuple{N,AbstractVector{<:VarName}},A<:NTuple{N,Any}} <:

# Ensure that samplers have the same selector, and that varnames are lists of
# VarNames.
samplers = tuple(map(set_selector ∘ drop_space, samplers)...)
samplers = tuple(map(set_selector, samplers)...)
varnames = tuple(map(to_varname_list, varnames)...)
return new{length(samplers),typeof(varnames),typeof(samplers)}(varnames, samplers)
end
Expand All @@ -355,49 +355,6 @@ function Gibbs(algs::Pair...)
return Gibbs(map(first, algs), map(last, algs))
end

# The below two constructors only provide backwards compatibility with the constructor of
# the old Gibbs sampler. They are deprecated and will be removed in the future.
function Gibbs(alg1::InferenceAlgorithm, other_algs::InferenceAlgorithm...)
algs = [alg1, other_algs...]
varnames = map(algs) do alg
space = getspace(alg)
if (space isa VarName)
space
elseif (space isa Symbol)
VarName{space}()
else
tuple((s isa Symbol ? VarName{s}() : s for s in space)...)
end
end
msg = (
"Specifying which sampler to use with which variable using syntax like " *
"`Gibbs(NUTS(:x), MH(:y))` is deprecated and will be removed in the future. " *
"Please use `Gibbs(; x=NUTS(), y=MH())` instead. If you want different iteration " *
"counts for different subsamplers, use e.g. " *
"`Gibbs(@varname(x) => RepeatSampler(NUTS(), 2), @varname(y) => MH())`"
)
Base.depwarn(msg, :Gibbs)
return Gibbs(varnames, map(set_selector ∘ drop_space, algs))
end

function Gibbs(
alg_with_iters1::Tuple{<:InferenceAlgorithm,Int},
other_algs_with_iters::Tuple{<:InferenceAlgorithm,Int}...,
)
algs_with_iters = [alg_with_iters1, other_algs_with_iters...]
algs = Iterators.map(first, algs_with_iters)
iters = Iterators.map(last, algs_with_iters)
algs_duplicated = Iterators.flatten((
Iterators.repeated(alg, iter) for (alg, iter) in zip(algs, iters)
))
# This calls the other deprecated constructor from above, hence no need for a depwarn
# here.
return Gibbs(algs_duplicated...)
end

# TODO: Remove when no longer needed.
DynamicPPL.getspace(::Gibbs) = ()

struct GibbsState{V<:DynamicPPL.AbstractVarInfo,S}
vi::V
states::S
Expand Down
Loading
Loading