Skip to content

Commit

Permalink
Build ctc graph from symbols in batch mode (#776)
Browse files Browse the repository at this point in the history
* contructing ctc graphs from symbols

* add more documents and add python unit test

* add rangged tensor unit test, fix code style

* add symbol checking & handle the final symbol

* add standard option to choose ctc topolopy

* Apply suggestions from code review

Co-authored-by: Fangjun Kuang <[email protected]>

* apply suggestions from code review

Co-authored-by: pkufool <[email protected]>
Co-authored-by: Fangjun Kuang <[email protected]>
  • Loading branch information
3 people authored Jul 8, 2021
1 parent d0bfa7e commit fd59f07
Show file tree
Hide file tree
Showing 8 changed files with 424 additions and 3 deletions.
155 changes: 155 additions & 0 deletions k2/csrc/fsa_algo.cu
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,161 @@ FsaVec LinearFsas(const Ragged<int32_t> &symbols) {
arcs);
}


FsaVec CtcGraphs(const Ragged<int32_t> &symbols, bool standard /*= true*/,
Array1<int32_t> *arc_map /*= nullptr*/) {
NVTX_RANGE(K2_FUNC);
K2_CHECK_EQ(symbols.NumAxes(), 2);
ContextPtr &c = symbols.Context();

int32_t num_fsas = symbols.Dim0();
Array1<int32_t> num_states_for(c, num_fsas + 1);
int32_t *num_states_for_data = num_states_for.Data();
const int32_t *symbol_row_split1_data = symbols.RowSplits(1).Data();
// symbols indexed with [fsa][symbol]
// for each fsa we need `symbol_num * 2 + 1 + 1` states, `symbol_num * 2 + 1`
// means that we need a blank state on each side of a symbol state, `+ 1` is
// for final state in k2
K2_EVAL(
c, num_fsas, lambda_set_num_states, (int32_t fsa_idx0)->void {
int32_t symbol_idx0x = symbol_row_split1_data[fsa_idx0],
symbol_idx0x_next = symbol_row_split1_data[fsa_idx0 + 1],
symbol_num = symbol_idx0x_next - symbol_idx0x;
num_states_for_data[fsa_idx0] = symbol_num * 2 + 2;
});

ExclusiveSum(num_states_for, &num_states_for);
Array1<int32_t> &fsa_to_states_row_splits = num_states_for;
RaggedShape fsa_to_states =
RaggedShape2(&fsa_to_states_row_splits, nullptr, -1);

int32_t num_states = fsa_to_states.NumElements();
Array1<int32_t> num_arcs_for(c, num_states + 1);
int32_t *num_arcs_for_data = num_arcs_for.Data();
const int32_t *fts_row_splits1_data = fsa_to_states.RowSplits(1).Data(),
*fts_row_ids1_data = fsa_to_states.RowIds(1).Data(),
*symbol_data = symbols.values.Data();
// set the arcs number for each state
K2_EVAL(
c, num_states, lambda_set_num_arcs, (int32_t state_idx01)->void {
int32_t fsa_idx0 = fts_row_ids1_data[state_idx01],
// we minus fsa_idx0 here, because we are adding one more state,
// the final state for each fsa
sym_state_idx01 = state_idx01 / 2 - fsa_idx0,
remainder = state_idx01 % 2,
current_num_arcs = 2; // normally there are two arcs, self-loop
// and arc pointing to the next state
// blank state always has two arcs
if (remainder) { // symbol state
int32_t sym_final_state =
symbol_row_split1_data[fsa_idx0 + 1];
// There are no arcs for final states
if (sym_state_idx01 == sym_final_state) {
current_num_arcs = 0;
} else if (!standard) {
current_num_arcs = 3;
} else {
int32_t current_symbol = symbol_data[sym_state_idx01],
// we set the next symbol of the last symbol to -1, so
// the following if clause will always be true, which means
// we will have 3 arcs for last symbol state
next_symbol = (sym_state_idx01 + 1) == sym_final_state ?
-1 : symbol_data[sym_state_idx01 + 1];
// symbols must be not equal to -1, which is specially used in k2
K2_CHECK_NE(current_symbol, -1);
// if current_symbol equals next_symbol, we need a blank state
// between them, so there are two arcs for this state
// otherwise, this state will point to blank state and next symbol
// state, so we need three arcs here.
// Note: for the simpilfied topology (standard equals false), there
// are always 3 arcs leaving symbol states.
if (current_symbol != next_symbol)
current_num_arcs = 3;
}
}
num_arcs_for_data[state_idx01] = current_num_arcs;
});

ExclusiveSum(num_arcs_for, &num_arcs_for);
Array1<int32_t> &states_to_arcs_row_splits = num_arcs_for;
RaggedShape states_to_arcs =
RaggedShape2(&states_to_arcs_row_splits, nullptr, -1);

// ctc_shape with a index of [fsa][state][arc]
RaggedShape ctc_shape = ComposeRaggedShapes(fsa_to_states, states_to_arcs);
int32_t num_arcs = ctc_shape.NumElements();
Array1<Arc> arcs(c, num_arcs);
Arc *arcs_data = arcs.Data();
const int32_t *ctc_row_splits1_data = ctc_shape.RowSplits(1).Data(),
*ctc_row_ids1_data = ctc_shape.RowIds(1).Data(),
*ctc_row_splits2_data = ctc_shape.RowSplits(2).Data(),
*ctc_row_ids2_data = ctc_shape.RowIds(2).Data();
int32_t *arc_map_data = nullptr;
if (arc_map != nullptr) {
*arc_map = Array1<int32_t>(c, num_arcs);
arc_map_data = arc_map->Data();
}

K2_EVAL(
c, num_arcs, lambda_set_arcs, (int32_t arc_idx012)->void {
int32_t state_idx01 = ctc_row_ids2_data[arc_idx012],
fsa_idx0 = ctc_row_ids1_data[state_idx01],
state_idx0x = ctc_row_splits1_data[fsa_idx0],
state_idx1 = state_idx01 - state_idx0x,
arc_idx01x = ctc_row_splits2_data[state_idx01],
arc_idx2 = arc_idx012 - arc_idx01x,
sym_state_idx01 = state_idx01 / 2 - fsa_idx0,
remainder = state_idx01 % 2,
sym_final_state = symbol_row_split1_data[fsa_idx0 + 1];
bool final_state = sym_final_state == sym_state_idx01;
int32_t current_symbol = final_state ?
-1 : symbol_data[sym_state_idx01];
Arc arc;
arc.score = 0;
arc.src_state = state_idx1;
int32_t arc_map_value = -1;
if (remainder) {
if (final_state) return;
int32_t next_symbol = (sym_state_idx01 + 1) == sym_final_state ?
-1 : symbol_data[sym_state_idx01 + 1];
// for standard topology, the symbol state can not point to next
// symbol state if the next symbol is identical to current symbol.
if (current_symbol == next_symbol && standard) {
K2_CHECK_LT(arc_idx2, 2);
arc.label = arc_idx2 == 0 ? 0 : current_symbol;
arc.dest_state = arc_idx2 == 0 ? state_idx1 + 1 : state_idx1;
} else {
switch (arc_idx2) {
case 0: // the arc pointing to blank state
arc.label = 0;
arc.dest_state = state_idx1 + 1;
break;
case 1: // the self loop arc
arc.label = current_symbol;
arc.dest_state = state_idx1;
break;
case 2: // the arc pointing to the next symbol state
arc.label = next_symbol;
arc_map_value = sym_state_idx01 + 1 == sym_final_state ?
-1 : sym_state_idx01 + 1;
arc.dest_state = state_idx1 + 2;
break;
default:
K2_LOG(FATAL) << "Arc index must be less than 3";
}
}
} else {
K2_CHECK_LT(arc_idx2, 2);
arc.label = arc_idx2 == 0 ? 0 : current_symbol;
arc.dest_state = arc_idx2 == 0 ? state_idx1 : state_idx1 + 1;
arc_map_value = (arc_idx2 == 0 || final_state) ? -1 : sym_state_idx01;
}
arcs_data[arc_idx012] = arc;
if (arc_map) arc_map_data[arc_idx012] = arc_map_value;
});
return Ragged<Arc>(ctc_shape, arcs);
}

