Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Build ctc graph from symbols in batch mode #776

Merged
merged 7 commits into from
Jul 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions k2/csrc/fsa_algo.cu
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,161 @@ FsaVec LinearFsas(const Ragged<int32_t> &symbols) {
arcs);
}


FsaVec CtcGraphs(const Ragged<int32_t> &symbols, bool standard /*= true*/,
Array1<int32_t> *arc_map /*= nullptr*/) {
NVTX_RANGE(K2_FUNC);
K2_CHECK_EQ(symbols.NumAxes(), 2);
ContextPtr &c = symbols.Context();

int32_t num_fsas = symbols.Dim0();
Array1<int32_t> num_states_for(c, num_fsas + 1);
int32_t *num_states_for_data = num_states_for.Data();
const int32_t *symbol_row_split1_data = symbols.RowSplits(1).Data();
// symbols indexed with [fsa][symbol]
// for each fsa we need `symbol_num * 2 + 1 + 1` states, `symbol_num * 2 + 1`
// means that we need a blank state on each side of a symbol state, `+ 1` is
// for final state in k2
K2_EVAL(
c, num_fsas, lambda_set_num_states, (int32_t fsa_idx0)->void {
int32_t symbol_idx0x = symbol_row_split1_data[fsa_idx0],
symbol_idx0x_next = symbol_row_split1_data[fsa_idx0 + 1],
symbol_num = symbol_idx0x_next - symbol_idx0x;
num_states_for_data[fsa_idx0] = symbol_num * 2 + 2;
});

ExclusiveSum(num_states_for, &num_states_for);
Array1<int32_t> &fsa_to_states_row_splits = num_states_for;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to introduce another name for num_states_for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After doing ExclusiveSum, num_states_for is actually the row_splits, using different names here just for easy understanding, we'll use fsa_to_states_row_splits to construct ragged_shape below.

RaggedShape fsa_to_states =
RaggedShape2(&fsa_to_states_row_splits, nullptr, -1);

int32_t num_states = fsa_to_states.NumElements();
Array1<int32_t> num_arcs_for(c, num_states + 1);
int32_t *num_arcs_for_data = num_arcs_for.Data();
const int32_t *fts_row_splits1_data = fsa_to_states.RowSplits(1).Data(),
*fts_row_ids1_data = fsa_to_states.RowIds(1).Data(),
*symbol_data = symbols.values.Data();
// set the arcs number for each state
K2_EVAL(
c, num_states, lambda_set_num_arcs, (int32_t state_idx01)->void {
int32_t fsa_idx0 = fts_row_ids1_data[state_idx01],
// we minus fsa_idx0 here, because we are adding one more state,
// the final state for each fsa
sym_state_idx01 = state_idx01 / 2 - fsa_idx0,
remainder = state_idx01 % 2,
current_num_arcs = 2; // normally there are two arcs, self-loop
// and arc pointing to the next state
// blank state always has two arcs
if (remainder) { // symbol state
int32_t sym_final_state =
symbol_row_split1_data[fsa_idx0 + 1];
// There are no arcs for final states
if (sym_state_idx01 == sym_final_state) {
current_num_arcs = 0;
} else if (!standard) {
current_num_arcs = 3;
} else {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
} else {
} else if(!standard) {
current_num_arcs = 3;
} else {
// same as before the latest change
}

For non-standard topo, current_num_arcs is always 3. Put it into
a separate if statement can save some work.

int32_t current_symbol = symbol_data[sym_state_idx01],
// we set the next symbol of the last symbol to -1, so
// the following if clause will always be true, which means
// we will have 3 arcs for last symbol state
next_symbol = (sym_state_idx01 + 1) == sym_final_state ?
-1 : symbol_data[sym_state_idx01 + 1];
// symbols must be not equal to -1, which is specially used in k2
K2_CHECK_NE(current_symbol, -1);
// if current_symbol equals next_symbol, we need a blank state
// between them, so there are two arcs for this state
// otherwise, this state will point to blank state and next symbol
// state, so we need three arcs here.
// Note: for the simpilfied topology (standard equals false), there
// are always 3 arcs leaving symbol states.
if (current_symbol != next_symbol)
current_num_arcs = 3;
}
}
num_arcs_for_data[state_idx01] = current_num_arcs;
});

