Skip to content

Commit 700a19a

Browse files
penelopeysmyebai
andauthored
Implement getstepsize() for NoAdaptation samplers (#2405)
Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>
1 parent 365fe16 commit 700a19a

File tree

3 files changed

+20
-1
lines changed

3 files changed

+20
-1
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.35.3"
3+
version = "0.35.4"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/mcmc/hmc.jl

+6
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,12 @@ end
463463

464464
getstepsize(sampler::Sampler{<:Hamiltonian}, state) = sampler.alg.ϵ
465465
getstepsize(sampler::Sampler{<:AdaptiveHamiltonian}, state) = AHMC.getϵ(state.adaptor)
466+
function getstepsize(
467+
sampler::Sampler{<:AdaptiveHamiltonian},
468+
state::HMCState{TV,TKernel,THam,PhType,AHMC.Adaptation.NoAdaptation},
469+
) where {TV,TKernel,THam,PhType}
470+
return state.kernel.τ.integrator.ϵ
471+
end
466472

467473
gen_metric(dim::Int, spl::Sampler{<:Hamiltonian}, state) = AHMC.UnitEuclideanMetric(dim)
468474
function gen_metric(dim::Int, spl::Sampler{<:AdaptiveHamiltonian}, state)

test/mcmc/hmc.jl

+13
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,19 @@ using Turing
329329
@test pvalue(ApproximateTwoSampleKSTest(vec(results), vec(results_prior))) > 0.001
330330
end
331331

332+
@testset "getstepsize: Turing.jl#2400" begin
333+
algs = [HMC(0.1, 10), HMCDA(0.8, 0.75), NUTS(0.5), NUTS(0, 0.5)]
334+
@testset "$(alg)" for alg in algs
335+
# Construct a HMC state by taking a single step
336+
spl = Sampler(alg, gdemo_default)
337+
hmc_state = DynamicPPL.initialstep(
338+
Random.default_rng(), gdemo_default, spl, DynamicPPL.VarInfo(gdemo_default)
339+
)[2]
340+
# Check that we can obtain the current step size
341+
@test Turing.Inference.getstepsize(spl, hmc_state) isa Float64
342+
end
343+
end
344+
332345
@testset "Check ADType" begin
333346
alg = HMC(0.1, 10; adtype=adbackend)
334347
m = DynamicPPL.contextualize(

0 commit comments

Comments
 (0)