Skip to content

Commit

Permalink
Ignore strides CHECK in FromTorch when input tensor is empty (#782)
Browse files Browse the repository at this point in the history
* Ignore strides CHECK in FromTorch when input tensor is empty

* Fix bug in Transpose when the input is empty fsavec

* Fix bug in Transpose when the input is empty fsavec

* add test case to cover empty FsaVec

Co-authored-by: pkufool <[email protected]>
  • Loading branch information
pkufool and pkufool authored Jul 22, 2021
1 parent 7b26fb6 commit fc57541
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 3 deletions.
5 changes: 4 additions & 1 deletion k2/csrc/ragged_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,10 @@ RaggedShape Transpose(RaggedShape &src, Array1<int32_t> *value_indexes) {
K2_CHECK_GT(src.NumAxes(), 2);
ContextPtr c = src.Context();
int32_t src_dim0 = src.Dim0(), src_tot_size1 = src.TotSize(1);
if (src_dim0 <= 0) return src;
if (src_dim0 <= 0) {
if (value_indexes) *value_indexes = Array1<int32_t>(c, 0);
return src;
}
int32_t src_dim1 = src_tot_size1 / src_dim0;
K2_CHECK_EQ(src_tot_size1 % src_dim0, 0)
<< "Transpose(): all dims on axis 0 must be the same.\n"
Expand Down
56 changes: 56 additions & 0 deletions k2/csrc/ragged_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,62 @@ template <typename T>
void TestTransposeRagged() {
ContextPtr cpu = GetCpuContext(); // will be used to copy data
for (auto &context : {GetCpuContext(), GetCudaContext()}) {
// empty case, fsavec with a empty fsa
{
const std::vector<int32_t> row_splits1_vec = {0, 0};
const std::vector<int32_t> row_splits2_vec = {0};
Array1<int32_t> row_splits1(context, row_splits1_vec);
Array1<int32_t> row_splits2(context, row_splits2_vec);
RaggedShape src_shape =
RaggedShape3(&row_splits1, nullptr, -1, &row_splits2, nullptr, -1);
ASSERT_EQ(src_shape.Dim0(), 1);
ASSERT_EQ(src_shape.TotSize(1), 0);

Array1<T> values_array(context, 0);
ASSERT_EQ(values_array.Dim(), src_shape.NumElements());

Ragged<T> ragged(src_shape, values_array);
Ragged<T> ans = Transpose(ragged);
RaggedShape shape = ans.shape;
// Check shape
ASSERT_EQ(shape.Dim0(), 0);
ASSERT_EQ(shape.TotSize(1), 0);
CheckArrayData(shape.RowSplits(1), std::vector<int32_t>({0}));
CheckArrayData(shape.RowSplits(2), std::vector<int32_t>({0}));
K2_CHECK_EQ(shape.RowIds(1).Dim(), 0);
K2_CHECK_EQ(shape.RowIds(2).Dim(), 0);
// Check values
K2_CHECK_EQ(ans.values.Dim(), 0);
}

// empty case, fsavec without any fsa
{
const std::vector<int32_t> row_splits1_vec = {0};
const std::vector<int32_t> row_splits2_vec = {0};
Array1<int32_t> row_splits1(context, row_splits1_vec);
Array1<int32_t> row_splits2(context, row_splits2_vec);
RaggedShape src_shape =
RaggedShape3(&row_splits1, nullptr, -1, &row_splits2, nullptr, -1);
ASSERT_EQ(src_shape.Dim0(), 0);
ASSERT_EQ(src_shape.TotSize(1), 0);

Array1<T> values_array(context, 0);
ASSERT_EQ(values_array.Dim(), src_shape.NumElements());

Ragged<T> ragged(src_shape, values_array);
Ragged<T> ans = Transpose(ragged);
RaggedShape shape = ans.shape;
// Check shape
ASSERT_EQ(shape.Dim0(), 0);
ASSERT_EQ(shape.TotSize(1), 0);
CheckArrayData(shape.RowSplits(1), std::vector<int32_t>({0}));
CheckArrayData(shape.RowSplits(2), std::vector<int32_t>({0}));
K2_CHECK_EQ(shape.RowIds(1).Dim(), 0);
K2_CHECK_EQ(shape.RowIds(2).Dim(), 0);
// Check values
K2_CHECK_EQ(ans.values.Dim(), 0);
}

{
const std::vector<int32_t> row_splits1_vec = {0, 2, 4, 6};
const std::vector<int32_t> row_splits2_vec = {0, 3, 4, 7, 8, 10, 12};
Expand Down
7 changes: 5 additions & 2 deletions k2/python/csrc/torch/torch_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,11 @@ Array1<T> FromTorch(torch::Tensor &tensor) {
K2_CHECK_EQ(tensor.scalar_type(), ToScalarType<T>::value)
<< "Expected scalar type: " << ToScalarType<T>::value
<< ". Given: " << tensor.scalar_type();
K2_CHECK_EQ(tensor.strides()[0], 1)
<< "Expected stride: 1. Given: " << tensor.strides()[0];
// Some empty tensor may have stride not equal to 1, e.g., tensor returned by
// clone() method, it is valid here, so we won't check its strieds.
if (tensor.numel())
K2_CHECK_EQ(tensor.strides()[0], 1)
<< "Expected stride: 1. Given: " << tensor.strides()[0];

auto region = NewRegion(tensor);
Array1<T> ans(tensor.numel(), region, 0);
Expand Down

0 comments on commit fc57541

Please sign in to comment.