Skip to content

Commit

Permalink
bugfix +es to unified tokenizer support for Canary
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Żelasko <[email protected]>
  • Loading branch information
pzelasko committed Jan 10, 2025
1 parent bff9690 commit cab0f8b
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 15 deletions.
6 changes: 5 additions & 1 deletion nemo/collections/asr/data/audio_to_text_lhotse_prompted.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,12 @@ def __init__(
super().__init__()
self.tokenizer = tokenizer
self.load_audio = AudioSamples(fault_tolerant=True)
self.padding_value = self.tokenizer.pad
self.prompt = prompt
pad_id = self.tokenizer.pad_id
if pad_id == -1:
pad_id = self.tokenizer.token_to_id("<pad>")
assert pad_id > -1, "Invalid tokenizer: both tokenizer.pad_id and tokenizer.token_to_id('<pad>') returned -1."
self.padding_value = pad_id

def __getitem__(self, cuts: CutSet) -> PromptedAudioToTextMiniBatch:
audio, audio_lens, cuts = self.load_audio(cuts)
Expand Down
12 changes: 10 additions & 2 deletions nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
)

# Define autoregressive CE loss
pad_id = self.tokenizer.pad_id
if pad_id == -1:
pad_id = self.tokenizer.token_to_id("<pad>")
assert pad_id > -1, "Invalid tokenizer: both tokenizer.pad_id and tokenizer.token_to_id('<pad>') returned -1."
with open_dict(self.cfg.loss):
self.cfg.loss.pad_id = self.tokenizer.pad_id
self.cfg.loss.pad_id = pad_id

self.loss = EncDecMultiTaskModel.from_config_dict(self.cfg.loss)

Expand Down Expand Up @@ -387,8 +391,12 @@ def change_vocabulary(
self.cfg.decoding = decoding_cfg

# Setup loss
pad_id = self.tokenizer.pad_id
if pad_id == -1:
pad_id = self.tokenizer.token_to_id("<pad>")
assert pad_id > -1, "Invalid tokenizer: both tokenizer.pad_id and tokenizer.token_to_id('<pad>') returned -1."
with open_dict(self.cfg.loss):
self.cfg.loss.pad_id = self.tokenizer.pad_id
self.cfg.loss.pad_id = pad_id

del self.loss
self.loss = EncDecMultiTaskModel.from_config_dict(self.cfg.loss)
Expand Down
26 changes: 14 additions & 12 deletions nemo/collections/common/data/lhotse/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,30 +709,32 @@ def make_structured_with_schema_warnings(config: DictConfig | dict) -> DictConfi
if not isinstance(config, DictConfig):
config = DictConfig(config)

# Remove unsupported keys and warn about them.
supported_keys = set(OmegaConf.to_container(default).keys())
received_keys = set(OmegaConf.to_container(config).keys())
unsupported_keys = received_keys - supported_keys
unsupported_keys.discard("use_lhotse")
if unsupported_keys:
logging.warning(
f"The following configuration keys are ignored by Lhotse dataloader: {','.join(unsupported_keys)}",
)
config = OmegaConf.masked_copy(config, list(supported_keys))

config = OmegaConf.merge(default, config)

if config.get("tarred_random_access", False):
logging.warning(
"Option 'tarred_random_access' is deprecated and replaced with 'skip_missing_manifest_entries'.",
)
config.skip_missing_manifest_entries = True

if config.skip_missing_manifest_entries:
logging.warning(
"Note: skip_missing_manifest_entries is set to True. "
"If any of your manifests and tar files are mismatched, the entire tar file will be skipped without warning. "
"It's your responsibility to ensure data integrity with this setting."
)

# Remove unsupported keys and warn about them.
supported_keys = set(OmegaConf.to_container(default).keys())
received_keys = set(OmegaConf.to_container(config).keys())
unsupported_keys = received_keys - supported_keys
if unsupported_keys:
logging.warning(
f"The following configuration keys are ignored by Lhotse dataloader: {','.join(unsupported_keys)}",
)
config = OmegaConf.masked_copy(config, list(supported_keys))

return OmegaConf.merge(default, config)
return config


def tokenize(example, tokenizer):
Expand Down

0 comments on commit cab0f8b

Please sign in to comment.