ExclusiveSum(num_arcs_for, &num_arcs_for);
Array1<int32_t> &states_to_arcs_row_splits = num_arcs_for;
RaggedShape states_to_arcs =
RaggedShape2(&states_to_arcs_row_splits, nullptr, -1);

// ctc_shape with a index of [fsa][state][arc]
RaggedShape ctc_shape = ComposeRaggedShapes(fsa_to_states, states_to_arcs);
int32_t num_arcs = ctc_shape.NumElements();
Array1<Arc> arcs(c, num_arcs);
Arc *arcs_data = arcs.Data();
const int32_t *ctc_row_splits1_data = ctc_shape.RowSplits(1).Data(),
*ctc_row_ids1_data = ctc_shape.RowIds(1).Data(),
*ctc_row_splits2_data = ctc_shape.RowSplits(2).Data(),
*ctc_row_ids2_data = ctc_shape.RowIds(2).Data();
int32_t *arc_map_data = nullptr;
if (arc_map != nullptr) {
*arc_map = Array1<int32_t>(c, num_arcs);
arc_map_data = arc_map->Data();
}

K2_EVAL(
c, num_arcs, lambda_set_arcs, (int32_t arc_idx012)->void {
int32_t state_idx01 = ctc_row_ids2_data[arc_idx012],
fsa_idx0 = ctc_row_ids1_data[state_idx01],
state_idx0x = ctc_row_splits1_data[fsa_idx0],
state_idx1 = state_idx01 - state_idx0x,
arc_idx01x = ctc_row_splits2_data[state_idx01],
arc_idx2 = arc_idx012 - arc_idx01x,
sym_state_idx01 = state_idx01 / 2 - fsa_idx0,
remainder = state_idx01 % 2,
sym_final_state = symbol_row_split1_data[fsa_idx0 + 1];
bool final_state = sym_final_state == sym_state_idx01;
int32_t current_symbol = final_state ?
-1 : symbol_data[sym_state_idx01];
Arc arc;
arc.score = 0;
arc.src_state = state_idx1;
int32_t arc_map_value = -1;
if (remainder) {
if (final_state) return;
int32_t next_symbol = (sym_state_idx01 + 1) == sym_final_state ?
-1 : symbol_data[sym_state_idx01 + 1];
// for standard topology, the symbol state can not point to next
// symbol state if the next symbol is identical to current symbol.
if (current_symbol == next_symbol && standard) {
K2_CHECK_LT(arc_idx2, 2);
arc.label = arc_idx2 == 0 ? 0 : current_symbol;
arc.dest_state = arc_idx2 == 0 ? state_idx1 + 1 : state_idx1;
} else {
switch (arc_idx2) {
case 0: // the arc pointing to blank state
arc.label = 0;
arc.dest_state = state_idx1 + 1;
break;
case 1: // the self loop arc
arc.label = current_symbol;
arc.dest_state = state_idx1;
break;
case 2: // the arc pointing to the next symbol state
arc.label = next_symbol;
arc_map_value = sym_state_idx01 + 1 == sym_final_state ?
-1 : sym_state_idx01 + 1;
arc.dest_state = state_idx1 + 2;
break;
default:
K2_LOG(FATAL) << "Arc index must be less than 3";
}
}
} else {
K2_CHECK_LT(arc_idx2, 2);
arc.label = arc_idx2 == 0 ? 0 : current_symbol;
arc.dest_state = arc_idx2 == 0 ? state_idx1 : state_idx1 + 1;
arc_map_value = (arc_idx2 == 0 || final_state) ? -1 : sym_state_idx01;
}
arcs_data[arc_idx012] = arc;
if (arc_map) arc_map_data[arc_idx012] = arc_map_value;
});
return Ragged<Arc>(ctc_shape, arcs);
}

