-
Notifications
You must be signed in to change notification settings - Fork 218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Build ctc graph from symbols in batch mode #776
Changes from all commits
b9f09a3
574d5f4
6387d2b
8ee147a
182c260
a83ae90
a1a7902
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -430,6 +430,161 @@ FsaVec LinearFsas(const Ragged<int32_t> &symbols) { | |||||||||||||
arcs); | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
FsaVec CtcGraphs(const Ragged<int32_t> &symbols, bool standard /*= true*/, | ||||||||||||||
Array1<int32_t> *arc_map /*= nullptr*/) { | ||||||||||||||
NVTX_RANGE(K2_FUNC); | ||||||||||||||
K2_CHECK_EQ(symbols.NumAxes(), 2); | ||||||||||||||
ContextPtr &c = symbols.Context(); | ||||||||||||||
|
||||||||||||||
int32_t num_fsas = symbols.Dim0(); | ||||||||||||||
Array1<int32_t> num_states_for(c, num_fsas + 1); | ||||||||||||||
int32_t *num_states_for_data = num_states_for.Data(); | ||||||||||||||
const int32_t *symbol_row_split1_data = symbols.RowSplits(1).Data(); | ||||||||||||||
// symbols indexed with [fsa][symbol] | ||||||||||||||
// for each fsa we need `symbol_num * 2 + 1 + 1` states, `symbol_num * 2 + 1` | ||||||||||||||
// means that we need a blank state on each side of a symbol state, `+ 1` is | ||||||||||||||
// for final state in k2 | ||||||||||||||
K2_EVAL( | ||||||||||||||
c, num_fsas, lambda_set_num_states, (int32_t fsa_idx0)->void { | ||||||||||||||
int32_t symbol_idx0x = symbol_row_split1_data[fsa_idx0], | ||||||||||||||
symbol_idx0x_next = symbol_row_split1_data[fsa_idx0 + 1], | ||||||||||||||
symbol_num = symbol_idx0x_next - symbol_idx0x; | ||||||||||||||
num_states_for_data[fsa_idx0] = symbol_num * 2 + 2; | ||||||||||||||
}); | ||||||||||||||
|
||||||||||||||
ExclusiveSum(num_states_for, &num_states_for); | ||||||||||||||
Array1<int32_t> &fsa_to_states_row_splits = num_states_for; | ||||||||||||||
RaggedShape fsa_to_states = | ||||||||||||||
RaggedShape2(&fsa_to_states_row_splits, nullptr, -1); | ||||||||||||||
|
||||||||||||||
int32_t num_states = fsa_to_states.NumElements(); | ||||||||||||||
Array1<int32_t> num_arcs_for(c, num_states + 1); | ||||||||||||||
int32_t *num_arcs_for_data = num_arcs_for.Data(); | ||||||||||||||
const int32_t *fts_row_splits1_data = fsa_to_states.RowSplits(1).Data(), | ||||||||||||||
*fts_row_ids1_data = fsa_to_states.RowIds(1).Data(), | ||||||||||||||
*symbol_data = symbols.values.Data(); | ||||||||||||||
// set the arcs number for each state | ||||||||||||||
K2_EVAL( | ||||||||||||||
c, num_states, lambda_set_num_arcs, (int32_t state_idx01)->void { | ||||||||||||||
int32_t fsa_idx0 = fts_row_ids1_data[state_idx01], | ||||||||||||||
// we minus fsa_idx0 here, because we are adding one more state, | ||||||||||||||
// the final state for each fsa | ||||||||||||||
sym_state_idx01 = state_idx01 / 2 - fsa_idx0, | ||||||||||||||
remainder = state_idx01 % 2, | ||||||||||||||
current_num_arcs = 2; // normally there are two arcs, self-loop | ||||||||||||||
// and arc pointing to the next state | ||||||||||||||
// blank state always has two arcs | ||||||||||||||
if (remainder) { // symbol state | ||||||||||||||
int32_t sym_final_state = | ||||||||||||||
symbol_row_split1_data[fsa_idx0 + 1]; | ||||||||||||||
// There are no arcs for final states | ||||||||||||||
if (sym_state_idx01 == sym_final_state) { | ||||||||||||||
current_num_arcs = 0; | ||||||||||||||
} else if (!standard) { | ||||||||||||||
current_num_arcs = 3; | ||||||||||||||
} else { | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
For non-standard topo, |
||||||||||||||
int32_t current_symbol = symbol_data[sym_state_idx01], | ||||||||||||||
// we set the next symbol of the last symbol to -1, so | ||||||||||||||
// the following if clause will always be true, which means | ||||||||||||||
// we will have 3 arcs for last symbol state | ||||||||||||||
next_symbol = (sym_state_idx01 + 1) == sym_final_state ? | ||||||||||||||
-1 : symbol_data[sym_state_idx01 + 1]; | ||||||||||||||
// symbols must be not equal to -1, which is specially used in k2 | ||||||||||||||
K2_CHECK_NE(current_symbol, -1); | ||||||||||||||
// if current_symbol equals next_symbol, we need a blank state | ||||||||||||||
// between them, so there are two arcs for this state | ||||||||||||||
// otherwise, this state will point to blank state and next symbol | ||||||||||||||
// state, so we need three arcs here. | ||||||||||||||
// Note: for the simpilfied topology (standard equals false), there | ||||||||||||||
// are always 3 arcs leaving symbol states. | ||||||||||||||
if (current_symbol != next_symbol) | ||||||||||||||
current_num_arcs = 3; | ||||||||||||||
} | ||||||||||||||
} | ||||||||||||||
num_arcs_for_data[state_idx01] = current_num_arcs; | ||||||||||||||
}); | ||||||||||||||
|
||||||||||||||
ExclusiveSum(num_arcs_for, &num_arcs_for); | ||||||||||||||
Array1<int32_t> &states_to_arcs_row_splits = num_arcs_for; | ||||||||||||||
RaggedShape states_to_arcs = | ||||||||||||||
RaggedShape2(&states_to_arcs_row_splits, nullptr, -1); | ||||||||||||||
|
||||||||||||||
// ctc_shape with a index of [fsa][state][arc] | ||||||||||||||
RaggedShape ctc_shape = ComposeRaggedShapes(fsa_to_states, states_to_arcs); | ||||||||||||||
int32_t num_arcs = ctc_shape.NumElements(); | ||||||||||||||
Array1<Arc> arcs(c, num_arcs); | ||||||||||||||
Arc *arcs_data = arcs.Data(); | ||||||||||||||
const int32_t *ctc_row_splits1_data = ctc_shape.RowSplits(1).Data(), | ||||||||||||||
*ctc_row_ids1_data = ctc_shape.RowIds(1).Data(), | ||||||||||||||
*ctc_row_splits2_data = ctc_shape.RowSplits(2).Data(), | ||||||||||||||
*ctc_row_ids2_data = ctc_shape.RowIds(2).Data(); | ||||||||||||||
int32_t *arc_map_data = nullptr; | ||||||||||||||
if (arc_map != nullptr) { | ||||||||||||||
*arc_map = Array1<int32_t>(c, num_arcs); | ||||||||||||||
arc_map_data = arc_map->Data(); | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
K2_EVAL( | ||||||||||||||
c, num_arcs, lambda_set_arcs, (int32_t arc_idx012)->void { | ||||||||||||||
int32_t state_idx01 = ctc_row_ids2_data[arc_idx012], | ||||||||||||||
fsa_idx0 = ctc_row_ids1_data[state_idx01], | ||||||||||||||
state_idx0x = ctc_row_splits1_data[fsa_idx0], | ||||||||||||||
state_idx1 = state_idx01 - state_idx0x, | ||||||||||||||
arc_idx01x = ctc_row_splits2_data[state_idx01], | ||||||||||||||
arc_idx2 = arc_idx012 - arc_idx01x, | ||||||||||||||
sym_state_idx01 = state_idx01 / 2 - fsa_idx0, | ||||||||||||||
remainder = state_idx01 % 2, | ||||||||||||||
sym_final_state = symbol_row_split1_data[fsa_idx0 + 1]; | ||||||||||||||
bool final_state = sym_final_state == sym_state_idx01; | ||||||||||||||
int32_t current_symbol = final_state ? | ||||||||||||||
-1 : symbol_data[sym_state_idx01]; | ||||||||||||||
Arc arc; | ||||||||||||||
arc.score = 0; | ||||||||||||||
arc.src_state = state_idx1; | ||||||||||||||
int32_t arc_map_value = -1; | ||||||||||||||
if (remainder) { | ||||||||||||||
if (final_state) return; | ||||||||||||||
int32_t next_symbol = (sym_state_idx01 + 1) == sym_final_state ? | ||||||||||||||
-1 : symbol_data[sym_state_idx01 + 1]; | ||||||||||||||
// for standard topology, the symbol state can not point to next | ||||||||||||||
// symbol state if the next symbol is identical to current symbol. | ||||||||||||||
if (current_symbol == next_symbol && standard) { | ||||||||||||||
K2_CHECK_LT(arc_idx2, 2); | ||||||||||||||
arc.label = arc_idx2 == 0 ? 0 : current_symbol; | ||||||||||||||
arc.dest_state = arc_idx2 == 0 ? state_idx1 + 1 : state_idx1; | ||||||||||||||
} else { | ||||||||||||||
switch (arc_idx2) { | ||||||||||||||
case 0: // the arc pointing to blank state | ||||||||||||||
arc.label = 0; | ||||||||||||||
arc.dest_state = state_idx1 + 1; | ||||||||||||||
break; | ||||||||||||||
case 1: // the self loop arc | ||||||||||||||
arc.label = current_symbol; | ||||||||||||||
arc.dest_state = state_idx1; | ||||||||||||||
break; | ||||||||||||||
case 2: // the arc pointing to the next symbol state | ||||||||||||||
arc.label = next_symbol; | ||||||||||||||
arc_map_value = sym_state_idx01 + 1 == sym_final_state ? | ||||||||||||||
-1 : sym_state_idx01 + 1; | ||||||||||||||
arc.dest_state = state_idx1 + 2; | ||||||||||||||
break; | ||||||||||||||
default: | ||||||||||||||
K2_LOG(FATAL) << "Arc index must be less than 3"; | ||||||||||||||
} | ||||||||||||||
} | ||||||||||||||
} else { | ||||||||||||||
K2_CHECK_LT(arc_idx2, 2); | ||||||||||||||
arc.label = arc_idx2 == 0 ? 0 : current_symbol; | ||||||||||||||
arc.dest_state = arc_idx2 == 0 ? state_idx1 : state_idx1 + 1; | ||||||||||||||
arc_map_value = (arc_idx2 == 0 || final_state) ? -1 : sym_state_idx01; | ||||||||||||||
} | ||||||||||||||
arcs_data[arc_idx012] = arc; | ||||||||||||||
if (arc_map) arc_map_data[arc_idx012] = arc_map_value; | ||||||||||||||
}); | ||||||||||||||
return Ragged<Arc>(ctc_shape, arcs); | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
void ArcSort(Fsa *fsa) { | ||||||||||||||
if (fsa->NumAxes() < 2) return; // it is empty | ||||||||||||||
SortSublists<Arc>(fsa); | ||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -424,9 +424,6 @@ void RemoveEpsilonAndAddSelfLoops(FsaOrVec &src, int32_t properties, | |
FsaOrVec *dest, | ||
Ragged<int32_t> *arc_derivs = nullptr); | ||
|
||
|
||
|
||
|
||
/* | ||
Determinize the input Fsas, it works for both Fsa and FsaVec. | ||
@param [in] src Source Fsa or FsaVec. | ||
|
@@ -478,6 +475,24 @@ Fsa LinearFsa(const Array1<int32_t> &symbols); | |
*/ | ||
FsaVec LinearFsas(const Ragged<int32_t> &symbols); | ||
|
||
/* | ||
Create an FsaVec containing ctc graph FSAs, given a list of sequences of | ||
symbols | ||
|
||
@param [in] symbols Input symbol sequences (must not contain | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to add a check inside the kernel that none of the input symbols is -1? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I add a checking in the kernel |
||
kFinalSymbol == -1). Its num_axes is 2. | ||
@param [in] standard Option to specify the type of CTC topology: "standard" | ||
or "simplified", where the "standard" one makes the | ||
blank mandatory between a pair of identical symbols. | ||
@param [out] It maps the arcs of output fsa to the symbols(idx01), the | ||
olabel of the `arc[i]` would be `symbols[arc_map[i]]`, | ||
and -1 for epsilon olabel. | ||
|
||
@return Returns an FsaVec with `ans.Dim0() == symbols.Dim0()`. | ||
*/ | ||
FsaVec CtcGraphs(const Ragged<int32_t> &symbols, bool standard = true, | ||
Array1<int32_t> *arc_map = nullptr); | ||
|
||
/* Compute the forward shortest path in the tropical semiring. | ||
|
||
@param [in] fsas Input FsaVec (must have 3 axes). Must be | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason to introduce another name for
num_states_for
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After doing ExclusiveSum,
num_states_for
is actually therow_splits
, using different names here just for easy understanding, we'll usefsa_to_states_row_splits
to construct ragged_shape below.