Skip to content

Commit

Permalink
Merge pull request #37 from arhik/main
Browse files Browse the repository at this point in the history
[docs] update ops/clamp.jl
  • Loading branch information
arhik authored Apr 13, 2024
2 parents 290aed1 + c627cfb commit 5621e39
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion src/ops/clamp.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
export clamp_kernel, clamp

"""
clamp_kernel(x::WgpuArray{T, N}, out::WgpuArray{T, N}, minVal::T, maxval::T) where {T, N}
This is a clamp compute kernel which takes input `x` and an uninitialized output `out` WgpuArrays,
along with clamp lower bound and upper bound values `minVal` and `maxVal` of type `T`. End users are not
supposed to call this function like regular julia function. This is instead needs to passed to `@wgpukernel`
macro to under go transformations into `WGSL` shader code.
"""

function clamp_kernel(x::WgpuArray{T, N}, out::WgpuArray{T, N}, minval::T, maxval::T) where {T, N}
gId = xDims.x*globalId.y + globalId.x
value = x[gId]
out[gId] = clamp(value, minval, maxval)
end


"""
clamp(x::WgpuArray{T, N}, minValue::T, maxValue::T) where {T, N}
This is a clamp operator which takes `WgpuArray` as an input along with lower bound and upper bound clamp
values to clamp the input array to these bounds
"""
function clamp(x::WgpuArray{T, N}, minValue::T, maxValue::T) where {T, N}
y = similar(x)
@wgpukernel launch=true workgroupSizes=(4, 4) workgroupCount=(2, 2) shmem=() clamp_kernel(x, y, minValue, maxValue)
Expand Down

0 comments on commit 5621e39

Please sign in to comment.