From b9f09a35a8d4af8bb1e60b29c54944c82377adf0 Mon Sep 17 00:00:00 2001 From: pkufool Date: Wed, 7 Jul 2021 13:42:30 +0800 Subject: [PATCH 1/7] contructing ctc graphs from symbols --- k2/csrc/fsa_algo.cu | 128 +++++++++++++++++++++++++++++++ k2/csrc/fsa_algo.h | 17 +++- k2/csrc/fsa_algo_test.cu | 22 ++++++ k2/python/csrc/torch/fsa_algo.cu | 43 +++++++++++ k2/python/k2/__init__.py | 1 + k2/python/k2/fsa_algo.py | 50 ++++++++++++ 6 files changed, 258 insertions(+), 3 deletions(-) diff --git a/k2/csrc/fsa_algo.cu b/k2/csrc/fsa_algo.cu index cdaa19cbf..68e7b9b16 100644 --- a/k2/csrc/fsa_algo.cu +++ b/k2/csrc/fsa_algo.cu @@ -430,6 +430,134 @@ FsaVec LinearFsas(const Ragged &symbols) { arcs); } + +FsaVec CtcGraphs(const Ragged &symbols, + Array1 *arc_map /*= nullptr*/) { + NVTX_RANGE(K2_FUNC); + K2_CHECK_EQ(symbols.NumAxes(), 2); + ContextPtr &c = symbols.Context(); + + int32_t num_fsas = symbols.Dim0(); + Array1 num_states_for(c, num_fsas + 1); + int32_t *num_states_for_data = num_states_for.Data(); + const int32_t *symbol_row_split1_data = symbols.RowSplits(1).Data(); + K2_EVAL( + c, num_fsas, lambda_set_num_states, (int32_t fsa_idx0)->void { + int32_t state_idx0x = symbol_row_split1_data[fsa_idx0], + state_idx0x_next = symbol_row_split1_data[fsa_idx0 + 1], + state_num = state_idx0x_next - state_idx0x; + num_states_for_data[fsa_idx0] = state_num * 2 + 2; + }); + + ExclusiveSum(num_states_for, &num_states_for); + Array1 &fsa_to_states_row_splits = num_states_for; + RaggedShape fsa_to_states = + RaggedShape2(&fsa_to_states_row_splits, nullptr, -1); + + int32_t num_states = fsa_to_states.NumElements(); + Array1 num_arcs_for(c, num_states + 1); + int32_t *num_arcs_for_data = num_arcs_for.Data(); + const int32_t *fts_row_splits1_data = fsa_to_states.RowSplits(1).Data(), + *fts_row_ids1_data = fsa_to_states.RowIds(1).Data(), + *symbol_data = symbols.values.Data(); + K2_EVAL( + c, num_states, lambda_set_num_arcs, (int32_t state_idx01)->void { + int32_t fsa_idx0 = fts_row_ids1_data[state_idx01], + sym_state_idx01 = state_idx01 / 2 - fsa_idx0, + remainder = state_idx01 % 2, + current_num_arcs = 2; + if (remainder) { + int32_t sym_final_state = + symbol_row_split1_data[fsa_idx0 + 1]; + if (sym_state_idx01 == sym_final_state) { + current_num_arcs = 0; + } else { + int32_t current_symbol = symbol_data[sym_state_idx01], + next_symbol = symbol_data[sym_state_idx01 + 1]; + if (current_symbol != next_symbol) + current_num_arcs = 3; + } + } + num_arcs_for_data[state_idx01] = current_num_arcs; + }); + + ExclusiveSum(num_arcs_for, &num_arcs_for); + Array1 &states_to_arcs_row_splits = num_arcs_for; + RaggedShape states_to_arcs = + RaggedShape2(&states_to_arcs_row_splits, nullptr, -1); + + RaggedShape ctc_shape = ComposeRaggedShapes(fsa_to_states, states_to_arcs); + int32_t num_arcs = ctc_shape.NumElements(); + Array1 arcs(c, num_arcs); + Arc *arcs_data = arcs.Data(); + const int32_t *ctc_row_splits1_data = ctc_shape.RowSplits(1).Data(), + *ctc_row_ids1_data = ctc_shape.RowIds(1).Data(), + *ctc_row_splits2_data = ctc_shape.RowSplits(2).Data(), + *ctc_row_ids2_data = ctc_shape.RowIds(2).Data(); + int32_t *arc_map_data = nullptr; + if (arc_map != nullptr) { + *arc_map = Array1(c, num_arcs); + arc_map_data = arc_map->Data(); + } + + K2_EVAL( + c, num_arcs, lambda_set_arcs, (int32_t arc_idx012)->void { + int32_t state_idx01 = ctc_row_ids2_data[arc_idx012], + fsa_idx0 = ctc_row_ids1_data[state_idx01], + state_idx0x = ctc_row_splits1_data[fsa_idx0], + state_idx1 = state_idx01 - state_idx0x, + arc_idx01x = ctc_row_splits2_data[state_idx01], + arc_idx2 = arc_idx012 - arc_idx01x, + sym_state_idx01 = state_idx01 / 2 - fsa_idx0, + remainder = state_idx01 % 2, + sym_final_state = symbol_row_split1_data[fsa_idx0 + 1]; + bool final_state = sym_final_state == sym_state_idx01; + int32_t current_symbol = final_state ? + -1 : symbol_data[sym_state_idx01]; + Arc arc; + arc.score = 0; + arc.src_state = state_idx1; + int32_t arc_map_value = -1; + if (remainder) { + if (final_state) return; + int32_t next_symbol = (sym_state_idx01 + 1) == sym_final_state ? + -1 : symbol_data[sym_state_idx01 + 1]; + if (current_symbol == next_symbol) { + K2_CHECK_LT(arc_idx2, 2); + arc.label = arc_idx2 == 0 ? 0 : current_symbol; + arc.dest_state = arc_idx2 == 0 ? state_idx1 + 1 : state_idx1; + } else { + switch (arc_idx2) { + case 0: + arc.label = 0; + arc.dest_state = state_idx1 + 1; + break; + case 1: + arc.label = current_symbol; + arc.dest_state = state_idx1; + break; + case 2: + arc.label = next_symbol; + arc_map_value = sym_state_idx01 + 1 == sym_final_state ? + -1 : sym_state_idx01 + 1; + arc.dest_state = state_idx1 + 2; + break; + default: + K2_LOG(FATAL) << "Arc index must be less than 3"; + } + } + } else { + K2_CHECK_LT(arc_idx2, 2); + arc.label = arc_idx2 == 0 ? 0 : current_symbol; + arc.dest_state = arc_idx2 == 0 ? state_idx1 : state_idx1 + 1; + arc_map_value = (arc_idx2 == 0 || final_state) ? -1 : sym_state_idx01; + } + arcs_data[arc_idx012] = arc; + if (arc_map) arc_map_data[arc_idx012] = arc_map_value; + }); + return Ragged(ctc_shape, arcs); +} + void ArcSort(Fsa *fsa) { if (fsa->NumAxes() < 2) return; // it is empty SortSublists(fsa); diff --git a/k2/csrc/fsa_algo.h b/k2/csrc/fsa_algo.h index 743dd8b99..5116bdd2f 100644 --- a/k2/csrc/fsa_algo.h +++ b/k2/csrc/fsa_algo.h @@ -424,9 +424,6 @@ void RemoveEpsilonAndAddSelfLoops(FsaOrVec &src, int32_t properties, FsaOrVec *dest, Ragged *arc_derivs = nullptr); - - - /* Determinize the input Fsas, it works for both Fsa and FsaVec. @param [in] src Source Fsa or FsaVec. @@ -478,6 +475,20 @@ Fsa LinearFsa(const Array1 &symbols); */ FsaVec LinearFsas(const Ragged &symbols); +/* + Create an FsaVec containing ctc graph FSAs, given a list of sequences of + symbols + + @param [in] symbols Input symbol sequences (must not contain + kFinalSymbol == -1). Its num_axes is 2. + @param [out] It map the olabel of the arc to the symbols(idx01), -1 for + epsilon olabel. + + @return Returns an FsaVec with `ans.Dim0() == symbols.Dim0()`. + */ +FsaVec CtcGraphs(const Ragged &symbols, + Array1 *arc_map = nullptr); + /* Compute the forward shortest path in the tropical semiring. @param [in] fsas Input FsaVec (must have 3 axes). Must be diff --git a/k2/csrc/fsa_algo_test.cu b/k2/csrc/fsa_algo_test.cu index 31758ad68..a71d77575 100644 --- a/k2/csrc/fsa_algo_test.cu +++ b/k2/csrc/fsa_algo_test.cu @@ -1053,4 +1053,26 @@ TEST(FsaAlgo, TestRemoveEplsionSelfRandomFsas) { } } +TEST(FsaAlgo, TestCtcGraph) { + for (const ContextPtr &c : {GetCpuContext(), GetCudaContext()}) { + Ragged symbols(c, "[ [ 1 2 2 3 ] [ 1 2 3 ] ]"); + Array1 arc_map; + FsaVec graph = CtcGraphs(symbols, &arc_map); + FsaVec graph_ref(c, "[ [ [ 0 0 0 0 0 1 1 0 ] [ 1 2 0 0 1 1 1 0 1 3 2 0 ] " + " [ 2 2 0 0 2 3 2 0 ] [ 3 4 0 0 3 3 2 0 ] " + " [ 4 4 0 0 4 5 2 0 ] [ 5 6 0 0 5 5 2 0 5 7 3 0 ] " + " [ 6 6 0 0 6 7 3 0 ] [ 7 8 0 0 7 7 3 0 7 9 -1 0 ] " + " [ 8 8 0 0 8 9 -1 0 ] [ ] ] " + " [ [ 0 0 0 0 0 1 1 0 ] [ 1 2 0 0 1 1 1 0 1 3 2 0 ] " + " [ 2 2 0 0 2 3 2 0 ] [ 3 4 0 0 3 3 2 0 3 5 3 0 ] " + " [ 4 4 0 0 4 5 3 0 ] [ 5 6 0 0 5 5 3 0 5 7 -1 0 ] " + " [ 6 6 0 0 6 7 -1 0 ] [ ] ] ]"); + Array1 arc_map_ref(c, "[ -1 0 -1 -1 1 -1 1 -1 -1 -1 2 -1 -1 3 " + " -1 3 -1 -1 -1 -1 -1 -1 4 -1 -1 5 -1 5 " + " -1 -1 6 -1 6 -1 -1 -1 -1 -1 ]"); + K2_CHECK(Equal(graph, graph_ref)); + K2_CHECK(Equal(arc_map, arc_map_ref)); + } +} + } // namespace k2 diff --git a/k2/python/csrc/torch/fsa_algo.cu b/k2/python/csrc/torch/fsa_algo.cu index 8f8d16537..9ce61a67c 100644 --- a/k2/python/csrc/torch/fsa_algo.cu +++ b/k2/python/csrc/torch/fsa_algo.cu @@ -661,6 +661,48 @@ static void PybindFixFinalLabels(py::module &m) { )"); } +static void PybindCtcGraph(py::module &m) { + m.def( + "ctc_graph", + [](const std::vector> &symbols, + bool need_arc_map = true, int32_t gpu_id = -1) + -> std::pair> { + ContextPtr context; + if (gpu_id < 0) + context = GetCpuContext(); + else + context = GetCudaContext(gpu_id); + + DeviceGuard guard(context); + Ragged ragged = CreateRagged2(symbols).To(context); + Array1 arc_map; + FsaVec graph = CtcGraphs(ragged, need_arc_map ? &arc_map : nullptr); + torch::optional tensor; + if (need_arc_map) tensor = ToTorch(arc_map); + return std::make_pair(graph, tensor); + }, + py::arg("symbols"), py::arg("need_arc_map") = true, + py::arg("gpu_id") = -1, + R"( + If gpu_id is -1, the returned FsaVec is on CPU. + If gpu_id >= 0, the returned FsaVec is on the specified GPU. + )"); + + m.def( + "ctc_graph", + [](const Ragged &symbols, bool need_arc_map = true, + int32_t /*unused_gpu_id*/) + -> std::pair> { + DeviceGuard guard(symbols.Context()); + Array1 arc_map; + FsaVec graph = CtcGraphs(symbols, need_arc_map ? &arc_map : nullptr); + torch::optional tensor; + if (need_arc_map) tensor = ToTorch(arc_map); + return std::make_pair(graph, tensor); + }, + py::arg("labels"), py::arg("need_arc_map") = true, py::arg("gpu_id")); +} + } // namespace k2 void PybindFsaAlgo(py::module &m) { @@ -682,4 +724,5 @@ void PybindFsaAlgo(py::module &m) { k2::PybindRemoveEpsilonSelfLoops(m); k2::PybindExpandArcs(m); k2::PybindFixFinalLabels(m); + k2::PybindCtcGraph(m); } diff --git a/k2/python/k2/__init__.py b/k2/python/k2/__init__.py index d7daf6bb8..adfca772d 100644 --- a/k2/python/k2/__init__.py +++ b/k2/python/k2/__init__.py @@ -35,6 +35,7 @@ from .fsa_algo import remove_epsilon_self_loops from .fsa_algo import shortest_path from .fsa_algo import top_sort +from .fsa_algo import ctc_graph from .fsa_properties import to_str as properties_to_str from .ops import cat from .ops import compose_arc_maps diff --git a/k2/python/k2/fsa_algo.py b/k2/python/k2/fsa_algo.py index c84b89b9e..e0dda46b9 100644 --- a/k2/python/k2/fsa_algo.py +++ b/k2/python/k2/fsa_algo.py @@ -962,3 +962,53 @@ def expand_ragged_attributes( return dest, arc_map else: return dest + + +def ctc_graph(symbols: Union[List[List[int]], k2.RaggedInt], + device: Optional[Union[torch.device, str]] = None) -> Fsa: + '''Construct ctc graphs from symbols. + + Note: + The scores of arcs in the returned FSA are all 0. + + Args: + symbols: + It can be one of the following types: + + - A list of list-of-integers, e..g, `[ [1, 2], [1, 2, 3] ]` + - An instance of :class:`k2.RaggedInt`. Must have `num_axes() == 2`. + device: + Optional. It can be either a string (e.g., 'cpu', + 'cuda:0') or a torch.device. + If it is None, then the returned FSA is on CPU. It has to be None + if `symbols` is an instance of :class:`k2.RaggedInt`. + + Returns: + + - If `symbols` is a list of list-of-integers, return an FsaVec + - If `symbols` is an instance of :class:`k2.RaggedInt`, return an FsaVec + ''' + symbol_values = None + if isinstance(symbols, k2.RaggedInt): + assert device is None + assert symbols.num_axes() == 2 + symbol_values = symbols.values() + else: + symbol_values = [it for symbol in symbols for it in symbol] + + if device is not None: + device = torch.device(device) + if device.type == 'cpu': + gpu_id = -1 + else: + assert device.type == 'cuda' + gpu_id = getattr(device, 'index', 0) + else: + gpu_id = -1 + need_arc_map = True + ragged_arc, arc_map = _k2.ctc_graph(symbols, need_arc_map, gpu_id) + fsa = Fsa(ragged_arc) + fsa.aux_labels = torch.tensor([symbol_values[arc_map[i]] \ + if arc_map[i] != -1 else 0 for i in range(len(arc_map))],\ + dtype=torch.int32) + return fsa From 574d5f4d3c2b92aa703b5f66ba1d6acb4a44d4cd Mon Sep 17 00:00:00 2001 From: pkufool Date: Wed, 7 Jul 2021 21:00:23 +0800 Subject: [PATCH 2/7] add more documents and add python unit test --- k2/csrc/fsa_algo.cu | 33 +++++++++++----- k2/csrc/fsa_algo.h | 5 ++- k2/python/tests/CMakeLists.txt | 1 + k2/python/tests/ctc_graph_test.py | 62 +++++++++++++++++++++++++++++++ 4 files changed, 90 insertions(+), 11 deletions(-) create mode 100644 k2/python/tests/ctc_graph_test.py diff --git a/k2/csrc/fsa_algo.cu b/k2/csrc/fsa_algo.cu index 68e7b9b16..540339c5d 100644 --- a/k2/csrc/fsa_algo.cu +++ b/k2/csrc/fsa_algo.cu @@ -441,12 +441,16 @@ FsaVec CtcGraphs(const Ragged &symbols, Array1 num_states_for(c, num_fsas + 1); int32_t *num_states_for_data = num_states_for.Data(); const int32_t *symbol_row_split1_data = symbols.RowSplits(1).Data(); + // symbols indexed with [fsa][symbol] + // for each fsa we need `symbol_num * 2 + 1 + 1` states, `symbol_num * 2 + 1` + // means that we need a blank state on each side of a symbol state, `+ 1` is + // for final state in k2 K2_EVAL( c, num_fsas, lambda_set_num_states, (int32_t fsa_idx0)->void { - int32_t state_idx0x = symbol_row_split1_data[fsa_idx0], - state_idx0x_next = symbol_row_split1_data[fsa_idx0 + 1], - state_num = state_idx0x_next - state_idx0x; - num_states_for_data[fsa_idx0] = state_num * 2 + 2; + int32_t symbol_idx0x = symbol_row_split1_data[fsa_idx0], + symbol_idx0x_next = symbol_row_split1_data[fsa_idx0 + 1], + symbol_num = symbol_idx0x_next - symbol_idx0x; + num_states_for_data[fsa_idx0] = symbol_num * 2 + 2; }); ExclusiveSum(num_states_for, &num_states_for); @@ -460,20 +464,30 @@ FsaVec CtcGraphs(const Ragged &symbols, const int32_t *fts_row_splits1_data = fsa_to_states.RowSplits(1).Data(), *fts_row_ids1_data = fsa_to_states.RowIds(1).Data(), *symbol_data = symbols.values.Data(); + // set the arcs number for each state K2_EVAL( c, num_states, lambda_set_num_arcs, (int32_t state_idx01)->void { int32_t fsa_idx0 = fts_row_ids1_data[state_idx01], + // we minus fsa_idx0 here, because we adding one more state, the + // final state for each fsa sym_state_idx01 = state_idx01 / 2 - fsa_idx0, remainder = state_idx01 % 2, - current_num_arcs = 2; - if (remainder) { + current_num_arcs = 2; // normally there are two arcs, self-loop + // and arc points to the next state + // blank state always has two arcs + if (remainder) { // symbol state int32_t sym_final_state = symbol_row_split1_data[fsa_idx0 + 1]; + // There is no arcs for final states if (sym_state_idx01 == sym_final_state) { current_num_arcs = 0; } else { int32_t current_symbol = symbol_data[sym_state_idx01], next_symbol = symbol_data[sym_state_idx01 + 1]; + // if current_symbol equals next_symbol, we need a blank state + // between them, so there are two arcs for this state + // otherwise, this state will point to blank state and next symbol + // state, so we need three arcs here. if (current_symbol != next_symbol) current_num_arcs = 3; } @@ -486,6 +500,7 @@ FsaVec CtcGraphs(const Ragged &symbols, RaggedShape states_to_arcs = RaggedShape2(&states_to_arcs_row_splits, nullptr, -1); + // ctc_shape with a index of [fsa][state][arc] RaggedShape ctc_shape = ComposeRaggedShapes(fsa_to_states, states_to_arcs); int32_t num_arcs = ctc_shape.NumElements(); Array1 arcs(c, num_arcs); @@ -528,15 +543,15 @@ FsaVec CtcGraphs(const Ragged &symbols, arc.dest_state = arc_idx2 == 0 ? state_idx1 + 1 : state_idx1; } else { switch (arc_idx2) { - case 0: + case 0: // the arc points to blank state arc.label = 0; arc.dest_state = state_idx1 + 1; break; - case 1: + case 1: // the self loop arc arc.label = current_symbol; arc.dest_state = state_idx1; break; - case 2: + case 2: // the arc points to next symbol state arc.label = next_symbol; arc_map_value = sym_state_idx01 + 1 == sym_final_state ? -1 : sym_state_idx01 + 1; diff --git a/k2/csrc/fsa_algo.h b/k2/csrc/fsa_algo.h index 5116bdd2f..000c60754 100644 --- a/k2/csrc/fsa_algo.h +++ b/k2/csrc/fsa_algo.h @@ -481,8 +481,9 @@ FsaVec LinearFsas(const Ragged &symbols); @param [in] symbols Input symbol sequences (must not contain kFinalSymbol == -1). Its num_axes is 2. - @param [out] It map the olabel of the arc to the symbols(idx01), -1 for - epsilon olabel. + @param [out] It maps the arcs of output fsa to the symbols(idx01), the + olabel of the `arc[i]` would be `symbols[arc_map[i]]`, + and -1 for epsilon olabel. @return Returns an FsaVec with `ans.Dim0() == symbols.Dim0()`. */ diff --git a/k2/python/tests/CMakeLists.txt b/k2/python/tests/CMakeLists.txt index 8c6e09740..8ade6fbd1 100644 --- a/k2/python/tests/CMakeLists.txt +++ b/k2/python/tests/CMakeLists.txt @@ -26,6 +26,7 @@ set(py_test_files connect_test.py create_sparse_test.py ctc_gradients_test.py + ctc_graph_test.py dense_fsa_vec_test.py determinize_test.py expand_ragged_attributes_test.py diff --git a/k2/python/tests/ctc_graph_test.py b/k2/python/tests/ctc_graph_test.py new file mode 100644 index 000000000..0ea7d85d1 --- /dev/null +++ b/k2/python/tests/ctc_graph_test.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (authors: WeiKang) +# +# See ../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# To run this single test, use +# +# ctest --verbose -R connect_test_py + +import unittest + +import k2 +import torch + + +class TestCtcGraph(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.devices = [torch.device('cpu')] + if torch.cuda.is_available() and k2.with_cuda: + cls.devices.append(torch.device('cuda', 0)) + if torch.cuda.device_count() > 1: + torch.cuda.set_device(1) + cls.devices.append(torch.device('cuda', 1)) + + def test(self): + for device in self.devices: + fsa_vec = k2.ctc_graph([[1,2,2],[1,2,3]]) + expected_str0 = '\n'.join(['0 0 0 0 0', '0 1 1 1 0', '1 2 0 0 0', + '1 1 1 0 0', '1 3 2 2 0', '2 2 0 0 0', + '2 3 2 2 0', '3 4 0 0 0', '3 3 2 0 0', + '4 4 0 0 0', '4 5 2 2 0', '5 6 0 0 0', + '5 5 2 0 0', '5 7 -1 0 0', '6 6 0 0 0', + '6 7 -1 0 0', '7']) + expected_str1 = '\n'.join(['0 0 0 0 0', '0 1 1 1 0', '1 2 0 0 0', + '1 1 1 0 0', '1 3 2 2 0', '2 2 0 0 0', + '2 3 2 2 0', '3 4 0 0 0', '3 3 2 0 0', + '3 5 3 3 0', '4 4 0 0 0', '4 5 3 3 0', + '5 6 0 0 0', '5 5 3 0 0', '5 7 -1 0 0', + '6 6 0 0 0', '6 7 -1 0 0', '7']) + actual_str0 = k2.to_str_simple(fsa_vec[0].to('cpu')) + actual_str1 = k2.to_str_simple(fsa_vec[1].to('cpu')) + assert actual_str0.strip() == expected_str0 + assert actual_str1.strip() == expected_str1 + + +if __name__ == '__main__': + unittest.main() From 6387d2b7de1c96d31bd6f9164517b5d5e3d66ce9 Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 8 Jul 2021 06:58:47 +0800 Subject: [PATCH 3/7] add rangged tensor unit test, fix code style --- k2/csrc/fsa_algo.cu | 4 ++-- k2/python/k2/__init__.py | 2 +- k2/python/k2/fsa_algo.py | 33 ++++++++++++++++--------------- k2/python/tests/ctc_graph_test.py | 16 ++++++++++++--- 4 files changed, 33 insertions(+), 22 deletions(-) diff --git a/k2/csrc/fsa_algo.cu b/k2/csrc/fsa_algo.cu index 540339c5d..dd81bde04 100644 --- a/k2/csrc/fsa_algo.cu +++ b/k2/csrc/fsa_algo.cu @@ -468,8 +468,8 @@ FsaVec CtcGraphs(const Ragged &symbols, K2_EVAL( c, num_states, lambda_set_num_arcs, (int32_t state_idx01)->void { int32_t fsa_idx0 = fts_row_ids1_data[state_idx01], - // we minus fsa_idx0 here, because we adding one more state, the - // final state for each fsa + // we minus fsa_idx0 here, because we are adding one more state, + // the final state for each fsa sym_state_idx01 = state_idx01 / 2 - fsa_idx0, remainder = state_idx01 % 2, current_num_arcs = 2; // normally there are two arcs, self-loop diff --git a/k2/python/k2/__init__.py b/k2/python/k2/__init__.py index adfca772d..b04e2ff06 100644 --- a/k2/python/k2/__init__.py +++ b/k2/python/k2/__init__.py @@ -21,6 +21,7 @@ from .fsa_algo import closure from .fsa_algo import compose from .fsa_algo import connect +from .fsa_algo import ctc_graph from .fsa_algo import determinize from .fsa_algo import expand_ragged_attributes from .fsa_algo import intersect @@ -35,7 +36,6 @@ from .fsa_algo import remove_epsilon_self_loops from .fsa_algo import shortest_path from .fsa_algo import top_sort -from .fsa_algo import ctc_graph from .fsa_properties import to_str as properties_to_str from .ops import cat from .ops import compose_arc_maps diff --git a/k2/python/k2/fsa_algo.py b/k2/python/k2/fsa_algo.py index e0dda46b9..85d96eb0e 100644 --- a/k2/python/k2/fsa_algo.py +++ b/k2/python/k2/fsa_algo.py @@ -981,21 +981,13 @@ def ctc_graph(symbols: Union[List[List[int]], k2.RaggedInt], Optional. It can be either a string (e.g., 'cpu', 'cuda:0') or a torch.device. If it is None, then the returned FSA is on CPU. It has to be None - if `symbols` is an instance of :class:`k2.RaggedInt`. + if `symbols` is an instance of :class:`k2.RaggedInt`, the returned + FSA will on the same device as `k2.RaggedInt`. Returns: - - - If `symbols` is a list of list-of-integers, return an FsaVec - - If `symbols` is an instance of :class:`k2.RaggedInt`, return an FsaVec + An FsaVec contains the returned ctc graphs, with `Dim0()` the same as + `len(symbols)`(List[List[int]]) or `Dim0()`(k2.RaggedInt) ''' - symbol_values = None - if isinstance(symbols, k2.RaggedInt): - assert device is None - assert symbols.num_axes() == 2 - symbol_values = symbols.values() - else: - symbol_values = [it for symbol in symbols for it in symbol] - if device is not None: device = torch.device(device) if device.type == 'cpu': @@ -1005,10 +997,19 @@ def ctc_graph(symbols: Union[List[List[int]], k2.RaggedInt], gpu_id = getattr(device, 'index', 0) else: gpu_id = -1 + + symbol_values = None + if isinstance(symbols, k2.RaggedInt): + assert device is None + assert symbols.num_axes() == 2 + symbol_values = symbols.values() + else: + symbol_values = torch.tensor( + [it for symbol in symbols for it in symbol], dtype=torch.int32, + device=device) + need_arc_map = True ragged_arc, arc_map = _k2.ctc_graph(symbols, need_arc_map, gpu_id) - fsa = Fsa(ragged_arc) - fsa.aux_labels = torch.tensor([symbol_values[arc_map[i]] \ - if arc_map[i] != -1 else 0 for i in range(len(arc_map))],\ - dtype=torch.int32) + aux_labels = k2.index(symbol_values, arc_map) + fsa = Fsa(ragged_arc, aux_labels=aux_labels) return fsa diff --git a/k2/python/tests/ctc_graph_test.py b/k2/python/tests/ctc_graph_test.py index 0ea7d85d1..086700473 100644 --- a/k2/python/tests/ctc_graph_test.py +++ b/k2/python/tests/ctc_graph_test.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -# Copyright 2021 Xiaomi Corporation (authors: WeiKang) +# Copyright 2021 Xiaomi Corporation (authors: Wei Kang) # # See ../../../LICENSE for clarification regarding multiple authors # @@ -18,7 +18,7 @@ # To run this single test, use # -# ctest --verbose -R connect_test_py +# ctest --verbose -R ctc_graph_test_py import unittest @@ -39,7 +39,13 @@ def setUpClass(cls): def test(self): for device in self.devices: - fsa_vec = k2.ctc_graph([[1,2,2],[1,2,3]]) + s = ''' + [ [1 2 2] [1 2 3] ] + ''' + ragged_int = k2.RaggedInt(s).to(device) + fsa_vec_ragged = k2.ctc_graph(ragged_int) + + fsa_vec = k2.ctc_graph([[1, 2, 2], [1, 2, 3]], device) expected_str0 = '\n'.join(['0 0 0 0 0', '0 1 1 1 0', '1 2 0 0 0', '1 1 1 0 0', '1 3 2 2 0', '2 2 0 0 0', '2 3 2 2 0', '3 4 0 0 0', '3 3 2 0 0', @@ -52,10 +58,14 @@ def test(self): '3 5 3 3 0', '4 4 0 0 0', '4 5 3 3 0', '5 6 0 0 0', '5 5 3 0 0', '5 7 -1 0 0', '6 6 0 0 0', '6 7 -1 0 0', '7']) + actual_str_ragged0 = k2.to_str_simple(fsa_vec_ragged[0].to('cpu')) + actual_str_ragged1 = k2.to_str_simple(fsa_vec_ragged[1].to('cpu')) actual_str0 = k2.to_str_simple(fsa_vec[0].to('cpu')) actual_str1 = k2.to_str_simple(fsa_vec[1].to('cpu')) assert actual_str0.strip() == expected_str0 assert actual_str1.strip() == expected_str1 + assert actual_str_ragged0.strip() == expected_str0 + assert actual_str_ragged1.strip() == expected_str1 if __name__ == '__main__': From 8ee147a211a42a606437197bdbda1cb9f3846189 Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 8 Jul 2021 07:20:56 +0800 Subject: [PATCH 4/7] add symbol checking & handle the final symbol --- k2/csrc/fsa_algo.cu | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/k2/csrc/fsa_algo.cu b/k2/csrc/fsa_algo.cu index dd81bde04..d95edc1c2 100644 --- a/k2/csrc/fsa_algo.cu +++ b/k2/csrc/fsa_algo.cu @@ -483,7 +483,12 @@ FsaVec CtcGraphs(const Ragged &symbols, current_num_arcs = 0; } else { int32_t current_symbol = symbol_data[sym_state_idx01], - next_symbol = symbol_data[sym_state_idx01 + 1]; + // we set the next symbol of the last symbol to -1, so + // the following if clause will always be true + next_symbol = (sym_state_idx01 + 1) == sym_final_state ? + -1 : symbol_data[sym_state_idx01 + 1]; + // symbols must be not equal to -1, which is specially used in k2 + K2_CHECK_NE(current_symbol, -1); // if current_symbol equals next_symbol, we need a blank state // between them, so there are two arcs for this state // otherwise, this state will point to blank state and next symbol From 182c260693e2f9a66a2ae94c1bfeb7e32215c165 Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 8 Jul 2021 19:43:23 +0800 Subject: [PATCH 5/7] add standard option to choose ctc topolopy --- k2/csrc/fsa_algo.cu | 15 ++++++++++----- k2/csrc/fsa_algo.h | 5 ++++- k2/csrc/fsa_algo_test.cu | 24 ++++++++++++++++++++++- k2/python/csrc/torch/fsa_algo.cu | 19 ++++++++++-------- k2/python/k2/fsa_algo.py | 8 +++++++- k2/python/tests/ctc_graph_test.py | 32 ++++++++++++++++++++++++++++++- 6 files changed, 86 insertions(+), 17 deletions(-) diff --git a/k2/csrc/fsa_algo.cu b/k2/csrc/fsa_algo.cu index d95edc1c2..aaa46f2d8 100644 --- a/k2/csrc/fsa_algo.cu +++ b/k2/csrc/fsa_algo.cu @@ -431,7 +431,7 @@ FsaVec LinearFsas(const Ragged &symbols) { } -FsaVec CtcGraphs(const Ragged &symbols, +FsaVec CtcGraphs(const Ragged &symbols, bool standard /*= true*/, Array1 *arc_map /*= nullptr*/) { NVTX_RANGE(K2_FUNC); K2_CHECK_EQ(symbols.NumAxes(), 2); @@ -473,7 +473,7 @@ FsaVec CtcGraphs(const Ragged &symbols, sym_state_idx01 = state_idx01 / 2 - fsa_idx0, remainder = state_idx01 % 2, current_num_arcs = 2; // normally there are two arcs, self-loop - // and arc points to the next state + // and arc pointing to the next state // blank state always has two arcs if (remainder) { // symbol state int32_t sym_final_state = @@ -484,7 +484,8 @@ FsaVec CtcGraphs(const Ragged &symbols, } else { int32_t current_symbol = symbol_data[sym_state_idx01], // we set the next symbol of the last symbol to -1, so - // the following if clause will always be true + // the following if clause will always be true, which means + // we will have 3 arcs for last symbol state next_symbol = (sym_state_idx01 + 1) == sym_final_state ? -1 : symbol_data[sym_state_idx01 + 1]; // symbols must be not equal to -1, which is specially used in k2 @@ -493,7 +494,9 @@ FsaVec CtcGraphs(const Ragged &symbols, // between them, so there are two arcs for this state // otherwise, this state will point to blank state and next symbol // state, so we need three arcs here. - if (current_symbol != next_symbol) + // Note: for the simpilfied topology (standard equals false), there + // are always 3 arcs leaving symbol states. + if (current_symbol != next_symbol || !standard) current_num_arcs = 3; } } @@ -542,7 +545,9 @@ FsaVec CtcGraphs(const Ragged &symbols, if (final_state) return; int32_t next_symbol = (sym_state_idx01 + 1) == sym_final_state ? -1 : symbol_data[sym_state_idx01 + 1]; - if (current_symbol == next_symbol) { + // for standard topology, the symbol state can not point to next + // symbol state if the next symbol is identical to current symbol. + if (current_symbol == next_symbol && standard) { K2_CHECK_LT(arc_idx2, 2); arc.label = arc_idx2 == 0 ? 0 : current_symbol; arc.dest_state = arc_idx2 == 0 ? state_idx1 + 1 : state_idx1; diff --git a/k2/csrc/fsa_algo.h b/k2/csrc/fsa_algo.h index 000c60754..5fa80c1e3 100644 --- a/k2/csrc/fsa_algo.h +++ b/k2/csrc/fsa_algo.h @@ -481,13 +481,16 @@ FsaVec LinearFsas(const Ragged &symbols); @param [in] symbols Input symbol sequences (must not contain kFinalSymbol == -1). Its num_axes is 2. + @param [in] standard Option to specify the type of CTC topology: "standard" + or "simplified", where the "standard" one makes the + blank mandatory between a pair of identical symbols. @param [out] It maps the arcs of output fsa to the symbols(idx01), the olabel of the `arc[i]` would be `symbols[arc_map[i]]`, and -1 for epsilon olabel. @return Returns an FsaVec with `ans.Dim0() == symbols.Dim0()`. */ -FsaVec CtcGraphs(const Ragged &symbols, +FsaVec CtcGraphs(const Ragged &symbols, bool standard = true, Array1 *arc_map = nullptr); /* Compute the forward shortest path in the tropical semiring. diff --git a/k2/csrc/fsa_algo_test.cu b/k2/csrc/fsa_algo_test.cu index a71d77575..e33f8c997 100644 --- a/k2/csrc/fsa_algo_test.cu +++ b/k2/csrc/fsa_algo_test.cu @@ -1057,7 +1057,7 @@ TEST(FsaAlgo, TestCtcGraph) { for (const ContextPtr &c : {GetCpuContext(), GetCudaContext()}) { Ragged symbols(c, "[ [ 1 2 2 3 ] [ 1 2 3 ] ]"); Array1 arc_map; - FsaVec graph = CtcGraphs(symbols, &arc_map); + FsaVec graph = CtcGraphs(symbols, true, &arc_map); FsaVec graph_ref(c, "[ [ [ 0 0 0 0 0 1 1 0 ] [ 1 2 0 0 1 1 1 0 1 3 2 0 ] " " [ 2 2 0 0 2 3 2 0 ] [ 3 4 0 0 3 3 2 0 ] " " [ 4 4 0 0 4 5 2 0 ] [ 5 6 0 0 5 5 2 0 5 7 3 0 ] " @@ -1075,4 +1075,26 @@ TEST(FsaAlgo, TestCtcGraph) { } } +TEST(FsaAlgo, TestCtcGraphSimplified) { + for (const ContextPtr &c : {GetCpuContext(), GetCudaContext()}) { + Ragged symbols(c, "[ [ 1 2 2 3 ] [ 1 2 3 ] ]"); + Array1 arc_map; + FsaVec graph = CtcGraphs(symbols, false, &arc_map); + FsaVec graph_ref(c, "[ [ [ 0 0 0 0 0 1 1 0 ] [ 1 2 0 0 1 1 1 0 1 3 2 0 ] " + " [ 2 2 0 0 2 3 2 0 ] [ 3 4 0 0 3 3 2 0 3 5 2 0] " + " [ 4 4 0 0 4 5 2 0 ] [ 5 6 0 0 5 5 2 0 5 7 3 0 ] " + " [ 6 6 0 0 6 7 3 0 ] [ 7 8 0 0 7 7 3 0 7 9 -1 0 ] " + " [ 8 8 0 0 8 9 -1 0 ] [ ] ] " + " [ [ 0 0 0 0 0 1 1 0 ] [ 1 2 0 0 1 1 1 0 1 3 2 0 ] " + " [ 2 2 0 0 2 3 2 0 ] [ 3 4 0 0 3 3 2 0 3 5 3 0 ] " + " [ 4 4 0 0 4 5 3 0 ] [ 5 6 0 0 5 5 3 0 5 7 -1 0 ] " + " [ 6 6 0 0 6 7 -1 0 ] [ ] ] ]"); + Array1 arc_map_ref(c, "[ -1 0 -1 -1 1 -1 1 -1 -1 2 -1 2 -1 " + " -1 3 -1 3 -1 -1 -1 -1 -1 -1 4 -1 -1 5 " + " -1 5 -1 -1 6 -1 6 -1 -1 -1 -1 -1 ]"); + K2_CHECK(Equal(graph, graph_ref)); + K2_CHECK(Equal(arc_map, arc_map_ref)); + } +} + } // namespace k2 diff --git a/k2/python/csrc/torch/fsa_algo.cu b/k2/python/csrc/torch/fsa_algo.cu index 9ce61a67c..c05d5317b 100644 --- a/k2/python/csrc/torch/fsa_algo.cu +++ b/k2/python/csrc/torch/fsa_algo.cu @@ -665,7 +665,7 @@ static void PybindCtcGraph(py::module &m) { m.def( "ctc_graph", [](const std::vector> &symbols, - bool need_arc_map = true, int32_t gpu_id = -1) + int32_t gpu_id = -1, bool standard = true, bool need_arc_map = true) -> std::pair> { ContextPtr context; if (gpu_id < 0) @@ -676,13 +676,14 @@ static void PybindCtcGraph(py::module &m) { DeviceGuard guard(context); Ragged ragged = CreateRagged2(symbols).To(context); Array1 arc_map; - FsaVec graph = CtcGraphs(ragged, need_arc_map ? &arc_map : nullptr); + FsaVec graph = CtcGraphs(ragged, standard, + need_arc_map ? &arc_map : nullptr); torch::optional tensor; if (need_arc_map) tensor = ToTorch(arc_map); return std::make_pair(graph, tensor); }, - py::arg("symbols"), py::arg("need_arc_map") = true, - py::arg("gpu_id") = -1, + py::arg("symbols"), py::arg("gpu_id") = -1, py::arg("standard") = true, + py::arg("need_arc_map") = true, R"( If gpu_id is -1, the returned FsaVec is on CPU. If gpu_id >= 0, the returned FsaVec is on the specified GPU. @@ -690,17 +691,19 @@ static void PybindCtcGraph(py::module &m) { m.def( "ctc_graph", - [](const Ragged &symbols, bool need_arc_map = true, - int32_t /*unused_gpu_id*/) + [](const Ragged &symbols, int32_t gpu_id, /*unused_gpu_id*/ + bool standard = true, bool need_arc_map = true) -> std::pair> { DeviceGuard guard(symbols.Context()); Array1 arc_map; - FsaVec graph = CtcGraphs(symbols, need_arc_map ? &arc_map : nullptr); + FsaVec graph = CtcGraphs(symbols, standard, + need_arc_map ? &arc_map : nullptr); torch::optional tensor; if (need_arc_map) tensor = ToTorch(arc_map); return std::make_pair(graph, tensor); }, - py::arg("labels"), py::arg("need_arc_map") = true, py::arg("gpu_id")); + py::arg("symbols"), py::arg("gpu_id"), py::arg("standard") = true, + py::arg("need_arc_map") = true); } } // namespace k2 diff --git a/k2/python/k2/fsa_algo.py b/k2/python/k2/fsa_algo.py index 85d96eb0e..94a330bb8 100644 --- a/k2/python/k2/fsa_algo.py +++ b/k2/python/k2/fsa_algo.py @@ -965,6 +965,7 @@ def expand_ragged_attributes( def ctc_graph(symbols: Union[List[List[int]], k2.RaggedInt], + standard: bool = True, device: Optional[Union[torch.device, str]] = None) -> Fsa: '''Construct ctc graphs from symbols. @@ -977,6 +978,10 @@ def ctc_graph(symbols: Union[List[List[int]], k2.RaggedInt], - A list of list-of-integers, e..g, `[ [1, 2], [1, 2, 3] ]` - An instance of :class:`k2.RaggedInt`. Must have `num_axes() == 2`. + standard: + Option to specify the type of CTC topology: "standard" or "simplified", + where the "standard" one makes the blank mandatory between a pair of + identical symbols. Default True. device: Optional. It can be either a string (e.g., 'cpu', 'cuda:0') or a torch.device. @@ -1009,7 +1014,8 @@ def ctc_graph(symbols: Union[List[List[int]], k2.RaggedInt], device=device) need_arc_map = True - ragged_arc, arc_map = _k2.ctc_graph(symbols, need_arc_map, gpu_id) + ragged_arc, arc_map = _k2.ctc_graph(symbols, gpu_id, + standard, need_arc_map) aux_labels = k2.index(symbol_values, arc_map) fsa = Fsa(ragged_arc, aux_labels=aux_labels) return fsa diff --git a/k2/python/tests/ctc_graph_test.py b/k2/python/tests/ctc_graph_test.py index 086700473..e6b910762 100644 --- a/k2/python/tests/ctc_graph_test.py +++ b/k2/python/tests/ctc_graph_test.py @@ -45,7 +45,7 @@ def test(self): ragged_int = k2.RaggedInt(s).to(device) fsa_vec_ragged = k2.ctc_graph(ragged_int) - fsa_vec = k2.ctc_graph([[1, 2, 2], [1, 2, 3]], device) + fsa_vec = k2.ctc_graph([[1, 2, 2], [1, 2, 3]], True, device) expected_str0 = '\n'.join(['0 0 0 0 0', '0 1 1 1 0', '1 2 0 0 0', '1 1 1 0 0', '1 3 2 2 0', '2 2 0 0 0', '2 3 2 2 0', '3 4 0 0 0', '3 3 2 0 0', @@ -67,6 +67,36 @@ def test(self): assert actual_str_ragged0.strip() == expected_str0 assert actual_str_ragged1.strip() == expected_str1 + def test_simplified(self): + for device in self.devices: + s = ''' + [ [1 2 2] [1 2 3] ] + ''' + ragged_int = k2.RaggedInt(s).to(device) + fsa_vec_ragged = k2.ctc_graph(ragged_int, False) + + fsa_vec = k2.ctc_graph([[1, 2, 2], [1, 2, 3]], False, device) + expected_str0 = '\n'.join(['0 0 0 0 0', '0 1 1 1 0', '1 2 0 0 0', + '1 1 1 0 0', '1 3 2 2 0', '2 2 0 0 0', + '2 3 2 2 0', '3 4 0 0 0', '3 3 2 0 0', + '3 5 2 2 0', '4 4 0 0 0', '4 5 2 2 0', + '5 6 0 0 0', '5 5 2 0 0', '5 7 -1 0 0', + '6 6 0 0 0', '6 7 -1 0 0', '7']) + expected_str1 = '\n'.join(['0 0 0 0 0', '0 1 1 1 0', '1 2 0 0 0', + '1 1 1 0 0', '1 3 2 2 0', '2 2 0 0 0', + '2 3 2 2 0', '3 4 0 0 0', '3 3 2 0 0', + '3 5 3 3 0', '4 4 0 0 0', '4 5 3 3 0', + '5 6 0 0 0', '5 5 3 0 0', '5 7 -1 0 0', + '6 6 0 0 0', '6 7 -1 0 0', '7']) + actual_str_ragged0 = k2.to_str_simple(fsa_vec_ragged[0].to('cpu')) + actual_str_ragged1 = k2.to_str_simple(fsa_vec_ragged[1].to('cpu')) + actual_str0 = k2.to_str_simple(fsa_vec[0].to('cpu')) + actual_str1 = k2.to_str_simple(fsa_vec[1].to('cpu')) + assert actual_str0.strip() == expected_str0 + assert actual_str1.strip() == expected_str1 + assert actual_str_ragged0.strip() == expected_str0 + assert actual_str_ragged1.strip() == expected_str1 + if __name__ == '__main__': unittest.main() From a83ae90383000bf280950d93bd50bd50af2d6354 Mon Sep 17 00:00:00 2001 From: pkufool Date: Fri, 9 Jul 2021 06:24:22 +0800 Subject: [PATCH 6/7] Apply suggestions from code review Co-authored-by: Fangjun Kuang --- k2/csrc/fsa_algo.cu | 6 +++--- k2/python/k2/fsa_algo.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/k2/csrc/fsa_algo.cu b/k2/csrc/fsa_algo.cu index aaa46f2d8..f0cdc9a3c 100644 --- a/k2/csrc/fsa_algo.cu +++ b/k2/csrc/fsa_algo.cu @@ -478,7 +478,7 @@ FsaVec CtcGraphs(const Ragged &symbols, bool standard /*= true*/, if (remainder) { // symbol state int32_t sym_final_state = symbol_row_split1_data[fsa_idx0 + 1]; - // There is no arcs for final states + // There are no arcs for final states if (sym_state_idx01 == sym_final_state) { current_num_arcs = 0; } else { @@ -553,7 +553,7 @@ FsaVec CtcGraphs(const Ragged &symbols, bool standard /*= true*/, arc.dest_state = arc_idx2 == 0 ? state_idx1 + 1 : state_idx1; } else { switch (arc_idx2) { - case 0: // the arc points to blank state + case 0: // the arc pointing to blank state arc.label = 0; arc.dest_state = state_idx1 + 1; break; @@ -561,7 +561,7 @@ FsaVec CtcGraphs(const Ragged &symbols, bool standard /*= true*/, arc.label = current_symbol; arc.dest_state = state_idx1; break; - case 2: // the arc points to next symbol state + case 2: // the arc pointing to the next symbol state arc.label = next_symbol; arc_map_value = sym_state_idx01 + 1 == sym_final_state ? -1 : sym_state_idx01 + 1; diff --git a/k2/python/k2/fsa_algo.py b/k2/python/k2/fsa_algo.py index 94a330bb8..6258556c7 100644 --- a/k2/python/k2/fsa_algo.py +++ b/k2/python/k2/fsa_algo.py @@ -990,7 +990,7 @@ def ctc_graph(symbols: Union[List[List[int]], k2.RaggedInt], FSA will on the same device as `k2.RaggedInt`. Returns: - An FsaVec contains the returned ctc graphs, with `Dim0()` the same as + An FsaVec containing the returned ctc graphs, with `Dim0()` the same as `len(symbols)`(List[List[int]]) or `Dim0()`(k2.RaggedInt) ''' if device is not None: From a1a7902d9924c949d4f39698e187fa1173074cf7 Mon Sep 17 00:00:00 2001 From: pkufool Date: Fri, 9 Jul 2021 06:35:35 +0800 Subject: [PATCH 7/7] apply suggestions from code review --- k2/csrc/fsa_algo.cu | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/k2/csrc/fsa_algo.cu b/k2/csrc/fsa_algo.cu index f0cdc9a3c..8e08eef43 100644 --- a/k2/csrc/fsa_algo.cu +++ b/k2/csrc/fsa_algo.cu @@ -481,6 +481,8 @@ FsaVec CtcGraphs(const Ragged &symbols, bool standard /*= true*/, // There are no arcs for final states if (sym_state_idx01 == sym_final_state) { current_num_arcs = 0; + } else if (!standard) { + current_num_arcs = 3; } else { int32_t current_symbol = symbol_data[sym_state_idx01], // we set the next symbol of the last symbol to -1, so @@ -496,7 +498,7 @@ FsaVec CtcGraphs(const Ragged &symbols, bool standard /*= true*/, // state, so we need three arcs here. // Note: for the simpilfied topology (standard equals false), there // are always 3 arcs leaving symbol states. - if (current_symbol != next_symbol || !standard) + if (current_symbol != next_symbol) current_num_arcs = 3; } }