void ArcSort(Fsa *fsa) {
if (fsa->NumAxes() < 2) return; // it is empty
SortSublists<Arc>(fsa);
Expand Down
21 changes: 18 additions & 3 deletions k2/csrc/fsa_algo.h
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,6 @@ void RemoveEpsilonAndAddSelfLoops(FsaOrVec &src, int32_t properties,
FsaOrVec *dest,
Ragged<int32_t> *arc_derivs = nullptr);




/*
Determinize the input Fsas, it works for both Fsa and FsaVec.
@param [in] src Source Fsa or FsaVec.
Expand Down Expand Up @@ -478,6 +475,24 @@ Fsa LinearFsa(const Array1<int32_t> &symbols);
*/
FsaVec LinearFsas(const Ragged<int32_t> &symbols);

/*
Create an FsaVec containing ctc graph FSAs, given a list of sequences of
symbols
@param [in] symbols Input symbol sequences (must not contain
kFinalSymbol == -1). Its num_axes is 2.
@param [in] standard Option to specify the type of CTC topology: "standard"
or "simplified", where the "standard" one makes the
blank mandatory between a pair of identical symbols.
@param [out] It maps the arcs of output fsa to the symbols(idx01), the
olabel of the `arc[i]` would be `symbols[arc_map[i]]`,
and -1 for epsilon olabel.
@return Returns an FsaVec with `ans.Dim0() == symbols.Dim0()`.
*/
FsaVec CtcGraphs(const Ragged<int32_t> &symbols, bool standard = true,
Array1<int32_t> *arc_map = nullptr);

/* Compute the forward shortest path in the tropical semiring.
@param [in] fsas Input FsaVec (must have 3 axes). Must be
Expand Down
44 changes: 44 additions & 0 deletions k2/csrc/fsa_algo_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1084,4 +1084,48 @@ TEST(FsaAlgo, TestRemoveEplsionSelfRandomFsas) {
}
}

TEST(FsaAlgo, TestCtcGraph) {
for (const ContextPtr &c : {GetCpuContext(), GetCudaContext()}) {
Ragged<int32_t> symbols(c, "[ [ 1 2 2 3 ] [ 1 2 3 ] ]");
Array1<int32_t> arc_map;
FsaVec graph = CtcGraphs(symbols, true, &arc_map);
FsaVec graph_ref(c, "[ [ [ 0 0 0 0 0 1 1 0 ] [ 1 2 0 0 1 1 1 0 1 3 2 0 ] "
" [ 2 2 0 0 2 3 2 0 ] [ 3 4 0 0 3 3 2 0 ] "
" [ 4 4 0 0 4 5 2 0 ] [ 5 6 0 0 5 5 2 0 5 7 3 0 ] "
" [ 6 6 0 0 6 7 3 0 ] [ 7 8 0 0 7 7 3 0 7 9 -1 0 ] "
" [ 8 8 0 0 8 9 -1 0 ] [ ] ] "
" [ [ 0 0 0 0 0 1 1 0 ] [ 1 2 0 0 1 1 1 0 1 3 2 0 ] "
" [ 2 2 0 0 2 3 2 0 ] [ 3 4 0 0 3 3 2 0 3 5 3 0 ] "
" [ 4 4 0 0 4 5 3 0 ] [ 5 6 0 0 5 5 3 0 5 7 -1 0 ] "
" [ 6 6 0 0 6 7 -1 0 ] [ ] ] ]");
Array1<int32_t> arc_map_ref(c, "[ -1 0 -1 -1 1 -1 1 -1 -1 -1 2 -1 -1 3 "
" -1 3 -1 -1 -1 -1 -1 -1 4 -1 -1 5 -1 5 "
" -1 -1 6 -1 6 -1 -1 -1 -1 -1 ]");
K2_CHECK(Equal(graph, graph_ref));
K2_CHECK(Equal(arc_map, arc_map_ref));
}
}

TEST(FsaAlgo, TestCtcGraphSimplified) {
for (const ContextPtr &c : {GetCpuContext(), GetCudaContext()}) {
Ragged<int32_t> symbols(c, "[ [ 1 2 2 3 ] [ 1 2 3 ] ]");
Array1<int32_t> arc_map;
FsaVec graph = CtcGraphs(symbols, false, &arc_map);
FsaVec graph_ref(c, "[ [ [ 0 0 0 0 0 1 1 0 ] [ 1 2 0 0 1 1 1 0 1 3 2 0 ] "
" [ 2 2 0 0 2 3 2 0 ] [ 3 4 0 0 3 3 2 0 3 5 2 0] "
" [ 4 4 0 0 4 5 2 0 ] [ 5 6 0 0 5 5 2 0 5 7 3 0 ] "
" [ 6 6 0 0 6 7 3 0 ] [ 7 8 0 0 7 7 3 0 7 9 -1 0 ] "
" [ 8 8 0 0 8 9 -1 0 ] [ ] ] "
" [ [ 0 0 0 0 0 1 1 0 ] [ 1 2 0 0 1 1 1 0 1 3 2 0 ] "
" [ 2 2 0 0 2 3 2 0 ] [ 3 4 0 0 3 3 2 0 3 5 3 0 ] "
" [ 4 4 0 0 4 5 3 0 ] [ 5 6 0 0 5 5 3 0 5 7 -1 0 ] "
" [ 6 6 0 0 6 7 -1 0 ] [ ] ] ]");
Array1<int32_t> arc_map_ref(c, "[ -1 0 -1 -1 1 -1 1 -1 -1 2 -1 2 -1 "
" -1 3 -1 3 -1 -1 -1 -1 -1 -1 4 -1 -1 5 "
" -1 5 -1 -1 6 -1 6 -1 -1 -1 -1 -1 ]");
K2_CHECK(Equal(graph, graph_ref));
K2_CHECK(Equal(arc_map, arc_map_ref));
}
}

} // namespace k2
46 changes: 46 additions & 0 deletions k2/python/csrc/torch/fsa_algo.cu
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,51 @@ static void PybindFixFinalLabels(py::module &m) {
)");
}

static void PybindCtcGraph(py::module &m) {
m.def(
"ctc_graph",
[](const std::vector<std::vector<int32_t>> &symbols,
int32_t gpu_id = -1, bool standard = true, bool need_arc_map = true)
-> std::pair<FsaVec, torch::optional<torch::Tensor>> {
ContextPtr context;
if (gpu_id < 0)
context = GetCpuContext();
else
context = GetCudaContext(gpu_id);

DeviceGuard guard(context);
Ragged<int32_t> ragged = CreateRagged2<int32_t>(symbols).To(context);
Array1<int32_t> arc_map;
FsaVec graph = CtcGraphs(ragged, standard,
need_arc_map ? &arc_map : nullptr);
torch::optional<torch::Tensor> tensor;
if (need_arc_map) tensor = ToTorch(arc_map);
return std::make_pair(graph, tensor);
},
py::arg("symbols"), py::arg("gpu_id") = -1, py::arg("standard") = true,
py::arg("need_arc_map") = true,
R"(
If gpu_id is -1, the returned FsaVec is on CPU.
If gpu_id >= 0, the returned FsaVec is on the specified GPU.
)");

m.def(
"ctc_graph",
[](const Ragged<int32_t> &symbols, int32_t gpu_id, /*unused_gpu_id*/
bool standard = true, bool need_arc_map = true)
-> std::pair<FsaVec, torch::optional<torch::Tensor>> {
DeviceGuard guard(symbols.Context());
Array1<int32_t> arc_map;
FsaVec graph = CtcGraphs(symbols, standard,
need_arc_map ? &arc_map : nullptr);
torch::optional<torch::Tensor> tensor;
if (need_arc_map) tensor = ToTorch(arc_map);
return std::make_pair(graph, tensor);
},
py::arg("symbols"), py::arg("gpu_id"), py::arg("standard") = true,
py::arg("need_arc_map") = true);
}

} // namespace k2