void ArcSort(Fsa *fsa) {
if (fsa->NumAxes() < 2) return; // it is empty
SortSublists<Arc>(fsa);
Expand Down
21 changes: 18 additions & 3 deletions k2/csrc/fsa_algo.h
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,6 @@ void RemoveEpsilonAndAddSelfLoops(FsaOrVec &src, int32_t properties,
FsaOrVec *dest,
Ragged<int32_t> *arc_derivs = nullptr);




/*
Determinize the input Fsas, it works for both Fsa and FsaVec.
@param [in] src Source Fsa or FsaVec.
Expand Down Expand Up @@ -478,6 +475,24 @@ Fsa LinearFsa(const Array1<int32_t> &symbols);
*/
FsaVec LinearFsas(const Ragged<int32_t> &symbols);

/*
Create an FsaVec containing ctc graph FSAs, given a list of sequences of
symbols

@param [in] symbols Input symbol sequences (must not contain
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to add a check inside the kernel that none of the input symbols is -1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I add a checking in the kernel set_num_arcs, which will enumerate all the symbols. I think it's enough.

kFinalSymbol == -1). Its num_axes is 2.
@param [in] standard Option to specify the type of CTC topology: "standard"
or "simplified", where the "standard" one makes the
blank mandatory between a pair of identical symbols.
@param [out] It maps the arcs of output fsa to the symbols(idx01), the
olabel of the `arc[i]` would be `symbols[arc_map[i]]`,
and -1 for epsilon olabel.

@return Returns an FsaVec with `ans.Dim0() == symbols.Dim0()`.
*/
FsaVec CtcGraphs(const Ragged<int32_t> &symbols, bool standard = true,
Array1<int32_t> *arc_map = nullptr);

/* Compute the forward shortest path in the tropical semiring.

@param [in] fsas Input FsaVec (must have 3 axes). Must be
Expand Down
44 changes: 44 additions & 0 deletions k2/csrc/fsa_algo_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1053,4 +1053,48 @@ TEST(FsaAlgo, TestRemoveEplsionSelfRandomFsas) {
}
}

TEST(FsaAlgo, TestCtcGraph) {
for (const ContextPtr &c : {GetCpuContext(), GetCudaContext()}) {
Ragged<int32_t> symbols(c, "[ [ 1 2 2 3 ] [ 1 2 3 ] ]");
Array1<int32_t> arc_map;
FsaVec graph = CtcGraphs(symbols, true, &arc_map);
FsaVec graph_ref(c, "[ [ [ 0 0 0 0 0 1 1 0 ] [ 1 2 0 0 1 1 1 0 1 3 2 0 ] "
" [ 2 2 0 0 2 3 2 0 ] [ 3 4 0 0 3 3 2 0 ] "
" [ 4 4 0 0 4 5 2 0 ] [ 5 6 0 0 5 5 2 0 5 7 3 0 ] "
" [ 6 6 0 0 6 7 3 0 ] [ 7 8 0 0 7 7 3 0 7 9 -1 0 ] "
" [ 8 8 0 0 8 9 -1 0 ] [ ] ] "
" [ [ 0 0 0 0 0 1 1 0 ] [ 1 2 0 0 1 1 1 0 1 3 2 0 ] "
" [ 2 2 0 0 2 3 2 0 ] [ 3 4 0 0 3 3 2 0 3 5 3 0 ] "
" [ 4 4 0 0 4 5 3 0 ] [ 5 6 0 0 5 5 3 0 5 7 -1 0 ] "
" [ 6 6 0 0 6 7 -1 0 ] [ ] ] ]");
Array1<int32_t> arc_map_ref(c, "[ -1 0 -1 -1 1 -1 1 -1 -1 -1 2 -1 -1 3 "
" -1 3 -1 -1 -1 -1 -1 -1 4 -1 -1 5 -1 5 "
" -1 -1 6 -1 6 -1 -1 -1 -1 -1 ]");
K2_CHECK(Equal(graph, graph_ref));
K2_CHECK(Equal(arc_map, arc_map_ref));
}
}

TEST(FsaAlgo, TestCtcGraphSimplified) {
for (const ContextPtr &c : {GetCpuContext(), GetCudaContext()}) {
Ragged<int32_t> symbols(c, "[ [ 1 2 2 3 ] [ 1 2 3 ] ]");
Array1<int32_t> arc_map;
FsaVec graph = CtcGraphs(symbols, false, &arc_map);
FsaVec graph_ref(c, "[ [ [ 0 0 0 0 0 1 1 0 ] [ 1 2 0 0 1 1 1 0 1 3 2 0 ] "
" [ 2 2 0 0 2 3 2 0 ] [ 3 4 0 0 3 3 2 0 3 5 2 0] "
" [ 4 4 0 0 4 5 2 0 ] [ 5 6 0 0 5 5 2 0 5 7 3 0 ] "
" [ 6 6 0 0 6 7 3 0 ] [ 7 8 0 0 7 7 3 0 7 9 -1 0 ] "
" [ 8 8 0 0 8 9 -1 0 ] [ ] ] "
" [ [ 0 0 0 0 0 1 1 0 ] [ 1 2 0 0 1 1 1 0 1 3 2 0 ] "
" [ 2 2 0 0 2 3 2 0 ] [ 3 4 0 0 3 3 2 0 3 5 3 0 ] "
" [ 4 4 0 0 4 5 3 0 ] [ 5 6 0 0 5 5 3 0 5 7 -1 0 ] "
" [ 6 6 0 0 6 7 -1 0 ] [ ] ] ]");
Array1<int32_t> arc_map_ref(c, "[ -1 0 -1 -1 1 -1 1 -1 -1 2 -1 2 -1 "
" -1 3 -1 3 -1 -1 -1 -1 -1 -1 4 -1 -1 5 "
" -1 5 -1 -1 6 -1 6 -1 -1 -1 -1 -1 ]");
K2_CHECK(Equal(graph, graph_ref));
K2_CHECK(Equal(arc_map, arc_map_ref));
}
}

} // namespace k2
46 changes: 46 additions & 0 deletions k2/python/csrc/torch/fsa_algo.cu
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,51 @@ static void PybindFixFinalLabels(py::module &m) {
)");
}

static void PybindCtcGraph(py::module &m) {
m.def(
"ctc_graph",
[](const std::vector<std::vector<int32_t>> &symbols,
int32_t gpu_id = -1, bool standard = true, bool need_arc_map = true)
-> std::pair<FsaVec, torch::optional<torch::Tensor>> {
ContextPtr context;
if (gpu_id < 0)
context = GetCpuContext();
else
context = GetCudaContext(gpu_id);

DeviceGuard guard(context);
Ragged<int32_t> ragged = CreateRagged2<int32_t>(symbols).To(context);
Array1<int32_t> arc_map;
FsaVec graph = CtcGraphs(ragged, standard,
need_arc_map ? &arc_map : nullptr);
torch::optional<torch::Tensor> tensor;
if (need_arc_map) tensor = ToTorch(arc_map);
return std::make_pair(graph, tensor);
},
py::arg("symbols"), py::arg("gpu_id") = -1, py::arg("standard") = true,
py::arg("need_arc_map") = true,
R"(
If gpu_id is -1, the returned FsaVec is on CPU.
If gpu_id >= 0, the returned FsaVec is on the specified GPU.
)");

m.def(
"ctc_graph",
[](const Ragged<int32_t> &symbols, int32_t gpu_id, /*unused_gpu_id*/
bool standard = true, bool need_arc_map = true)
-> std::pair<FsaVec, torch::optional<torch::Tensor>> {
DeviceGuard guard(symbols.Context());
Array1<int32_t> arc_map;
FsaVec graph = CtcGraphs(symbols, standard,
need_arc_map ? &arc_map : nullptr);
torch::optional<torch::Tensor> tensor;
if (need_arc_map) tensor = ToTorch(arc_map);
return std::make_pair(graph, tensor);
},
py::arg("symbols"), py::arg("gpu_id"), py::arg("standard") = true,
py::arg("need_arc_map") = true);
}

} // namespace k2

