Skip to content

Commit

Permalink
C API utilities.
Browse files Browse the repository at this point in the history
Signed-off-by: Michał Zientkiewicz <[email protected]>
  • Loading branch information
mzient committed Feb 1, 2025
1 parent d159abf commit 2ac3150
Show file tree
Hide file tree
Showing 4 changed files with 396 additions and 0 deletions.
124 changes: 124 additions & 0 deletions dali/c_api_2/managed_handle.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef DALI_C_API_2_MANAGED_HANDLE_H_
#define DALI_C_API_2_MANAGED_HANDLE_H_

#define DALI_ALLOW_NEW_C_API
#include "dali/dali.h"
#include "dali/core/unique_handle.h"

namespace dali::c_api {

template <typename HandleType, typename Actual>
class RefCountedHandle {
public:
using handle_type = HandleType;
static constexpr handle_type null_handle() { return 0; }

constexpr RefCountedHandle() : handle_(Actual::null_handle()) {}
constexpr explicit RefCountedHandle(handle_type h) : handle_(h) {}
~RefCountedHandle() { reset(); }

RefCountedHandle(const RefCountedHandle &h) {
handle_ = h.handle_;
if (*this)
Actual::IncRef(handle_);
}

RefCountedHandle(RefCountedHandle &&h) noexcept {
handle_ = h.handle_;
h.handle_ = Actual::null_handle();
}

RefCountedHandle &operator=(const RefCountedHandle &other) {
if (other.handle_) {
Actual::IncRef(other.handle_);
}
reset();
handle_ = other.handle_;
return *this;
}

RefCountedHandle &operator=(RefCountedHandle &&other) noexcept {
std::swap(handle_, other.handle_);
other.reset();
return *this;
}

void reset() noexcept {
if (*this)
Actual::DecRef(handle_);
handle_ = Actual::null_handle();
}

[[nodiscard]] handle_type release() noexcept {
auto h = handle_;
handle_ = Actual::null_handle();
return h;
}

handle_type get() const noexcept { return handle_; }
operator handle_type() const noexcept { return get(); }

explicit operator bool() const noexcept { return handle_ != Actual::null_handle(); }

private:
handle_type handle_;
};

#define DALI_C_UNIQUE_HANDLE(Resource) \
class Resource##Handle : public dali::UniqueHandle<dali##Resource##_h, Resource##Handle> { \
public: \
using UniqueHandle<dali##Resource##_h, Resource##Handle>::UniqueHandle; \
static void DestroyHandle(dali##Resource##_h h) { \
auto result = dali##Resource##Destroy(h); \
if (result != DALI_SUCCESS) { \
throw std::runtime_error(daliGetLastErrorMessage()); \
} \
} \
}

#define DALI_C_REF_HANDLE(Resource) \
class Resource##Handle \
: public dali::c_api::RefCountedHandle<dali##Resource##_h, Resource##Handle> { \
public: \
using RefCountedHandle<dali##Resource##_h, Resource##Handle>::RefCountedHandle; \
static int IncRef(dali##Resource##_h h) { \
int ref = 0; \
auto result = dali##Resource##IncRef(h, &ref); \
if (result != DALI_SUCCESS) { \
throw std::runtime_error(daliGetLastErrorMessage()); \
} \
return ref; \
} \
static int DecRef(dali##Resource##_h h) { \
int ref = 0; \
auto result = dali##Resource##DecRef(h, &ref); \
if (result != DALI_SUCCESS) { \
throw std::runtime_error(daliGetLastErrorMessage()); \
} \
return ref; \
} \
}

DALI_C_UNIQUE_HANDLE(Pipeline);
DALI_C_UNIQUE_HANDLE(PipelineOutputs);
DALI_C_REF_HANDLE(TensorList);
DALI_C_REF_HANDLE(Tensor);


} // namespace dali::c_api

#endif // DALI_C_API_2_MANAGED_HANDLE_H_
119 changes: 119 additions & 0 deletions dali/c_api_2/ref_counting.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef DALI_C_API_2_REF_COUNTING_H_
#define DALI_C_API_2_REF_COUNTING_H_

#include <atomic>
#include <type_traits>
#include <utility>

namespace dali::c_api {

class RefCountedObject {
public:
int IncRef() noexcept {
return std::atomic_fetch_add_explicit(&ref_, 1, std::memory_order_relaxed) + 1;
}

int DecRef() noexcept {
int ret = std::atomic_fetch_sub_explicit(&ref_, 1, std::memory_order_acq_rel) - 1;
if (!ret)
delete this;
return ret;
}

int RefCount() const noexcept {
return ref_.load(std::memory_order_relaxed);
}

virtual ~RefCountedObject() = default;
private:
std::atomic<int> ref_{1};
};

template <typename T>
class RefCountedPtr {
public:
constexpr RefCountedPtr() noexcept = default;

explicit RefCountedPtr(T *ptr, bool inc_ref = false) noexcept : ptr_(ptr) {
if (inc_ref && ptr_)
ptr_->IncRef();
}

~RefCountedPtr() {
reset();
}

template <typename U, std::enable_if_t<std::is_convertible_v<U *, T *>, int> = 0>
RefCountedPtr(const RefCountedPtr<U> &other) noexcept : ptr_(other.ptr_) {
if (ptr_)
ptr_->IncRef();
}

template <typename U, std::enable_if_t<std::is_convertible_v<U *, T *>, int> = 0>
RefCountedPtr(RefCountedPtr<U> &&other) noexcept : ptr_(other.ptr_) {
other.ptr_ = nullptr;
}

template <typename U>
std::enable_if_t<std::is_convertible_v<U *, T *>, RefCountedPtr> &
operator=(const RefCountedPtr<U> &other) noexcept {
if (ptr_ == other.ptr_)
return *this;
if (other.ptr_)
other.ptr_->IncRef();
ptr_->DecRef();
ptr_ = other.ptr_;
return *this;
}

template <typename U>
std::enable_if_t<std::is_convertible_v<U *, T *>, RefCountedPtr> &
operator=(RefCountedPtr &&other) noexcept {
if (&other == this)
return *this;
std::swap(ptr_, other.ptr_);
other.reset();
}

void reset() noexcept {
if (ptr_)
ptr_->DecRef();
ptr_ = nullptr;
}

[[nodiscard]] T *release() noexcept {
T *p = ptr_;
ptr_ = nullptr;
return p;
}

constexpr T *operator->() const & noexcept { return ptr_; }

constexpr T &operator*() const & noexcept { return *ptr_; }

constexpr T *get() const & noexcept { return ptr_; }

private:
template <typename U>
friend class RefCountedPtr;
T *ptr_ = nullptr;
};

} // namespace dali::c_api

