Skip to content

Commit

Permalink
Merge pull request #503 from CliMA/aj/gpu_workaround_testing
Browse files Browse the repository at this point in the history
MNWE for cloud sedimentation GPU errors
  • Loading branch information
trontrytel authored Jan 14, 2025
2 parents 384c393 + 5727654 commit a690880
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@ Thermodynamics = "b60c26fb-14c3-4610-9d3e-2d17fe7ff00c"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"

[compat]
Dierckx = "0.5.3"
Documenter = "1.1"
DocumenterCitations = "1.3.3"
13 changes: 13 additions & 0 deletions src/parameters/Parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,17 @@ include("MicrophysicsP3.jl")
# Terminal velocity parameters (can be used with different microph. schemes)
include("TerminalVelocity.jl")

for T in (
Chen2022VelTypeRain,
Chen2022VelTypeSmallIce,
Chen2022VelTypeLargeIce,
Chen2022VelType,
CloudLiquid,
)
@eval Base.Broadcast.broadcastable(x::$T) = x
@eval Base.ndims(::Type{<:$T}) = 0
@eval Base.size(::$T) = ()
@eval Base.@propagate_inbounds Base.getindex(x::$T, i) = x
end

end # module
6 changes: 3 additions & 3 deletions test/gpu_clima_core_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import CloudMicrophysics.MicrophysicsNonEq as CMN
function make_column(::Type{FT}) where {FT}

context = ClimaComms.SingletonCommsContext(ClimaComms.CUDADevice())
#context = ClimaComms.context()

vert_domain = CC.Domains.IntervalDomain(
CC.Geometry.ZPoint{FT}(FT(0)),
Expand All @@ -32,6 +33,7 @@ end
function make_extruded_sphere(::Type{FT}) where {FT}

context = ClimaComms.SingletonCommsContext(ClimaComms.CUDADevice())
#context = ClimaComms.context()

# Define vertical
# domain
Expand Down Expand Up @@ -145,6 +147,7 @@ function main_3d(::Type{FT}) where {FT}
space_3d_w = make_extruded_sphere(FT)

ρq = CC.Fields.ones(space_3d_ρq) .* FT(1e-3)

ρ = CC.Fields.ones(space_3d_ρ)
w = CC.Fields.zeros(space_3d_w)

Expand All @@ -161,15 +164,12 @@ using Test
@testset "GPU inference failure 1D Float64" begin
main_1d(Float64)
end

@testset "GPU inference failure 3D Float64" begin
main_3d(Float64)
end

@testset "GPU inference failure 1D Float32" begin
main_1d(Float32)
end

@testset "GPU inference failure 3D Float32" begin
main_3d(Float32)
end

0 comments on commit a690880

Please sign in to comment.