void PybindFsaAlgo(py::module &m) {
Expand All @@ -682,4 +727,5 @@ void PybindFsaAlgo(py::module &m) {
k2::PybindRemoveEpsilonSelfLoops(m);
k2::PybindExpandArcs(m);
k2::PybindFixFinalLabels(m);
k2::PybindCtcGraph(m);
}
1 change: 1 addition & 0 deletions k2/python/k2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .fsa_algo import closure
from .fsa_algo import compose
from .fsa_algo import connect
from .fsa_algo import ctc_graph
from .fsa_algo import determinize
from .fsa_algo import expand_ragged_attributes
from .fsa_algo import intersect
Expand Down
57 changes: 57 additions & 0 deletions k2/python/k2/fsa_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,3 +962,60 @@ def expand_ragged_attributes(
return dest, arc_map
else:
return dest


def ctc_graph(symbols: Union[List[List[int]], k2.RaggedInt],
standard: bool = True,
device: Optional[Union[torch.device, str]] = None) -> Fsa:
'''Construct ctc graphs from symbols.

Note:
The scores of arcs in the returned FSA are all 0.

Args:
symbols:
It can be one of the following types:

- A list of list-of-integers, e..g, `[ [1, 2], [1, 2, 3] ]`
- An instance of :class:`k2.RaggedInt`. Must have `num_axes() == 2`.
standard:
Option to specify the type of CTC topology: "standard" or "simplified",
where the "standard" one makes the blank mandatory between a pair of
identical symbols. Default True.
device:
Optional. It can be either a string (e.g., 'cpu',
'cuda:0') or a torch.device.
If it is None, then the returned FSA is on CPU. It has to be None
pkufool marked this conversation as resolved.
Show resolved Hide resolved
if `symbols` is an instance of :class:`k2.RaggedInt`, the returned
FSA will on the same device as `k2.RaggedInt`.

Returns:
An FsaVec containing the returned ctc graphs, with `Dim0()` the same as
`len(symbols)`(List[List[int]]) or `Dim0()`(k2.RaggedInt)
'''
if device is not None:
device = torch.device(device)
if device.type == 'cpu':
gpu_id = -1
else:
assert device.type == 'cuda'
gpu_id = getattr(device, 'index', 0)
else:
gpu_id = -1

symbol_values = None
if isinstance(symbols, k2.RaggedInt):
assert device is None
assert symbols.num_axes() == 2
symbol_values = symbols.values()
else:
symbol_values = torch.tensor(
[it for symbol in symbols for it in symbol], dtype=torch.int32,
device=device)

need_arc_map = True
ragged_arc, arc_map = _k2.ctc_graph(symbols, gpu_id,
standard, need_arc_map)
aux_labels = k2.index(symbol_values, arc_map)
fsa = Fsa(ragged_arc, aux_labels=aux_labels)
return fsa
1 change: 1 addition & 0 deletions k2/python/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ set(py_test_files
connect_test.py
create_sparse_test.py
ctc_gradients_test.py
ctc_graph_test.py
dense_fsa_vec_test.py
determinize_test.py
expand_ragged_attributes_test.py
Expand Down
Loading