void PybindFsaAlgo(py::module &m) {
Expand All @@ -682,4 +727,5 @@ void PybindFsaAlgo(py::module &m) {
k2::PybindRemoveEpsilonSelfLoops(m);
k2::PybindExpandArcs(m);
k2::PybindFixFinalLabels(m);
k2::PybindCtcGraph(m);
}
1 change: 1 addition & 0 deletions k2/python/k2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .fsa_algo import closure
from .fsa_algo import compose
from .fsa_algo import connect
from .fsa_algo import ctc_graph
from .fsa_algo import determinize
from .fsa_algo import expand_ragged_attributes
from .fsa_algo import intersect
Expand Down
57 changes: 57 additions & 0 deletions k2/python/k2/fsa_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,3 +962,60 @@ def expand_ragged_attributes(
return dest, arc_map
else:
return dest


def ctc_graph(symbols: Union[List[List[int]], k2.RaggedInt],
standard: bool = True,
device: Optional[Union[torch.device, str]] = None) -> Fsa:
'''Construct ctc graphs from symbols.
Note:
The scores of arcs in the returned FSA are all 0.
Args:
symbols:
It can be one of the following types:
- A list of list-of-integers, e..g, `[ [1, 2], [1, 2, 3] ]`
- An instance of :class:`k2.RaggedInt`. Must have `num_axes() == 2`.
standard:
Option to specify the type of CTC topology: "standard" or "simplified",
where the "standard" one makes the blank mandatory between a pair of
identical symbols. Default True.
device:
Optional. It can be either a string (e.g., 'cpu',
'cuda:0') or a torch.device.
If it is None, then the returned FSA is on CPU. It has to be None
if `symbols` is an instance of :class:`k2.RaggedInt`, the returned
FSA will on the same device as `k2.RaggedInt`.
Returns:
An FsaVec containing the returned ctc graphs, with `Dim0()` the same as
`len(symbols)`(List[List[int]]) or `Dim0()`(k2.RaggedInt)
'''
if device is not None:
device = torch.device(device)
if device.type == 'cpu':
gpu_id = -1
else:
assert device.type == 'cuda'
gpu_id = getattr(device, 'index', 0)
else:
gpu_id = -1

symbol_values = None
if isinstance(symbols, k2.RaggedInt):
assert device is None
assert symbols.num_axes() == 2
symbol_values = symbols.values()
else:
symbol_values = torch.tensor(
[it for symbol in symbols for it in symbol], dtype=torch.int32,
device=device)

need_arc_map = True
ragged_arc, arc_map = _k2.ctc_graph(symbols, gpu_id,
standard, need_arc_map)
aux_labels = k2.index(symbol_values, arc_map)
fsa = Fsa(ragged_arc, aux_labels=aux_labels)
return fsa
1 change: 1 addition & 0 deletions k2/python/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ set(py_test_files
connect_test.py
create_sparse_test.py
ctc_gradients_test.py
ctc_graph_test.py
dense_fsa_vec_test.py
determinize_test.py
expand_ragged_attributes_test.py
Expand Down
Loading

0 comments on commit fd59f07

Please sign in to comment.