-
Notifications
You must be signed in to change notification settings - Fork 630
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Michał Zientkiewicz <[email protected]>
- Loading branch information
Showing
4 changed files
with
396 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
// Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#ifndef DALI_C_API_2_MANAGED_HANDLE_H_ | ||
#define DALI_C_API_2_MANAGED_HANDLE_H_ | ||
|
||
#define DALI_ALLOW_NEW_C_API | ||
#include "dali/dali.h" | ||
#include "dali/core/unique_handle.h" | ||
|
||
namespace dali::c_api { | ||
|
||
template <typename HandleType, typename Actual> | ||
class RefCountedHandle { | ||
public: | ||
using handle_type = HandleType; | ||
static constexpr handle_type null_handle() { return 0; } | ||
|
||
constexpr RefCountedHandle() : handle_(Actual::null_handle()) {} | ||
constexpr explicit RefCountedHandle(handle_type h) : handle_(h) {} | ||
~RefCountedHandle() { reset(); } | ||
|
||
RefCountedHandle(const RefCountedHandle &h) { | ||
handle_ = h.handle_; | ||
if (*this) | ||
Actual::IncRef(handle_); | ||
} | ||
|
||
RefCountedHandle(RefCountedHandle &&h) noexcept { | ||
handle_ = h.handle_; | ||
h.handle_ = Actual::null_handle(); | ||
} | ||
|
||
RefCountedHandle &operator=(const RefCountedHandle &other) { | ||
if (other.handle_) { | ||
Actual::IncRef(other.handle_); | ||
} | ||
reset(); | ||
handle_ = other.handle_; | ||
return *this; | ||
} | ||
|
||
RefCountedHandle &operator=(RefCountedHandle &&other) noexcept { | ||
std::swap(handle_, other.handle_); | ||
other.reset(); | ||
return *this; | ||
} | ||
|
||
void reset() noexcept { | ||
if (*this) | ||
Actual::DecRef(handle_); | ||
handle_ = Actual::null_handle(); | ||
} | ||
|
||
[[nodiscard]] handle_type release() noexcept { | ||
auto h = handle_; | ||
handle_ = Actual::null_handle(); | ||
return h; | ||
} | ||
|
||
handle_type get() const noexcept { return handle_; } | ||
operator handle_type() const noexcept { return get(); } | ||
|
||
explicit operator bool() const noexcept { return handle_ != Actual::null_handle(); } | ||
|
||
private: | ||
handle_type handle_; | ||
}; | ||
|
||
#define DALI_C_UNIQUE_HANDLE(Resource) \ | ||
class Resource##Handle : public dali::UniqueHandle<dali##Resource##_h, Resource##Handle> { \ | ||
public: \ | ||
using UniqueHandle<dali##Resource##_h, Resource##Handle>::UniqueHandle; \ | ||
static void DestroyHandle(dali##Resource##_h h) { \ | ||
auto result = dali##Resource##Destroy(h); \ | ||
if (result != DALI_SUCCESS) { \ | ||
throw std::runtime_error(daliGetLastErrorMessage()); \ | ||
} \ | ||
} \ | ||
} | ||
|
||
#define DALI_C_REF_HANDLE(Resource) \ | ||
class Resource##Handle \ | ||
: public dali::c_api::RefCountedHandle<dali##Resource##_h, Resource##Handle> { \ | ||
public: \ | ||
using RefCountedHandle<dali##Resource##_h, Resource##Handle>::RefCountedHandle; \ | ||
static int IncRef(dali##Resource##_h h) { \ | ||
int ref = 0; \ | ||
auto result = dali##Resource##IncRef(h, &ref); \ | ||
if (result != DALI_SUCCESS) { \ | ||
throw std::runtime_error(daliGetLastErrorMessage()); \ | ||
} \ | ||
return ref; \ | ||
} \ | ||
static int DecRef(dali##Resource##_h h) { \ | ||
int ref = 0; \ | ||
auto result = dali##Resource##DecRef(h, &ref); \ | ||
if (result != DALI_SUCCESS) { \ | ||
throw std::runtime_error(daliGetLastErrorMessage()); \ | ||
} \ | ||
return ref; \ | ||
} \ | ||
} | ||
|
||
DALI_C_UNIQUE_HANDLE(Pipeline); | ||
DALI_C_UNIQUE_HANDLE(PipelineOutputs); | ||
DALI_C_REF_HANDLE(TensorList); | ||
DALI_C_REF_HANDLE(Tensor); | ||
|
||
|
||
} // namespace dali::c_api | ||
|
||
#endif // DALI_C_API_2_MANAGED_HANDLE_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
// Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#ifndef DALI_C_API_2_REF_COUNTING_H_ | ||
#define DALI_C_API_2_REF_COUNTING_H_ | ||
|
||
#include <atomic> | ||
#include <type_traits> | ||
#include <utility> | ||
|
||
namespace dali::c_api { | ||
|
||
class RefCountedObject { | ||
public: | ||
int IncRef() noexcept { | ||
return std::atomic_fetch_add_explicit(&ref_, 1, std::memory_order_relaxed) + 1; | ||
} | ||
|
||
int DecRef() noexcept { | ||
int ret = std::atomic_fetch_sub_explicit(&ref_, 1, std::memory_order_acq_rel) - 1; | ||
if (!ret) | ||
delete this; | ||
return ret; | ||
} | ||
|
||
int RefCount() const noexcept { | ||
return ref_.load(std::memory_order_relaxed); | ||
} | ||
|
||
virtual ~RefCountedObject() = default; | ||
private: | ||
std::atomic<int> ref_{1}; | ||
}; | ||
|
||
template <typename T> | ||
class RefCountedPtr { | ||
public: | ||
constexpr RefCountedPtr() noexcept = default; | ||
|
||
explicit RefCountedPtr(T *ptr, bool inc_ref = false) noexcept : ptr_(ptr) { | ||
if (inc_ref && ptr_) | ||
ptr_->IncRef(); | ||
} | ||
|
||
~RefCountedPtr() { | ||
reset(); | ||
} | ||
|
||
template <typename U, std::enable_if_t<std::is_convertible_v<U *, T *>, int> = 0> | ||
RefCountedPtr(const RefCountedPtr<U> &other) noexcept : ptr_(other.ptr_) { | ||
if (ptr_) | ||
ptr_->IncRef(); | ||
} | ||
|
||
template <typename U, std::enable_if_t<std::is_convertible_v<U *, T *>, int> = 0> | ||
RefCountedPtr(RefCountedPtr<U> &&other) noexcept : ptr_(other.ptr_) { | ||
other.ptr_ = nullptr; | ||
} | ||
|
||
template <typename U> | ||
std::enable_if_t<std::is_convertible_v<U *, T *>, RefCountedPtr> & | ||
operator=(const RefCountedPtr<U> &other) noexcept { | ||
if (ptr_ == other.ptr_) | ||
return *this; | ||
if (other.ptr_) | ||
other.ptr_->IncRef(); | ||
ptr_->DecRef(); | ||
ptr_ = other.ptr_; | ||
return *this; | ||
} | ||
|
||
template <typename U> | ||
std::enable_if_t<std::is_convertible_v<U *, T *>, RefCountedPtr> & | ||
operator=(RefCountedPtr &&other) noexcept { | ||
if (&other == this) | ||
return *this; | ||
std::swap(ptr_, other.ptr_); | ||
other.reset(); | ||
} | ||
|
||
void reset() noexcept { | ||
if (ptr_) | ||
ptr_->DecRef(); | ||
ptr_ = nullptr; | ||
} | ||
|
||
[[nodiscard]] T *release() noexcept { | ||
T *p = ptr_; | ||
ptr_ = nullptr; | ||
return p; | ||
} | ||
|
||
constexpr T *operator->() const & noexcept { return ptr_; } | ||
|
||
constexpr T &operator*() const & noexcept { return *ptr_; } | ||
|
||
constexpr T *get() const & noexcept { return ptr_; } | ||
|
||
private: | ||
template <typename U> | ||
friend class RefCountedPtr; | ||
T *ptr_ = nullptr; | ||
}; | ||
|
||
} // namespace dali::c_api | ||
|
||
#endif // DALI_C_API_2_REF_COUNTING_H_ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
// Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include <cuda_runtime_api.h> | ||
#include <stdexcept> | ||
#include "dali/c_api_2/validation.h" | ||
|
||
namespace dali::c_api { | ||
|
||
void ValidateDeviceId(int device_id, bool allow_cpu_only) { | ||
if (device_id == CPU_ONLY_DEVICE_ID && allow_cpu_only) | ||
return; | ||
|
||
static int dev_count = []() { | ||
int ndevs = 0; | ||
CUDA_CALL(cudaGetDeviceCount(&ndevs)); | ||
return ndevs; | ||
}(); | ||
|
||
if (dev_count < 1) | ||
throw std::runtime_error("No CUDA device found."); | ||
|
||
if (device_id < 0 || device_id >= dev_count) { | ||
throw std::out_of_range(make_string( | ||
"The device id ", device_id, " is invalid." | ||
" Valid device ids are [0..", dev_count-1, "].")); | ||
} | ||
} | ||
|
||
} // namespace dali::c_api |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
// Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#ifndef DALI_C_API_2_VALIDATION_H_ | ||
#define DALI_C_API_2_VALIDATION_H_ | ||
|
||
#include <stdexcept> | ||
#include <optional> | ||
#define DALI_ALLOW_NEW_C_API | ||
#include "dali/dali.h" | ||
#include "dali/core/format.h" | ||
#include "dali/core/span.h" | ||
#include "dali/core/tensor_shape_print.h" | ||
#include "dali/pipeline/data/types.h" | ||
|
||
namespace dali::c_api { | ||
|
||
inline void Validate(daliDataType_t dtype) { | ||
if (!TypeTable::TryGetTypeInfo(dtype)) | ||
throw std::invalid_argument(make_string("Invalid data type: ", dtype)); | ||
} | ||
|
||
inline void Validate(const TensorLayout &layout, int ndim, bool allow_empty = true) { | ||
if (layout.empty() && allow_empty) | ||
return; | ||
if (layout.ndim() != ndim) | ||
throw std::invalid_argument(make_string( | ||
"The layout '", layout, "' cannot describe ", ndim, "-dimensional data.")); | ||
} | ||
|
||
template <typename ShapeLike> | ||
void ValidateSampleShape( | ||
int sample_index, | ||
ShapeLike &&sample_shape, | ||
std::optional<int> expected_ndim = std::nullopt) { | ||
int ndim = std::size(sample_shape); | ||
if (expected_ndim.has_value() && ndim != *expected_ndim) | ||
throw std::invalid_argument(make_string( | ||
"Unexpected number of dimensions (", ndim, ") in sample ", sample_index, | ||
". Expected ", *expected_ndim, ".")); | ||
|
||
for (int j = 0; j < ndim; j++) | ||
if (sample_shape[j] < 0) | ||
throw std::invalid_argument(make_string( | ||
"Negative extent encountered in the shape of sample ", sample_index, ". Offending shape: ", | ||
TensorShape<-1>(sample_shape))); | ||
} | ||
|
||
inline void ValidateNumSamples(int num_samples) { | ||
if (num_samples < 0) | ||
throw std::invalid_argument("The number of samples must not be negative."); | ||
} | ||
|
||
inline void ValidateNDim(int ndim) { | ||
if (ndim < 0) | ||
throw std::invalid_argument("The number of dimensions must not be negative."); | ||
} | ||
|
||
|
||
inline void ValidateShape( | ||
int ndim, | ||
const int64_t *shape) { | ||
ValidateNDim(ndim); | ||
if (ndim > 0 && !shape) | ||
throw std::invalid_argument("The `shape` must not be NULL when ndim > 0."); | ||
|
||
for (int j = 0; j < ndim; j++) | ||
if (shape[j] < 0) | ||
throw std::invalid_argument(make_string( | ||
"The tensor shape must not contain negative extents. Got: ", | ||
TensorShape<-1>(make_cspan(shape, ndim)))); | ||
} | ||
|
||
inline void ValidateShape(int num_samples, int ndim, const int64_t *shapes) { | ||
ValidateNumSamples(num_samples); | ||
ValidateNDim(ndim); | ||
if (!shapes && num_samples > 0 && ndim > 0) | ||
throw std::invalid_argument("The `shapes` are required for non-scalar (ndim>=0) samples."); | ||
|
||
if (ndim > 0) { | ||
for (int i = 0; i < num_samples; i++) | ||
ValidateSampleShape(i, make_cspan(&shapes[i*ndim], ndim)); | ||
} | ||
} | ||
|
||
inline void Validate(daliStorageDevice_t device_type) { | ||
if (device_type != DALI_STORAGE_CPU && device_type != DALI_STORAGE_GPU) | ||
throw std::invalid_argument(make_string("Invalid storage device type: ", device_type)); | ||
} | ||
|
||
void ValidateDeviceId(int device_id, bool allow_cpu_only); | ||
|
||
inline void Validate(const daliBufferPlacement_t &placement) { | ||
Validate(placement.device_type); | ||
if (placement.device_type == DALI_STORAGE_GPU || placement.pinned) | ||
ValidateDeviceId(placement.device_id, placement.pinned); | ||
} | ||
|
||
} // namespace dali::c_api | ||
|
||
#endif // DALI_C_API_2_VALIDATION_H_ |