#endif // DALI_C_API_2_REF_COUNTING_H_

41 changes: 41 additions & 0 deletions dali/c_api_2/validation.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cuda_runtime_api.h>
#include <stdexcept>
#include "dali/c_api_2/validation.h"

namespace dali::c_api {

void ValidateDeviceId(int device_id, bool allow_cpu_only) {
if (device_id == CPU_ONLY_DEVICE_ID && allow_cpu_only)
return;

static int dev_count = []() {
int ndevs = 0;
CUDA_CALL(cudaGetDeviceCount(&ndevs));
return ndevs;
}();

if (dev_count < 1)
throw std::runtime_error("No CUDA device found.");

if (device_id < 0 || device_id >= dev_count) {
throw std::out_of_range(make_string(
"The device id ", device_id, " is invalid."
" Valid device ids are [0..", dev_count-1, "]."));
}
}

} // namespace dali::c_api
112 changes: 112 additions & 0 deletions dali/c_api_2/validation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef DALI_C_API_2_VALIDATION_H_
#define DALI_C_API_2_VALIDATION_H_

#include <stdexcept>
#include <optional>
#define DALI_ALLOW_NEW_C_API
#include "dali/dali.h"
#include "dali/core/format.h"
#include "dali/core/span.h"
#include "dali/core/tensor_shape_print.h"
#include "dali/pipeline/data/types.h"

namespace dali::c_api {

inline void Validate(daliDataType_t dtype) {
if (!TypeTable::TryGetTypeInfo(dtype))
throw std::invalid_argument(make_string("Invalid data type: ", dtype));
}

inline void Validate(const TensorLayout &layout, int ndim, bool allow_empty = true) {
if (layout.empty() && allow_empty)
return;
if (layout.ndim() != ndim)
throw std::invalid_argument(make_string(
"The layout '", layout, "' cannot describe ", ndim, "-dimensional data."));
}

template <typename ShapeLike>
void ValidateSampleShape(
int sample_index,
ShapeLike &&sample_shape,
std::optional<int> expected_ndim = std::nullopt) {
int ndim = std::size(sample_shape);
if (expected_ndim.has_value() && ndim != *expected_ndim)
throw std::invalid_argument(make_string(
"Unexpected number of dimensions (", ndim, ") in sample ", sample_index,
". Expected ", *expected_ndim, "."));

for (int j = 0; j < ndim; j++)
if (sample_shape[j] < 0)
throw std::invalid_argument(make_string(
"Negative extent encountered in the shape of sample ", sample_index, ". Offending shape: ",
TensorShape<-1>(sample_shape)));
}

inline void ValidateNumSamples(int num_samples) {
if (num_samples < 0)
throw std::invalid_argument("The number of samples must not be negative.");
}

inline void ValidateNDim(int ndim) {
if (ndim < 0)
throw std::invalid_argument("The number of dimensions must not be negative.");
}


inline void ValidateShape(
int ndim,
const int64_t *shape) {
ValidateNDim(ndim);
if (ndim > 0 && !shape)
throw std::invalid_argument("The `shape` must not be NULL when ndim > 0.");

for (int j = 0; j < ndim; j++)
if (shape[j] < 0)
throw std::invalid_argument(make_string(
"The tensor shape must not contain negative extents. Got: ",
TensorShape<-1>(make_cspan(shape, ndim))));
}

inline void ValidateShape(int num_samples, int ndim, const int64_t *shapes) {
ValidateNumSamples(num_samples);
ValidateNDim(ndim);
if (!shapes && num_samples > 0 && ndim > 0)
throw std::invalid_argument("The `shapes` are required for non-scalar (ndim>=0) samples.");

if (ndim > 0) {
for (int i = 0; i < num_samples; i++)
ValidateSampleShape(i, make_cspan(&shapes[i*ndim], ndim));
}
}

inline void Validate(daliStorageDevice_t device_type) {
if (device_type != DALI_STORAGE_CPU && device_type != DALI_STORAGE_GPU)
throw std::invalid_argument(make_string("Invalid storage device type: ", device_type));
}

void ValidateDeviceId(int device_id, bool allow_cpu_only);

inline void Validate(const daliBufferPlacement_t &placement) {
Validate(placement.device_type);
if (placement.device_type == DALI_STORAGE_GPU || placement.pinned)
ValidateDeviceId(placement.device_id, placement.pinned);
}

} // namespace dali::c_api

#endif // DALI_C_API_2_VALIDATION_H_

0 comments on commit 2ac3150

Please sign in to comment.