diff --git a/environment.yml b/environment.yml index dbda5d43..cdfb9b1a 100644 --- a/environment.yml +++ b/environment.yml @@ -8,8 +8,6 @@ dependencies: - librosa - tqdm - requests - - colorama - - ansiwrap - pyyaml - dataclassy - kaldi=*=*cpu* @@ -42,6 +40,8 @@ dependencies: - matplotlib - seaborn - pip + - rich + - rich-click - pip: - build - twine diff --git a/montreal_forced_aligner/__main__.py b/montreal_forced_aligner/__main__.py index e6e7a6fb..01328e9d 100644 --- a/montreal_forced_aligner/__main__.py +++ b/montreal_forced_aligner/__main__.py @@ -1,3 +1,6 @@ +from rich.traceback import install + from montreal_forced_aligner.command_line.mfa import mfa_cli +install(show_locals=True) mfa_cli() diff --git a/montreal_forced_aligner/abc.py b/montreal_forced_aligner/abc.py index 094a3a26..b2b2dd56 100644 --- a/montreal_forced_aligner/abc.py +++ b/montreal_forced_aligner/abc.py @@ -699,7 +699,7 @@ def check_previous_run(self) -> bool: return True conf = load_configuration(self.worker_config_path) clean = self._validate_previous_configuration(conf) - if not clean: + if not GLOBAL_CONFIG.current_profile.clean and not clean: logger.warning( "The previous run had a different configuration than the current, which may cause issues." " Please see the log for details or use --clean flag if issues are encountered." diff --git a/montreal_forced_aligner/acoustic_modeling/base.py b/montreal_forced_aligner/acoustic_modeling/base.py index a15041bb..94ef9223 100644 --- a/montreal_forced_aligner/acoustic_modeling/base.py +++ b/montreal_forced_aligner/acoustic_modeling/base.py @@ -13,8 +13,8 @@ from typing import TYPE_CHECKING, List import sqlalchemy.engine -import tqdm from sqlalchemy.orm import Session +from tqdm.rich import tqdm from montreal_forced_aligner.abc import MfaWorker, ModelExporterMixin, TrainerMixin from montreal_forced_aligner.alignment import AlignMixin @@ -327,7 +327,7 @@ def acc_stats(self) -> None: """ logger.info("Accumulating statistics...") arguments = self.acc_stats_arguments() - with tqdm.tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: if GLOBAL_CONFIG.use_mp: error_dict = {} return_queue = mp.Queue() diff --git a/montreal_forced_aligner/acoustic_modeling/lda.py b/montreal_forced_aligner/acoustic_modeling/lda.py index 21f0be15..168c250c 100644 --- a/montreal_forced_aligner/acoustic_modeling/lda.py +++ b/montreal_forced_aligner/acoustic_modeling/lda.py @@ -12,7 +12,7 @@ from queue import Empty from typing import TYPE_CHECKING, Dict, List -import tqdm +from tqdm.rich import tqdm from montreal_forced_aligner.abc import KaldiFunction from montreal_forced_aligner.acoustic_modeling.triphone import TriphoneTrainer @@ -412,7 +412,7 @@ def lda_acc_stats(self) -> None: if os.path.exists(worker_lda_path): os.remove(worker_lda_path) arguments = self.lda_acc_stats_arguments() - with tqdm.tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: if GLOBAL_CONFIG.use_mp: error_dict = {} return_queue = mp.Queue() @@ -509,7 +509,7 @@ def calc_lda_mllt(self) -> None: """ logger.info("Re-calculating LDA...") arguments = self.calc_lda_mllt_arguments() - with tqdm.tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: if GLOBAL_CONFIG.use_mp: error_dict = {} return_queue = mp.Queue() diff --git a/montreal_forced_aligner/acoustic_modeling/monophone.py b/montreal_forced_aligner/acoustic_modeling/monophone.py index 44185a22..0d0dbbba 100644 --- a/montreal_forced_aligner/acoustic_modeling/monophone.py +++ b/montreal_forced_aligner/acoustic_modeling/monophone.py @@ -10,8 +10,8 @@ from pathlib import Path from queue import Empty -import tqdm from sqlalchemy.orm import Session, joinedload, subqueryload +from tqdm.rich import tqdm from montreal_forced_aligner.abc import KaldiFunction from montreal_forced_aligner.acoustic_modeling.base import AcousticModelTrainingMixin @@ -240,7 +240,7 @@ def mono_align_equal(self) -> None: logger.info("Generating initial alignments...") arguments = self.mono_align_equal_arguments() - with tqdm.tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: if GLOBAL_CONFIG.use_mp: error_dict = {} return_queue = mp.Queue() diff --git a/montreal_forced_aligner/acoustic_modeling/pronunciation_probabilities.py b/montreal_forced_aligner/acoustic_modeling/pronunciation_probabilities.py index 845ff702..3093001a 100644 --- a/montreal_forced_aligner/acoustic_modeling/pronunciation_probabilities.py +++ b/montreal_forced_aligner/acoustic_modeling/pronunciation_probabilities.py @@ -8,8 +8,8 @@ import typing from pathlib import Path -import tqdm from sqlalchemy.orm import joinedload +from tqdm.rich import tqdm from montreal_forced_aligner.acoustic_modeling.base import AcousticModelTrainingMixin from montreal_forced_aligner.alignment.multiprocessing import ( @@ -186,7 +186,7 @@ def train_g2p_lexicon(self) -> None: ) for x in self.worker.dictionary_lookup.values() } - with tqdm.tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: for dict_id, utt_id, phones in run_kaldi_function( GeneratePronunciationsFunction, arguments, pbar.update ): diff --git a/montreal_forced_aligner/acoustic_modeling/sat.py b/montreal_forced_aligner/acoustic_modeling/sat.py index 8706d8a4..a9e03b6b 100644 --- a/montreal_forced_aligner/acoustic_modeling/sat.py +++ b/montreal_forced_aligner/acoustic_modeling/sat.py @@ -13,7 +13,7 @@ from queue import Empty from typing import Dict, List -import tqdm +from tqdm.rich import tqdm from montreal_forced_aligner.acoustic_modeling.triphone import TriphoneTrainer from montreal_forced_aligner.config import GLOBAL_CONFIG @@ -341,7 +341,7 @@ def create_align_model(self) -> None: begin = time.time() arguments = self.acc_stats_two_feats_arguments() - with tqdm.tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: if GLOBAL_CONFIG.use_mp: error_dict = {} return_queue = mp.Queue() diff --git a/montreal_forced_aligner/acoustic_modeling/trainer.py b/montreal_forced_aligner/acoustic_modeling/trainer.py index e00c6585..9d0940f5 100644 --- a/montreal_forced_aligner/acoustic_modeling/trainer.py +++ b/montreal_forced_aligner/acoustic_modeling/trainer.py @@ -15,8 +15,8 @@ from queue import Empty from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple -import tqdm from sqlalchemy.orm import Session, joinedload, subqueryload +from tqdm.rich import tqdm from montreal_forced_aligner.abc import KaldiFunction, ModelExporterMixin, TopLevelMfaWorker from montreal_forced_aligner.config import GLOBAL_CONFIG @@ -583,7 +583,7 @@ def compute_phone_pdf_counts(self) -> None: log_directory = self.working_log_directory os.makedirs(log_directory, exist_ok=True) arguments = self.transition_acc_arguments() - with tqdm.tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: if GLOBAL_CONFIG.use_mp: error_dict = {} return_queue = mp.Queue() diff --git a/montreal_forced_aligner/acoustic_modeling/triphone.py b/montreal_forced_aligner/acoustic_modeling/triphone.py index 420798c3..eff89373 100644 --- a/montreal_forced_aligner/acoustic_modeling/triphone.py +++ b/montreal_forced_aligner/acoustic_modeling/triphone.py @@ -11,7 +11,7 @@ from queue import Empty from typing import TYPE_CHECKING, Dict, List -import tqdm +from tqdm.rich import tqdm from montreal_forced_aligner.acoustic_modeling.base import AcousticModelTrainingMixin from montreal_forced_aligner.config import GLOBAL_CONFIG @@ -294,7 +294,7 @@ def convert_alignments(self) -> None: """ logger.info("Converting alignments...") arguments = self.convert_alignments_arguments() - with tqdm.tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: if GLOBAL_CONFIG.use_mp: error_dict = {} return_queue = mp.Queue() diff --git a/montreal_forced_aligner/alignment/adapting.py b/montreal_forced_aligner/alignment/adapting.py index 312b8389..2e6bc094 100644 --- a/montreal_forced_aligner/alignment/adapting.py +++ b/montreal_forced_aligner/alignment/adapting.py @@ -11,7 +11,7 @@ from queue import Empty from typing import TYPE_CHECKING, List -import tqdm +from tqdm.rich import tqdm from montreal_forced_aligner.abc import AdapterMixin from montreal_forced_aligner.alignment.multiprocessing import AccStatsArguments, AccStatsFunction @@ -124,7 +124,7 @@ def acc_stats(self, alignment: bool = False) -> None: initial_mdl_path = self.working_directory.joinpath("unadapted.mdl") final_mdl_path = self.working_directory.joinpath("final.mdl") logger.info("Accumulating statistics...") - with tqdm.tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: if GLOBAL_CONFIG.use_mp: error_dict = {} return_queue = mp.Queue() diff --git a/montreal_forced_aligner/alignment/base.py b/montreal_forced_aligner/alignment/base.py index 683538c2..6a795799 100644 --- a/montreal_forced_aligner/alignment/base.py +++ b/montreal_forced_aligner/alignment/base.py @@ -16,8 +16,8 @@ from typing import Dict, List, Optional import sqlalchemy -import tqdm from sqlalchemy.orm import joinedload, subqueryload +from tqdm.rich import tqdm from montreal_forced_aligner.abc import FileExporterMixin from montreal_forced_aligner.alignment.mixins import AlignMixin @@ -327,7 +327,7 @@ def compute_pronunciation_probabilities(self): } logger.info("Generating pronunciations...") arguments = self.generate_pronunciations_arguments() - with tqdm.tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: if GLOBAL_CONFIG.use_mp: error_dict = {} return_queue = mp.Queue() @@ -696,7 +696,7 @@ def collect_alignments(self) -> None: if max_word_interval_id is None: max_word_interval_id = 0 - with tqdm.tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: logger.info(f"Collecting phone and word alignments from {workflow.name} lattices...") arguments = self.alignment_extraction_arguments() @@ -867,7 +867,7 @@ def fine_tune_alignments(self) -> None: """ logger.info("Fine tuning alignments...") begin = time.time() - with self.session() as session, tqdm.tqdm( + with self.session() as session, tqdm( total=self.num_utterances, disable=GLOBAL_CONFIG.quiet ) as pbar: update_mappings = [] @@ -1027,7 +1027,7 @@ def export_textgrids( begin = time.time() error_dict = {} - with tqdm.tqdm(total=self.num_files, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_files, disable=GLOBAL_CONFIG.quiet) as pbar: with self.session() as session: files = ( session.query( diff --git a/montreal_forced_aligner/alignment/mixins.py b/montreal_forced_aligner/alignment/mixins.py index 44a265a5..1a4063db 100644 --- a/montreal_forced_aligner/alignment/mixins.py +++ b/montreal_forced_aligner/alignment/mixins.py @@ -12,7 +12,7 @@ from queue import Empty from typing import TYPE_CHECKING, Dict, List -import tqdm +from tqdm.rich import tqdm from montreal_forced_aligner.alignment.multiprocessing import ( AlignArguments, @@ -298,7 +298,7 @@ def compile_train_graphs(self) -> None: logger.info("Compiling training graphs...") error_sum = 0 arguments = self.compile_train_graphs_arguments() - with tqdm.tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: if GLOBAL_CONFIG.use_mp: error_dict = {} return_queue = mp.Queue() @@ -351,7 +351,7 @@ def get_phone_confidences(self): begin = time.time() with self.session() as session: - with tqdm.tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: arguments = self.phone_confidence_arguments() interval_update_mappings = [] if GLOBAL_CONFIG.use_mp: @@ -416,7 +416,7 @@ def align_utterances(self, training=False) -> None: """ begin = time.time() logger.info("Generating alignments...") - with tqdm.tqdm( + with tqdm( total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet ) as pbar, self.session() as session: if not training: diff --git a/montreal_forced_aligner/command_line/adapt.py b/montreal_forced_aligner/command_line/adapt.py index 8e51f551..2fd51197 100644 --- a/montreal_forced_aligner/command_line/adapt.py +++ b/montreal_forced_aligner/command_line/adapt.py @@ -4,7 +4,7 @@ import os from pathlib import Path -import click +import rich_click as click from montreal_forced_aligner.alignment import AdaptingAligner from montreal_forced_aligner.command_line.utils import ( diff --git a/montreal_forced_aligner/command_line/align.py b/montreal_forced_aligner/command_line/align.py index 5edc1551..c1fda2d1 100644 --- a/montreal_forced_aligner/command_line/align.py +++ b/montreal_forced_aligner/command_line/align.py @@ -4,7 +4,7 @@ import os from pathlib import Path -import click +import rich_click as click import yaml from montreal_forced_aligner.alignment import PretrainedAligner diff --git a/montreal_forced_aligner/command_line/anchor.py b/montreal_forced_aligner/command_line/anchor.py index 5bb25ae7..7d7a7403 100644 --- a/montreal_forced_aligner/command_line/anchor.py +++ b/montreal_forced_aligner/command_line/anchor.py @@ -1,12 +1,15 @@ """Command line functions for launching anchor annotation""" from __future__ import annotations +import logging import sys -import click +import rich_click as click __all__ = ["anchor_cli"] +logger = logging.getLogger("mfa") + @click.command(name="anchor", short_help="Launch Anchor") @click.help_option("-h", "--help") @@ -17,8 +20,7 @@ def anchor_cli(*args, **kwargs) -> None: # pragma: no cover try: from anchor.command_line import main except ImportError: - raise - print( + logger.error( "Anchor annotator utility is not installed, please install it via pip install anchor-annotator." ) sys.exit(1) diff --git a/montreal_forced_aligner/command_line/configure.py b/montreal_forced_aligner/command_line/configure.py index d5e398e2..5cbd3690 100644 --- a/montreal_forced_aligner/command_line/configure.py +++ b/montreal_forced_aligner/command_line/configure.py @@ -1,6 +1,6 @@ import os -import click +import rich_click as click from montreal_forced_aligner.config import GLOBAL_CONFIG, MFA_PROFILE_VARIABLE diff --git a/montreal_forced_aligner/command_line/create_segments.py b/montreal_forced_aligner/command_line/create_segments.py index a47198ad..f32b0889 100644 --- a/montreal_forced_aligner/command_line/create_segments.py +++ b/montreal_forced_aligner/command_line/create_segments.py @@ -4,7 +4,7 @@ import os from pathlib import Path -import click +import rich_click as click from montreal_forced_aligner.command_line.utils import ( check_databases, diff --git a/montreal_forced_aligner/command_line/diarize_speakers.py b/montreal_forced_aligner/command_line/diarize_speakers.py index 676b1e44..bd632210 100644 --- a/montreal_forced_aligner/command_line/diarize_speakers.py +++ b/montreal_forced_aligner/command_line/diarize_speakers.py @@ -4,7 +4,7 @@ import os from pathlib import Path -import click +import rich_click as click from montreal_forced_aligner.command_line.utils import ( check_databases, diff --git a/montreal_forced_aligner/command_line/g2p.py b/montreal_forced_aligner/command_line/g2p.py index 985385bb..9162e4c5 100644 --- a/montreal_forced_aligner/command_line/g2p.py +++ b/montreal_forced_aligner/command_line/g2p.py @@ -4,7 +4,7 @@ import os from pathlib import Path -import click +import rich_click as click from montreal_forced_aligner.command_line.utils import ( check_databases, diff --git a/montreal_forced_aligner/command_line/history.py b/montreal_forced_aligner/command_line/history.py index 0141474c..9b0326b2 100644 --- a/montreal_forced_aligner/command_line/history.py +++ b/montreal_forced_aligner/command_line/history.py @@ -1,11 +1,14 @@ +import logging import time -import click +import rich_click as click from montreal_forced_aligner.config import GLOBAL_CONFIG, load_command_history __all__ = ["history_cli"] +logger = logging.getLogger("mfa") + @click.command( "history", @@ -26,14 +29,14 @@ def history_cli(depth: int, verbose: bool) -> None: """ history = load_command_history()[-depth:] if verbose: - print("command\tDate\tExecution time\tVersion\tExit code\tException") + logger.info("command\tDate\tExecution time\tVersion\tExit code\tException") for h in history: execution_time = time.strftime("%H:%M:%S", time.gmtime(h["execution_time"])) d = h["date"].isoformat() - print( + logger.info( f"{h['command']}\t{d}\t{execution_time}\t{h.get('version', 'unknown')}\t{h['exit_code']}\t{h['exception']}" ) pass else: for h in history: - print(h["command"]) + logger.info(h["command"]) diff --git a/montreal_forced_aligner/command_line/mfa.py b/montreal_forced_aligner/command_line/mfa.py index 0e27f66b..e0af89a9 100644 --- a/montreal_forced_aligner/command_line/mfa.py +++ b/montreal_forced_aligner/command_line/mfa.py @@ -8,7 +8,7 @@ import warnings from datetime import datetime -import click +import rich_click as click from montreal_forced_aligner.command_line.adapt import adapt_model_cli from montreal_forced_aligner.command_line.align import align_corpus_cli @@ -107,17 +107,14 @@ def mfa_cli(ctx: click.Context) -> None: GLOBAL_CONFIG.load() from montreal_forced_aligner.helper import configure_logger - if not GLOBAL_CONFIG.current_profile.debug: - warnings.simplefilter("ignore") + warnings.simplefilter("ignore") configure_logger("mfa") check_third_party() if ctx.invoked_subcommand != "anchor": hooks = ExitHooks() hooks.hook() atexit.register(hooks.history_save_handler) - from colorama import init - init() mp.freeze_support() diff --git a/montreal_forced_aligner/command_line/model.py b/montreal_forced_aligner/command_line/model.py index df6f5ac6..246fbf11 100644 --- a/montreal_forced_aligner/command_line/model.py +++ b/montreal_forced_aligner/command_line/model.py @@ -7,7 +7,7 @@ import typing from pathlib import Path -import click +import rich_click as click from montreal_forced_aligner.command_line.utils import ( check_databases, @@ -42,7 +42,7 @@ @click.help_option("-h", "--help") def model_cli() -> None: """ - Inspect, download, and save pretrained MFA models + Inspect, download, and save pretrained MFA models and dictionaries """ pass @@ -100,8 +100,8 @@ def inspect_model_cli(model_type: str, model: str) -> None: from montreal_forced_aligner.config import GLOBAL_CONFIG, get_temporary_directory GLOBAL_CONFIG.current_profile.clean = True - GLOBAL_CONFIG.current_profile.temporary_directory = os.path.join( - get_temporary_directory(), "model_inspect" + GLOBAL_CONFIG.current_profile.temporary_directory = get_temporary_directory().joinpath( + "model_inspect" ) shutil.rmtree(GLOBAL_CONFIG.current_profile.temporary_directory, ignore_errors=True) if model_type and model_type not in MODEL_TYPES: @@ -130,8 +130,8 @@ def inspect_model_cli(model_type: str, model: str) -> None: if path is None: raise PretrainedModelNotFoundError(model) model = path - working_dir = os.path.join(get_temporary_directory(), "models", "inspect") - ext = os.path.splitext(model)[1] + working_dir = get_temporary_directory().joinpath("models", "inspect") + ext = model.suffix if model_type: if model_type == MODEL_TYPES["dictionary"]: m = MODEL_TYPES[model_type](model, working_dir, phone_set_type=PhoneSetType.AUTO) @@ -218,13 +218,13 @@ def save_model_cli(path: Path, model_type: str, name: str, overwrite: bool) -> N Type of model """ logger = logging.getLogger("mfa") - model_name = os.path.splitext(os.path.basename(path))[0] + model_name = path.stem model_class = MODEL_TYPES[model_type] if name: out_path = model_class.get_pretrained_path(name, enforce_existence=False) else: out_path = model_class.get_pretrained_path(model_name, enforce_existence=False) - if not overwrite and os.path.exists(out_path): + if not overwrite and out_path.exists(): raise ModelSaveError(out_path) shutil.copyfile(path, out_path) logger.info( diff --git a/montreal_forced_aligner/command_line/tokenize.py b/montreal_forced_aligner/command_line/tokenize.py index d46b0c1d..57c31b75 100644 --- a/montreal_forced_aligner/command_line/tokenize.py +++ b/montreal_forced_aligner/command_line/tokenize.py @@ -4,7 +4,7 @@ import os from pathlib import Path -import click +import rich_click as click from montreal_forced_aligner.command_line.utils import ( check_databases, diff --git a/montreal_forced_aligner/command_line/train_acoustic_model.py b/montreal_forced_aligner/command_line/train_acoustic_model.py index 6a8b0832..e6dcc1a8 100644 --- a/montreal_forced_aligner/command_line/train_acoustic_model.py +++ b/montreal_forced_aligner/command_line/train_acoustic_model.py @@ -4,7 +4,7 @@ import os from pathlib import Path -import click +import rich_click as click from montreal_forced_aligner.acoustic_modeling import TrainableAligner from montreal_forced_aligner.command_line.utils import ( diff --git a/montreal_forced_aligner/command_line/train_dictionary.py b/montreal_forced_aligner/command_line/train_dictionary.py index a1c0b0cd..a2f23fd2 100644 --- a/montreal_forced_aligner/command_line/train_dictionary.py +++ b/montreal_forced_aligner/command_line/train_dictionary.py @@ -4,7 +4,7 @@ import os from pathlib import Path -import click +import rich_click as click from montreal_forced_aligner.alignment.pretrained import DictionaryTrainer from montreal_forced_aligner.command_line.utils import ( diff --git a/montreal_forced_aligner/command_line/train_g2p.py b/montreal_forced_aligner/command_line/train_g2p.py index 1bffec6c..7fd83fba 100644 --- a/montreal_forced_aligner/command_line/train_g2p.py +++ b/montreal_forced_aligner/command_line/train_g2p.py @@ -4,7 +4,7 @@ import os from pathlib import Path -import click +import rich_click as click from montreal_forced_aligner.command_line.utils import ( check_databases, diff --git a/montreal_forced_aligner/command_line/train_ivector_extractor.py b/montreal_forced_aligner/command_line/train_ivector_extractor.py index 2980abb7..f94f73f7 100644 --- a/montreal_forced_aligner/command_line/train_ivector_extractor.py +++ b/montreal_forced_aligner/command_line/train_ivector_extractor.py @@ -4,7 +4,7 @@ import os from pathlib import Path -import click +import rich_click as click from montreal_forced_aligner.command_line.utils import ( check_databases, diff --git a/montreal_forced_aligner/command_line/train_lm.py b/montreal_forced_aligner/command_line/train_lm.py index 9959440b..ca61b112 100644 --- a/montreal_forced_aligner/command_line/train_lm.py +++ b/montreal_forced_aligner/command_line/train_lm.py @@ -4,7 +4,7 @@ import os from pathlib import Path -import click +import rich_click as click from montreal_forced_aligner.command_line.utils import ( check_databases, diff --git a/montreal_forced_aligner/command_line/train_tokenizer.py b/montreal_forced_aligner/command_line/train_tokenizer.py index 2cf241ed..76c6acfc 100644 --- a/montreal_forced_aligner/command_line/train_tokenizer.py +++ b/montreal_forced_aligner/command_line/train_tokenizer.py @@ -4,7 +4,7 @@ import os from pathlib import Path -import click +import rich_click as click from montreal_forced_aligner.command_line.utils import ( check_databases, diff --git a/montreal_forced_aligner/command_line/transcribe.py b/montreal_forced_aligner/command_line/transcribe.py index 0742e3d4..e6fb0a8b 100644 --- a/montreal_forced_aligner/command_line/transcribe.py +++ b/montreal_forced_aligner/command_line/transcribe.py @@ -4,7 +4,7 @@ import os from pathlib import Path -import click +import rich_click as click from montreal_forced_aligner.command_line.utils import ( check_databases, diff --git a/montreal_forced_aligner/command_line/utils.py b/montreal_forced_aligner/command_line/utils.py index 4100a6c2..17eaac26 100644 --- a/montreal_forced_aligner/command_line/utils.py +++ b/montreal_forced_aligner/command_line/utils.py @@ -8,7 +8,7 @@ import typing from pathlib import Path -import click +import rich_click as click import sqlalchemy import yaml diff --git a/montreal_forced_aligner/command_line/validate.py b/montreal_forced_aligner/command_line/validate.py index 9adf7e9c..3361c5e4 100644 --- a/montreal_forced_aligner/command_line/validate.py +++ b/montreal_forced_aligner/command_line/validate.py @@ -4,7 +4,7 @@ import os from pathlib import Path -import click +import rich_click as click from montreal_forced_aligner.command_line.utils import ( check_databases, diff --git a/montreal_forced_aligner/config.py b/montreal_forced_aligner/config.py index e6cf1ee2..c12e194d 100644 --- a/montreal_forced_aligner/config.py +++ b/montreal_forced_aligner/config.py @@ -11,9 +11,9 @@ import typing from typing import Any, Dict, List, Union -import click import dataclassy import joblib +import rich_click as click import yaml from dataclassy import dataclass @@ -52,8 +52,8 @@ def get_temporary_directory() -> pathlib.Path: :class:`~montreal_forced_aligner.exceptions.RootDirectoryError` """ TEMP_DIR = pathlib.Path( - os.environ.get(MFA_ROOT_ENVIRONMENT_VARIABLE, os.path.expanduser("~/Documents/MFA")) - ) + os.environ.get(MFA_ROOT_ENVIRONMENT_VARIABLE, "~/Documents/MFA") + ).expanduser() try: TEMP_DIR.mkdir(parents=True, exist_ok=True) except OSError: diff --git a/montreal_forced_aligner/corpus/acoustic_corpus.py b/montreal_forced_aligner/corpus/acoustic_corpus.py index 72bbf18e..263fc468 100644 --- a/montreal_forced_aligner/corpus/acoustic_corpus.py +++ b/montreal_forced_aligner/corpus/acoustic_corpus.py @@ -14,7 +14,7 @@ from typing import List, Optional import sqlalchemy -import tqdm +from tqdm.rich import tqdm from montreal_forced_aligner.abc import MfaWorker from montreal_forced_aligner.config import GLOBAL_CONFIG @@ -228,7 +228,7 @@ def load_reference_alignments(self, reference_directory: Path) -> None: indices = [] jobs = [] reference_intervals = [] - with tqdm.tqdm( + with tqdm( total=self.num_files, disable=GLOBAL_CONFIG.quiet ) as pbar, self.session() as session: phone_mapping = {} @@ -385,7 +385,7 @@ def generate_final_features(self) -> None: log_directory = self.split_directory.joinpath("log") os.makedirs(log_directory, exist_ok=True) arguments = self.final_feature_arguments() - with tqdm.tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: for _ in run_kaldi_function(FinalFeatureFunction, arguments, pbar.update): pass with self.session() as session: @@ -488,7 +488,7 @@ def create_corpus_split(self) -> None: else: logger.info("Creating corpus split for feature generation...") os.makedirs(self.split_directory.joinpath("log"), exist_ok=True) - with self.session() as session, tqdm.tqdm( + with self.session() as session, tqdm( total=self.num_utterances + self.num_files, disable=GLOBAL_CONFIG.quiet ) as pbar: jobs = session.query(Job) @@ -636,7 +636,7 @@ def compute_speaker_pitch_ranges(self): os.makedirs(log_directory, exist_ok=True) arguments = self.pitch_range_arguments() update_mapping = [] - with tqdm.tqdm(total=self.num_speakers, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_speakers, disable=GLOBAL_CONFIG.quiet) as pbar: for speaker_id, min_f0, max_f0 in run_kaldi_function( PitchRangeFunction, arguments, pbar.update ): @@ -665,7 +665,7 @@ def mfcc(self) -> None: log_directory = self.split_directory.joinpath("log") os.makedirs(log_directory, exist_ok=True) arguments = self.mfcc_arguments() - with tqdm.tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: for _ in run_kaldi_function(MfccFunction, arguments, pbar.update): pass logger.debug(f"Generating MFCCs took {time.time() - begin:.3f} seconds") @@ -736,7 +736,7 @@ def calc_fmllr(self, iteration: Optional[int] = None) -> None: logger.info("Calculating fMLLR for speaker adaptation...") arguments = self.calc_fmllr_arguments(iteration=iteration) - with tqdm.tqdm(total=self.num_speakers, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_speakers, disable=GLOBAL_CONFIG.quiet) as pbar: if GLOBAL_CONFIG.use_mp: error_dict = {} return_queue = mp.Queue() @@ -797,7 +797,7 @@ def compute_vad(self) -> None: logger.info("Computing VAD...") arguments = self.compute_vad_arguments() - with tqdm.tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: if GLOBAL_CONFIG.use_mp: error_dict = {} return_queue = mp.Queue() @@ -967,9 +967,7 @@ def _load_corpus_from_source_mp(self) -> None: p.start() last_poll = time.time() - 30 try: - with self.session() as session, tqdm.tqdm( - total=100, disable=GLOBAL_CONFIG.quiet - ) as pbar: + with self.session() as session, tqdm(total=100, disable=GLOBAL_CONFIG.quiet) as pbar: import_data = DatabaseImportData() while True: try: diff --git a/montreal_forced_aligner/corpus/base.py b/montreal_forced_aligner/corpus/base.py index e4281d94..e778b2bb 100644 --- a/montreal_forced_aligner/corpus/base.py +++ b/montreal_forced_aligner/corpus/base.py @@ -9,8 +9,8 @@ from pathlib import Path import sqlalchemy.engine -import tqdm from sqlalchemy.orm import Session, joinedload, selectinload, subqueryload +from tqdm.rich import tqdm from montreal_forced_aligner.abc import DatabaseMixin, MfaWorker from montreal_forced_aligner.config import GLOBAL_CONFIG @@ -302,7 +302,7 @@ def _write_spk2utt(self) -> None: def create_corpus_split(self) -> None: """Create split directory and output information from Jobs""" os.makedirs(self.split_directory.joinpath("log"), exist_ok=True) - with self.session() as session, tqdm.tqdm( + with self.session() as session, tqdm( total=self.num_utterances, disable=GLOBAL_CONFIG.quiet ) as pbar: jobs = session.query(Job) @@ -662,7 +662,7 @@ def normalize_text(self) -> None: update_mapping = [] word_key = self.get_next_primary_key(Word) pronunciation_key = self.get_next_primary_key(Pronunciation) - with tqdm.tqdm( + with tqdm( total=self.num_utterances, disable=GLOBAL_CONFIG.quiet ) as pbar, self.session() as session: dictionaries: typing.Dict[int, Dictionary] = { @@ -1169,7 +1169,7 @@ def create_subset(self, subset: int) -> None: session.commit() logger.debug(f"Setting subset flags took {time.time()-begin} seconds") - with self.session() as session, tqdm.tqdm( + with self.session() as session, tqdm( total=subset, disable=GLOBAL_CONFIG.quiet ) as pbar: jobs = ( diff --git a/montreal_forced_aligner/corpus/ivector_corpus.py b/montreal_forced_aligner/corpus/ivector_corpus.py index 125bed20..fe68be88 100644 --- a/montreal_forced_aligner/corpus/ivector_corpus.py +++ b/montreal_forced_aligner/corpus/ivector_corpus.py @@ -10,7 +10,7 @@ import numpy as np import sqlalchemy -import tqdm +from tqdm.rich import tqdm from montreal_forced_aligner.config import GLOBAL_CONFIG, IVECTOR_DIMENSION from montreal_forced_aligner.corpus.acoustic_corpus import AcousticCorpusMixin @@ -187,7 +187,7 @@ def compute_plda(self) -> None: if self.stopped.stop_check(): logger.debug("PLDA computation stopped early.") return - with tqdm.tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar, mfa_open( + with tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar, mfa_open( log_path, "w" ) as log_file: @@ -262,7 +262,7 @@ def extract_ivectors(self) -> None: return logger.info("Extracting ivectors...") arguments = self.extract_ivectors_arguments() - with tqdm.tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: for _ in run_kaldi_function(ExtractIvectorsFunction, arguments, pbar.update): pass self.collect_utterance_ivectors() @@ -314,7 +314,7 @@ def collect_utterance_ivectors(self) -> None: for line in f: scp_line = line.strip().split(maxsplit=1) ivector_arks[int(scp_line[0].split("-")[-1])] = scp_line[-1] - with self.session() as session, tqdm.tqdm( + with self.session() as session, tqdm( total=self.num_utterances, disable=GLOBAL_CONFIG.quiet ) as pbar: update_mapping = {} @@ -382,7 +382,7 @@ def collect_speaker_ivectors(self) -> None: num_utts_path = self.working_directory.joinpath("current_num_utts.ark") if not os.path.exists(speaker_ivector_ark_path): self.compute_speaker_ivectors() - with self.session() as session, tqdm.tqdm( + with self.session() as session, tqdm( total=self.num_speakers, disable=GLOBAL_CONFIG.quiet ) as pbar: utterance_counts = {} diff --git a/montreal_forced_aligner/corpus/text_corpus.py b/montreal_forced_aligner/corpus/text_corpus.py index e51c91f1..8fc8599b 100644 --- a/montreal_forced_aligner/corpus/text_corpus.py +++ b/montreal_forced_aligner/corpus/text_corpus.py @@ -9,7 +9,7 @@ from pathlib import Path from queue import Empty -import tqdm +from tqdm.rich import tqdm from montreal_forced_aligner.abc import MfaWorker, TemporaryDirectoryMixin from montreal_forced_aligner.config import GLOBAL_CONFIG @@ -65,9 +65,7 @@ def _load_corpus_from_source_mp(self) -> None: import_data = DatabaseImportData() try: file_count = 0 - with tqdm.tqdm( - total=1, disable=GLOBAL_CONFIG.quiet - ) as pbar, self.session() as session: + with tqdm(total=1, disable=GLOBAL_CONFIG.quiet) as pbar, self.session() as session: for root, _, files in os.walk(self.corpus_directory, followlinks=True): exts = find_exts(files) relative_path = ( diff --git a/montreal_forced_aligner/diarization/speaker_diarizer.py b/montreal_forced_aligner/diarization/speaker_diarizer.py index 4003a9da..8bd5ab2a 100644 --- a/montreal_forced_aligner/diarization/speaker_diarizer.py +++ b/montreal_forced_aligner/diarization/speaker_diarizer.py @@ -20,10 +20,10 @@ import numpy as np import sqlalchemy -import tqdm import yaml from sklearn import decomposition, metrics from sqlalchemy.orm import joinedload, selectinload +from tqdm.rich import tqdm from montreal_forced_aligner.abc import FileExporterMixin, TopLevelMfaWorker from montreal_forced_aligner.alignment.multiprocessing import construct_output_path @@ -323,7 +323,7 @@ def classify_speakers(self): self.setup() logger.info("Classifying utterances...") - with self.session() as session, tqdm.tqdm( + with self.session() as session, tqdm( total=self.num_utterances, disable=GLOBAL_CONFIG.quiet ) as pbar, mfa_open( self.working_directory.joinpath("speaker_classification_results.csv"), "w" @@ -633,7 +633,7 @@ def visualize_clusters(self, ivectors, cluster_labels=None): def export_xvectors(self): logger.info("Exporting SpeechBrain embeddings...") os.makedirs(self.split_directory, exist_ok=True) - with tqdm.tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: arguments = [ ExportIvectorsArguments( j.id, @@ -703,7 +703,7 @@ def initialize_mfa_clustering(self): logger.info("Generating initial speaker labels...") utt2spk = {k: v for k, v in session.query(Utterance.id, Utterance.speaker_id)} - with tqdm.tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: for utt_id, classified_speaker, score in run_kaldi_function( func, arguments, pbar.update ): @@ -753,7 +753,7 @@ def initialize_mfa_clustering(self): def export_speaker_ivectors(self): logger.info("Exporting current speaker ivectors...") - with self.session() as session, tqdm.tqdm( + with self.session() as session, tqdm( total=self.num_speakers, disable=GLOBAL_CONFIG.quiet ) as pbar, mfa_open(self.num_utts_path, "w") as f: if self.use_xvector: @@ -806,7 +806,7 @@ def classify_iteration(self, iteration=None) -> None: self.max_iterations, )[iteration] logger.debug(f"Score threshold: {score_threshold}") - with self.session() as session, tqdm.tqdm( + with self.session() as session, tqdm( total=self.num_utterances, disable=GLOBAL_CONFIG.quiet ) as pbar: @@ -876,9 +876,7 @@ def breakup_large_clusters(self): logger.info("Breaking up large speakers...") logger.debug(f"Unknown speaker is {unknown_speaker_id}") next_speaker_id = self.get_next_primary_key(Speaker) - with tqdm.tqdm( - total=len(above_threshold_speakers), disable=GLOBAL_CONFIG.quiet - ) as pbar: + with tqdm(total=len(above_threshold_speakers), disable=GLOBAL_CONFIG.quiet) as pbar: utterance_mapping = [] new_speakers = {} for s_id in above_threshold_speakers: @@ -1262,7 +1260,7 @@ def calculate_eer(self) -> typing.Tuple[float, float]: limit_per_speaker = 5 limit_within_speaker = 30 begin = time.time() - with tqdm.tqdm(total=self.num_speakers, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_speakers, disable=GLOBAL_CONFIG.quiet) as pbar: arguments = [ ComputeEerArguments( j.id, @@ -1309,7 +1307,7 @@ def load_embeddings(self) -> None: logger.info("Embeddings already loaded.") return logger.info("Loading SpeechBrain embeddings...") - with tqdm.tqdm( + with tqdm( total=self.num_utterances, disable=GLOBAL_CONFIG.quiet ) as pbar, self.session() as session: begin = time.time() @@ -1361,7 +1359,7 @@ def load_embeddings(self) -> None: def refresh_plda_vectors(self): logger.info("Refreshing PLDA vectors...") self.plda = PldaModel.load(self.plda_path) - with self.session() as session, tqdm.tqdm( + with self.session() as session, tqdm( total=self.num_utterances, disable=GLOBAL_CONFIG.quiet ) as pbar: if self.use_xvector: @@ -1391,7 +1389,7 @@ def refresh_plda_vectors(self): def refresh_speaker_vectors(self) -> None: """Refresh speaker vectors following clustering or classification""" logger.info("Refreshing speaker vectors...") - with self.session() as session, tqdm.tqdm( + with self.session() as session, tqdm( total=self.num_speakers, disable=GLOBAL_CONFIG.quiet ) as pbar: if self.use_xvector: @@ -1431,7 +1429,7 @@ def compute_speaker_embeddings(self) -> None: if not self.has_xvectors(): self.load_embeddings() logger.info("Computing SpeechBrain speaker embeddings...") - with tqdm.tqdm( + with tqdm( total=self.num_speakers, disable=GLOBAL_CONFIG.quiet ) as pbar, self.session() as session: update_mapping = [] @@ -1500,7 +1498,7 @@ def export_files(self, output_directory: str) -> None: joinedload(File.sound_file, innerjoin=True).load_only(SoundFile.duration), joinedload(File.text_file, innerjoin=True).load_only(TextFile.file_type), ) - with tqdm.tqdm(total=self.num_files, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_files, disable=GLOBAL_CONFIG.quiet) as pbar: for file in files: utterance_count = len(file.utterances) diff --git a/montreal_forced_aligner/dictionary/multispeaker.py b/montreal_forced_aligner/dictionary/multispeaker.py index aa80ad4c..d75bd634 100644 --- a/montreal_forced_aligner/dictionary/multispeaker.py +++ b/montreal_forced_aligner/dictionary/multispeaker.py @@ -16,9 +16,9 @@ import pynini import pywrapfst import sqlalchemy.orm.session -import tqdm import yaml from sqlalchemy.orm import selectinload +from tqdm.rich import tqdm from montreal_forced_aligner.config import GLOBAL_CONFIG from montreal_forced_aligner.data import PhoneType, WordType @@ -620,7 +620,7 @@ def apply_phonological_rules(self) -> None: with self.session() as session: num_words = session.query(Word).count() logger.info("Applying phonological rules...") - with tqdm.tqdm(total=num_words, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=num_words, disable=GLOBAL_CONFIG.quiet) as pbar: new_pron_objs = [] rule_application_objs = [] dialect_ids = {d.name: d.id for d in session.query(Dialect).all()} diff --git a/montreal_forced_aligner/exceptions.py b/montreal_forced_aligner/exceptions.py index e2affd2b..ebb7748c 100644 --- a/montreal_forced_aligner/exceptions.py +++ b/montreal_forced_aligner/exceptions.py @@ -15,7 +15,7 @@ import requests.structures -from montreal_forced_aligner.helper import TerminalPrinter, comma_join +from montreal_forced_aligner.helper import comma_join if TYPE_CHECKING: from montreal_forced_aligner.dictionary.mixins import DictionaryMixin @@ -65,20 +65,17 @@ class MFAError(Exception): """ def __init__(self, base_error_message: str, *args, **kwargs): - self.printer = TerminalPrinter() self.message_lines: List[str] = [base_error_message] @property def message(self) -> str: """Formatted exception message""" - return "\n".join(self.printer.format_info_lines(self.message_lines)) + return "\n".join(self.message_lines) def __str__(self) -> str: """Output the error""" - message = self.printer.error_text(type(self).__name__) + ":" - self.printer.indent_level += 1 + message = type(self).__name__ + ":" message += "\n\n" + self.message - self.printer.indent_level -= 1 return message @@ -95,12 +92,12 @@ class PlatformError(MFAError): def __init__(self, functionality_name): super().__init__("") self.message_lines = [ - f"Functionality for {self.printer.emphasized_text(functionality_name)} is not available on {self.printer.error_text(sys.platform)}." + f"Functionality for {functionality_name} is not available on {sys.platform}." ] if sys.platform == "win32": self.message_lines.append("") self.message_lines.append( - f" If you'd like to use {self.printer.emphasized_text(functionality_name)} on Windows, please follow the MFA installation " + f" If you'd like to use {functionality_name} on Windows, please follow the MFA installation " f"instructions for the Windows Subsystem for Linux (WSL)." ) @@ -129,39 +126,37 @@ def __init__( super().__init__("") if error_text: self.message_lines = [ - f"There was an error when invoking '{self.printer.error_text(binary_name)}':", + f"There was an error when invoking '{binary_name}':", error_text, "This likely indicates that MFA's dependencies were not correctly installed, or there is an issue with your Conda environment.", "If you are in the correct environment, please try re-creating the environment from scratch as a first step, i.e.:", - self.printer.pass_text( - "conda create -n aligner -c conda-forge montreal-forced-aligner" - ), + "conda create -n aligner -c conda-forge montreal-forced-aligner", ] else: - self.message_lines = [f"Could not find '{self.printer.error_text(binary_name)}'."] + self.message_lines = [f"Could not find '{binary_name}'."] self.message_lines.append( "Please ensure that you have installed MFA's conda dependencies and are in the correct environment." ) if open_fst: self.message_lines.append( - f"Please ensure that you are in an environment that has the {self.printer.emphasized_text('openfst')} conda package installed, " - f"or that the {self.printer.emphasized_text('openfst')} binaries are on your path if you compiled them yourself." + f"Please ensure that you are in an environment that has the {'openfst'} conda package installed, " + f"or that the {'openfst'} binaries are on your path if you compiled them yourself." ) elif open_blas: self.message_lines.append( - f"Try installing {self.printer.emphasized_text('openblas')} via system package manager or verify it's on your system path?" + f"Try installing {'openblas'} via system package manager or verify it's on your system path?" ) elif libc: self.message_lines.append( - f"You likely have a different version of {self.printer.emphasized_text('glibc')} than the packages binaries use. " - f"Try compiling {self.printer.emphasized_text('Kaldi')} on your machine and collecting the binaries via the " - f"{self.printer.pass_text('mfa thirdparty kaldi')} command." + f"You likely have a different version of {'glibc'} than the packages binaries use. " + f"Try compiling {'Kaldi'} on your machine and collecting the binaries via the " + f"{'mfa thirdparty kaldi'} command." ) elif sox: self.message_lines = [] self.message_lines.append( - f"Your version of {self.printer.emphasized_text('sox')} does not support the file format in your corpus. " - f"Try installing another version of {self.printer.emphasized_text('sox')} with support for {self.printer.error_text(binary_name)}." + f"Your version of {'sox'} does not support the file format in your corpus. " + f"Try installing another version of {'sox'} with support for {binary_name}." ) @@ -199,9 +194,7 @@ class ModelLoadError(ModelError): def __init__(self, path: typing.Union[str, Path]): super().__init__("") - self.message_lines = [ - f"The archive {self.printer.error_text(path)} could not be parsed as an MFA model." - ] + self.message_lines = [f"The archive {path} could not be parsed as an MFA model."] class ModelSaveError(ModelError): @@ -217,7 +210,7 @@ class ModelSaveError(ModelError): def __init__(self, path: Path): super().__init__("") self.message_lines = [ - f"The archive {self.printer.error_text(path)} already exists.", + f"The archive {path} already exists.", "Please specify --overwrite if you would like to overwrite it.", ] @@ -247,9 +240,9 @@ def __init__( rate_limit = headers["x-ratelimit-limit"] rate_limit_reset = datetime.datetime.fromtimestamp(int(headers["x-ratelimit-reset"])) self.message_lines = [ - f"Current hourly rate limit ({self.printer.error_text(rate_limit)} per hour) has been exceeded for the GitHub API.", + f"Current hourly rate limit ({rate_limit} per hour) has been exceeded for the GitHub API.", "You can increase it by providing a personal authentication token to via --github_token.", - f"The rate limit will reset at {self.printer.pass_text(rate_limit_reset)}", + f"The rate limit will reset at {rate_limit_reset}", ] else: self.message_lines = [ @@ -280,7 +273,7 @@ class PhoneMismatchError(DictionaryError): def __init__(self, missing_phones: Collection[str]): super().__init__("There were extra phones that were not in the dictionary: ") - missing_phones = [f"{self.printer.error_text(x)}" for x in sorted(missing_phones)] + missing_phones = [f"{x}" for x in sorted(missing_phones)] self.message_lines.append(comma_join(missing_phones)) @@ -291,7 +284,7 @@ class NoDefaultSpeakerDictionaryError(DictionaryError): def __init__(self): super().__init__("") - self.message_lines = [f'No "{self.printer.error_text("default")}" dictionary was found.'] + self.message_lines = [f'No "{"default"}" dictionary was found.'] class DictionaryPathError(DictionaryError): @@ -307,7 +300,7 @@ class DictionaryPathError(DictionaryError): def __init__(self, input_path: Path): super().__init__("") self.message_lines = [ - f"The specified path for the dictionary ({self.printer.error_text(input_path)}) was not found." + f"The specified path for the dictionary ({input_path}) was not found." ] @@ -324,7 +317,7 @@ class DictionaryFileError(DictionaryError): def __init__(self, input_path: Path): super().__init__("") self.message_lines = [ - f"The specified path for the dictionary ({self.printer.error_text(input_path)}) is not a file." + f"The specified path for the dictionary ({input_path}) is not a file." ] @@ -351,7 +344,7 @@ class CorpusReadError(CorpusError): def __init__(self, file_name: str): super().__init__("") - self.message_lines = [f"There was an error reading {self.printer.error_text(file_name)}."] + self.message_lines = [f"There was an error reading {file_name}."] class TextParseError(CorpusReadError): @@ -367,8 +360,7 @@ class TextParseError(CorpusReadError): def __init__(self, file_name: str): super().__init__("") self.message_lines = [ - f"There was an error decoding {self.printer.error_text(file_name)}, " - f"maybe try resaving it as utf8?" + f"There was an error decoding {file_name}, " f"maybe try resaving it as utf8?" ] @@ -390,7 +382,7 @@ def __init__(self, file_name: str, error: str): self.error = error self.message_lines.extend( [ - f"Reading {self.printer.emphasized_text(file_name)} has the following error:", + f"Reading {file_name} has the following error:", "", "", self.error, @@ -424,7 +416,7 @@ def __init__(self, file_name: str, error: str): self.error = error self.message_lines.extend( [ - f"Reading {self.printer.emphasized_text(file_name)} has the following error:", + f"Reading {file_name} has the following error:", "", "", self.error, @@ -478,13 +470,13 @@ class AlignmentError(MFAError): def __init__(self, error_logs: List[str]): super().__init__("") self.message_lines = [ - f"There were {self.printer.error_text(len(error_logs))} job(s) with errors. " + f"There were {len(error_logs)} job(s) with errors. " f"For more information, please see:", "", "", ] for path in error_logs: - self.message_lines.append(self.printer.error_text(path)) + self.message_lines.append(path) class AlignmentExportError(AlignmentError): @@ -542,7 +534,7 @@ class PronunciationAcousticMismatchError(AlignerError): def __init__(self, missing_phones: Collection[str]): super().__init__("There were phones in the dictionary that do not have acoustic models: ") - missing_phones = [f"{self.printer.error_text(x)}" for x in sorted(missing_phones)] + missing_phones = [f"{x}" for x in sorted(missing_phones)] self.message_lines.append(comma_join(missing_phones)) @@ -563,7 +555,7 @@ def __init__(self, g2p_model: G2PModel, dictionary: DictionaryMixin): "There were graphemes in the corpus that are not covered by the G2P model:" ) missing_graphs = dictionary.graphemes - set(g2p_model.meta["graphemes"]) - missing_graphs = [f"{self.printer.error_text(x)}" for x in sorted(missing_graphs)] + missing_graphs = [f"{x}" for x in sorted(missing_graphs)] self.message_lines.append(comma_join(missing_graphs)) @@ -590,7 +582,7 @@ class FileArgumentNotFoundError(ArgumentError): def __init__(self, path: Path): super().__init__("") - self.message_lines = [f'Could not find "{self.printer.error_text(path)}".'] + self.message_lines = [f'Could not find "{path}".'] class PretrainedModelNotFoundError(ArgumentError): @@ -614,11 +606,9 @@ def __init__( extra = "" if model_type: extra += f" for {model_type}" - self.message_lines = [ - f'Could not find a model named "{self.printer.error_text(name)}"{extra}.' - ] + self.message_lines = [f'Could not find a model named "{name}"{extra}.'] if available: - available = [f"{self.printer.pass_text(x)}" for x in available] + available = [f"{x}" for x in available] self.message_lines.append(f"Available: {comma_join(available)}.") @@ -643,11 +633,9 @@ def __init__( extra = "" if model_type: extra += f" for {model_type}" - self.message_lines = [ - f'Could not find a model named "{self.printer.error_text(name)}"{extra}.' - ] + self.message_lines = [f'Could not find a model named "{name}"{extra}.'] if available: - available = [f"{self.printer.pass_text(x)}" for x in available] + available = [f"{x}" for x in available] self.message_lines.append(f"Available: {comma_join(available)}.") self.message_lines.append( "You can see all available models either on https://mfa-models.readthedocs.io/en/latest/ or https://github.com/MontrealCorpusTools/mfa-models/releases." @@ -672,8 +660,8 @@ class MultipleModelTypesFoundError(ArgumentError): def __init__(self, name: str, possible_model_types: List[str]): super().__init__("") - self.message_lines = [f'Found multiple model types for "{self.printer.error_text(name)}":'] - possible_model_types = [f"{self.printer.error_text(x)}" for x in possible_model_types] + self.message_lines = [f'Found multiple model types for "{name}":'] + possible_model_types = [f"{x}" for x in possible_model_types] self.message_lines.extend( [", ".join(possible_model_types), "Please specify a model type to inspect."] ) @@ -698,12 +686,10 @@ def __init__(self, name: str, model_type: str, extensions: List[str]): extra = "" if model_type: extra += f" for {model_type}" - self.message_lines = [ - f'The path "{self.printer.error_text(name)}" does not have the correct extensions{extra}.' - ] + self.message_lines = [f'The path "{name}" does not have the correct extensions{extra}.'] if extensions: - available = [f"{self.printer.pass_text(x)}" for x in extensions] + available = [f"{x}" for x in extensions] self.message_lines.append(f" Possible extensions: {comma_join(available)}.") @@ -721,11 +707,9 @@ class ModelTypeNotSupportedError(ArgumentError): def __init__(self, model_type, model_types): super().__init__("") - self.message_lines = [ - f'The model type "{self.printer.error_text(model_type)}" is not supported.' - ] + self.message_lines = [f'The model type "{model_type}" is not supported.'] if model_types: - model_types = [f"{self.printer.pass_text(x)}" for x in sorted(model_types)] + model_types = [f"{x}" for x in sorted(model_types)] self.message_lines.append(f" Possible model types: {comma_join(model_types)}.") @@ -745,8 +729,8 @@ class RootDirectoryError(ConfigError): def __init__(self, temporary_directory, variable): super().__init__("") self.message_lines = [ - f"Could not create a root MFA temporary directory (tried {self.printer.error_text(temporary_directory)}. ", - f"Please specify a write-able directory via the {self.printer.emphasized_text(variable)} environment variable.", + f"Could not create a root MFA temporary directory (tried {temporary_directory}. ", + f"Please specify a write-able directory via the {variable} environment variable.", ] @@ -775,10 +759,8 @@ def __init__(self, error_dict: Dict[str, Exception]): super().__init__("The following Pynini alignment jobs encountered errors:") self.message_lines.extend(["", ""]) for k, v in error_dict.items(): - self.message_lines.append(self.printer.indent_string + self.printer.error_text(k)) - self.message_lines.append( - self.printer.indent_string + self.printer.emphasized_text(str(v)) - ) + self.message_lines.append(k) + self.message_lines.append(str(v)) class PyniniGenerationError(G2PError): @@ -790,10 +772,8 @@ def __init__(self, error_dict: Dict[str, Exception]): super().__init__("The following words had errors in running G2P:") self.message_lines.extend(["", ""]) for k, v in error_dict.items(): - self.message_lines.append(self.printer.indent_string + self.printer.error_text(k)) - self.message_lines.append( - self.printer.indent_string + self.printer.emphasized_text(str(v)) - ) + self.message_lines.append(k) + self.message_lines.append(str(v)) class PhonetisaurusSymbolError(G2PError): @@ -845,7 +825,7 @@ class MultiprocessingError(MFAError): def __init__(self, job_name: int, error_text: str): super().__init__(f"Job {job_name} encountered an error:") - self.message_lines = [f"Job {self.printer.error_text(job_name)} encountered an error:"] + self.message_lines = [f"Job {job_name} encountered an error:"] self.job_name = job_name self.message_lines.extend( [self.highlight_line(x) for x in error_text.splitlines(keepends=False)] @@ -865,8 +845,8 @@ def highlight_line(self, line: str) -> str: str Highlighted line """ - emph_replacement = self.printer.emphasized_text(r"\1") - err_replacement = self.printer.error_text(r"\1") + emph_replacement = r"\1" + err_replacement = r"\1" line = re.sub(r"File \"(.*)\"", f'File "{emph_replacement}"', line) line = re.sub(r"line (\d+)", f"line {err_replacement}", line) return line @@ -908,9 +888,7 @@ def refresh_message(self) -> None: for line in f: self.message_lines.append(line.strip()) if self.log_file: - self.message_lines.append( - f" For more details, please check {self.printer.error_text(self.log_file)}" - ) + self.message_lines.append(f" For more details, please check {self.log_file}") def append_error_log(self, error_log: str) -> None: """ diff --git a/montreal_forced_aligner/g2p/generator.py b/montreal_forced_aligner/g2p/generator.py index 64d1efb9..f36e6329 100644 --- a/montreal_forced_aligner/g2p/generator.py +++ b/montreal_forced_aligner/g2p/generator.py @@ -13,10 +13,10 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union import pynini -import tqdm from pynini import Fst, TokenType from pynini.lib import rewrite from pywrapfst import SymbolTable +from tqdm.rich import tqdm from montreal_forced_aligner.abc import DatabaseMixin, TopLevelMfaWorker from montreal_forced_aligner.config import GLOBAL_CONFIG @@ -412,7 +412,7 @@ def generate_pronunciations(self) -> Dict[str, List[str]]: to_return = {} skipped_words = 0 if num_words < 30 or GLOBAL_CONFIG.num_jobs == 1: - with tqdm.tqdm(total=num_words, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=num_words, disable=GLOBAL_CONFIG.quiet) as pbar: for word in self.words_to_g2p: w, m = clean_up_word(word, self.g2p_model.meta["graphemes"]) pbar.update(1) @@ -462,7 +462,7 @@ def generate_pronunciations(self) -> Dict[str, List[str]]: procs.append(p) p.start() num_words -= skipped_words - with tqdm.tqdm(total=num_words, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=num_words, disable=GLOBAL_CONFIG.quiet) as pbar: while True: try: word, result = return_queue.get(timeout=1) diff --git a/montreal_forced_aligner/g2p/phonetisaurus_trainer.py b/montreal_forced_aligner/g2p/phonetisaurus_trainer.py index 94c79293..a041bd87 100644 --- a/montreal_forced_aligner/g2p/phonetisaurus_trainer.py +++ b/montreal_forced_aligner/g2p/phonetisaurus_trainer.py @@ -14,9 +14,9 @@ import pynini import pywrapfst import sqlalchemy -import tqdm from pynini.lib import rewrite from sqlalchemy.orm import scoped_session, sessionmaker +from tqdm.rich import tqdm from montreal_forced_aligner.abc import MetaDict, TopLevelMfaWorker from montreal_forced_aligner.config import GLOBAL_CONFIG @@ -813,7 +813,7 @@ def initialize_alignments(self) -> None: symbols = {} job_symbols = {} symbol_id = 1 - with tqdm.tqdm( + with tqdm( total=self.g2p_num_training_pronunciations, disable=GLOBAL_CONFIG.quiet ) as pbar, self.session(autoflush=False, autocommit=False) as session: while True: @@ -921,9 +921,7 @@ def maximization(self, last_iteration=False) -> float: procs[-1].start() error_list = [] - with tqdm.tqdm( - total=self.g2p_num_training_pronunciations, disable=GLOBAL_CONFIG.quiet - ) as pbar: + with tqdm(total=self.g2p_num_training_pronunciations, disable=GLOBAL_CONFIG.quiet) as pbar: while True: try: result = return_queue.get(timeout=1) @@ -974,9 +972,7 @@ def expectation(self) -> None: procs[-1].start() mappings = {} zero = pynini.Weight.zero("log") - with tqdm.tqdm( - total=self.g2p_num_training_pronunciations, disable=GLOBAL_CONFIG.quiet - ) as pbar: + with tqdm(total=self.g2p_num_training_pronunciations, disable=GLOBAL_CONFIG.quiet) as pbar: while True: try: result = return_queue.get(timeout=1) @@ -1037,9 +1033,7 @@ def train_ngram_model(self) -> None: count_paths.append(args.far_path.with_suffix(".cnts")) procs[-1].start() - with tqdm.tqdm( - total=self.g2p_num_training_pronunciations, disable=GLOBAL_CONFIG.quiet - ) as pbar: + with tqdm(total=self.g2p_num_training_pronunciations, disable=GLOBAL_CONFIG.quiet) as pbar: while True: try: result = return_queue.get(timeout=1) @@ -1317,9 +1311,7 @@ def export_alignments(self) -> None: count_paths.append(args.far_path.with_suffix(".cnts")) procs[-1].start() - with tqdm.tqdm( - total=self.g2p_num_training_pronunciations, disable=GLOBAL_CONFIG.quiet - ) as pbar: + with tqdm(total=self.g2p_num_training_pronunciations, disable=GLOBAL_CONFIG.quiet) as pbar: while True: try: result = return_queue.get(timeout=1) diff --git a/montreal_forced_aligner/g2p/trainer.py b/montreal_forced_aligner/g2p/trainer.py index 6a4e9b8c..16e5765c 100644 --- a/montreal_forced_aligner/g2p/trainer.py +++ b/montreal_forced_aligner/g2p/trainer.py @@ -18,8 +18,8 @@ import pynini import pywrapfst -import tqdm from pynini import Fst +from tqdm.rich import tqdm from montreal_forced_aligner.abc import MetaDict, MfaWorker, TopLevelMfaWorker, TrainerMixin from montreal_forced_aligner.config import GLOBAL_CONFIG @@ -538,7 +538,7 @@ def _alignments(self) -> None: # Actually runs starts. logger.info("Calculating alignments...") begin = time.time() - with tqdm.tqdm( + with tqdm( total=num_commands * self.num_iterations, disable=GLOBAL_CONFIG.quiet ) as pbar: for start in starts: diff --git a/montreal_forced_aligner/helper.py b/montreal_forced_aligner/helper.py index 54f96c1c..c0d3fb5e 100644 --- a/montreal_forced_aligner/helper.py +++ b/montreal_forced_aligner/helper.py @@ -10,19 +10,18 @@ import json import logging import re -import shutil -import sys import typing from contextlib import contextmanager from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type -import ansiwrap import dataclassy import numpy import yaml from Bio import pairwise2 -from colorama import Fore, Style +from rich.console import Console +from rich.logging import RichHandler +from rich.theme import Theme if TYPE_CHECKING: from montreal_forced_aligner.abc import MetaDict @@ -30,7 +29,6 @@ __all__ = [ - "TerminalPrinter", "comma_join", "make_safe", "make_scp_safe", @@ -44,13 +42,24 @@ "overlap_scoring", "align_phones", "split_phone_position", - "CustomFormatter", "configure_logger", "mfa_open", "load_configuration", ] +console = Console( + theme=Theme( + { + "logging.level.debug": "cyan", + "logging.level.info": "green", + "logging.level.warning": "yellow", + "logging.level.error": "red", + } + ) +) + + @contextmanager def mfa_open(path, mode="r", encoding="utf8", newline=""): if "r" in mode: @@ -185,422 +194,19 @@ def configure_logger(identifier: str, log_file: Optional[Path] = None) -> None: file_handler.setFormatter(formatter) logger.addHandler(file_handler) elif not config.current_profile.quiet: - handler = logging.StreamHandler(sys.stdout) + handler = RichHandler( + rich_tracebacks=True, log_time_format="", console=console, show_path=False + ) if config.current_profile.verbose: handler.setLevel(logging.DEBUG) logging.getLogger("sqlalchemy.engine").setLevel(logging.DEBUG) logging.getLogger("sqlalchemy.pool").setLevel(logging.DEBUG) else: handler.setLevel(logging.INFO) - handler.setFormatter(CustomFormatter()) + handler.setFormatter(logging.Formatter("%(message)s")) logger.addHandler(handler) -class CustomFormatter(logging.Formatter): - """ - Custom log formatter class for MFA to highlight messages and incorporate terminal options from - the global configuration - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - from montreal_forced_aligner.config import GLOBAL_CONFIG - - use_colors = GLOBAL_CONFIG.terminal_colors - red = "" - green = "" - yellow = "" - blue = "" - reset = "" - if use_colors: - red = Fore.RED - green = Fore.GREEN - yellow = Fore.YELLOW - blue = Fore.CYAN - reset = Style.RESET_ALL - - self.FORMATS = { - logging.DEBUG: (f"{blue}DEBUG{reset} - ", "%(message)s"), - logging.INFO: (f"{green}INFO{reset} - ", "%(message)s"), - logging.WARNING: (f"{yellow}WARNING{reset} - ", "%(message)s"), - logging.ERROR: (f"{red}ERROR{reset} - ", "%(message)s"), - logging.CRITICAL: (f"{red}CRITICAL{reset} - ", "%(message)s"), - } - - def format(self, record: logging.LogRecord): - """ - Format a given log message - - Parameters - ---------- - record: logging.LogRecord - Log record to format - - Returns - ------- - str - Formatted log message - """ - log_fmt = self.FORMATS.get(record.levelno) - return ansiwrap.fill( - record.getMessage(), - initial_indent=log_fmt[0], - subsequent_indent=" " * len(log_fmt[0]), - width=shutil.get_terminal_size().columns, - ) - - -class TerminalPrinter: - """ - Helper class to output colorized text - - Parameters - ---------- - print_function: Callable, optional - Function to print information, defaults to :func:`print` - - Attributes - ---------- - colors: dict[str, str] - Mapping of color names to terminal codes in colorama (or empty strings - if the global terminal_colors flag is set to False) - """ - - def __init__(self, print_function: typing.Callable = None): - if print_function is not None: - self.print_function = print_function - else: - self.print_function = print - from montreal_forced_aligner.config import GLOBAL_CONFIG - - self.colors = {} - self.colors["bright"] = "" - self.colors["green"] = "" - self.colors["red"] = "" - self.colors["blue"] = "" - self.colors["cyan"] = "" - self.colors["yellow"] = "" - self.colors["reset"] = "" - self.colors["normal"] = "" - self.indent_level = 0 - self.indent_size = 2 - if GLOBAL_CONFIG.terminal_colors: - self.colors["bright"] = Style.BRIGHT - self.colors["green"] = Fore.GREEN - self.colors["red"] = Fore.RED - self.colors["blue"] = Fore.BLUE - self.colors["cyan"] = Fore.CYAN - self.colors["yellow"] = Fore.YELLOW - self.colors["reset"] = Style.RESET_ALL - self.colors["normal"] = Style.NORMAL - - def error_text(self, text: Any) -> str: - """ - Highlight text as an error - - Parameters - ---------- - text: Any - Text to highlight - - Returns - ------- - str - Highlighted text - """ - return self.colorize(str(text), "red") - - def emphasized_text(self, text: Any) -> str: - """ - Highlight text as emphasis - - Parameters - ---------- - text: Any - Text to highlight - - Returns - ------- - str - Highlighted text - """ - return self.colorize(str(text), "bright") - - def pass_text(self, text: Any) -> str: - """ - Highlight text as good - - Parameters - ---------- - text: Any - Text to highlight - - Returns - ------- - str - Highlighted text - """ - return self.colorize(str(text), "green") - - def warning_text(self, text: Any) -> str: - """ - Highlight text as a warning - - Parameters - ---------- - text: Any - Text to highlight - - Returns - ------- - str - Highlighted text - """ - return self.colorize(str(text), "yellow") - - @property - def indent_string(self) -> str: - """Indent string to use in formatting the output messages""" - return " " * self.indent_size * self.indent_level - - def print_header(self, header: str) -> None: - """ - Print a section header - - Parameters - ---------- - header: str - Section header string - """ - self.indent_level = 0 - self.print_function("") - underline = "*" * len(header) - self.print_function(self.colorize(underline, "bright")) - self.print_function(self.colorize(header, "bright")) - self.print_function(self.colorize(underline, "bright")) - self.print_function("") - self.indent_level += 1 - - def print_sub_header(self, header: str) -> None: - """ - Print a subsection header - - Parameters - ---------- - header: str - Subsection header string - """ - underline = "=" * len(header) - self.print_function(self.indent_string + self.colorize(header, "bright")) - self.print_function(self.indent_string + self.colorize(underline, "bright")) - self.print_function("") - self.indent_level += 1 - - def print_end_section(self) -> None: - """Mark the end of a section""" - self.indent_level -= 1 - self.print_function("") - - def format_info_lines(self, lines: Union[list[str], str]) -> List[str]: - """ - Format lines - - Parameters - ---------- - lines: Union[list[str], str - Lines to format - - Returns - ------- - str - Formatted string - """ - if isinstance(lines, str): - lines = [lines] - - for i, line in enumerate(lines): - lines[i] = ansiwrap.fill( - str(line), - initial_indent=self.indent_string, - subsequent_indent=" " * self.indent_size * (self.indent_level + 1), - width=shutil.get_terminal_size().columns, - break_on_hyphens=False, - break_long_words=False, - drop_whitespace=False, - ) - return lines - - def print_info_lines(self, lines: Union[list[str], str]) -> None: - """ - Print formatted information lines - - Parameters - ---------- - lines: Union[list[str], str - Lines to format - """ - if isinstance(lines, str): - lines = [lines] - lines = self.format_info_lines(lines) - for line in lines: - self.print_function(line) - - def print_green_stat(self, stat: Any, text: str) -> None: - """ - Print a statistic in green - - Parameters - ---------- - stat: Any - Statistic to print - text: str - Other text to follow statistic - """ - self.print_function(self.indent_string + f"{self.colorize(stat, 'green')} {text}") - - def print_yellow_stat(self, stat, text) -> None: - """ - Print a statistic in yellow - - Parameters - ---------- - stat: Any - Statistic to print - text: str - Other text to follow statistic - """ - self.print_function(self.indent_string + f"{self.colorize(stat, 'yellow')} {text}") - - def print_red_stat(self, stat, text) -> None: - """ - Print a statistic in red - - Parameters - ---------- - stat: Any - Statistic to print - text: str - Other text to follow statistic - """ - self.print_function(self.indent_string + f"{self.colorize(stat, 'red')} {text}") - - def colorize(self, text: Any, color: str) -> str: - """ - Colorize a string - - Parameters - ---------- - text: Any - Text to colorize - color: str - Colorama code or empty string to wrap the text - - Returns - ------- - str - Colorized string - """ - return f"{self.colors[color]}{text}{self.colors['reset']}" - - def print_block(self, block: dict, starting_level: int = 1) -> None: - """ - Print a configuration block - - Parameters - ---------- - block: dict - Configuration options to output - starting_level: int - Starting indentation level - """ - for k, v in block.items(): - value_color = None - key_color = None - value = "" - if isinstance(k, tuple): - k, key_color = k - - if isinstance(v, tuple): - value, value_color = v - elif not isinstance(v, dict): - value = v - self.print_information_line(k, value, key_color, value_color, starting_level) - if isinstance(v, dict): - self.print_block(v, starting_level=starting_level + 1) - self.print_function("") - - def print_config(self, configuration: MetaDict) -> None: - """ - Pretty print a configuration - - Parameters - ---------- - configuration: dict[str, Any] - Configuration to print - """ - for k, v in configuration.items(): - if "name" in v: - name = v["name"] - name_color = None - if isinstance(name, tuple): - name, name_color = name - self.print_information_line(k, name, value_color=name_color, level=0) - if "data" in v: - self.print_block(v["data"]) - - def print_information_line( - self, - key: str, - value: Any, - key_color: Optional[str] = None, - value_color: Optional[str] = None, - level: int = 1, - ) -> None: - """ - Pretty print a given configuration line - - Parameters - ---------- - key: str - Configuration key - value: Any - Configuration value - key_color: str - Key color - value_color: str - Value color - level: int - Indentation level - """ - if key_color is None: - key_color = "bright" - if value_color is None: - value_color = "cyan" - if isinstance(value, bool): - if value: - value_color = "green" - else: - value_color = "red" - if isinstance(value, (list, tuple, set)): - value = comma_join([self.colorize(x, value_color) for x in sorted(value)]) - else: - value = self.colorize(str(value), value_color) - indent = (" " * level) + "-" - subsequent_indent = " " * (level + 1) - if key: - key = f" {key}:" - subsequent_indent += " " * (len(key)) - - self.print_function( - ansiwrap.fill( - f"{self.colorize(key, key_color)} {value}", - width=shutil.get_terminal_size().columns, - initial_indent=indent, - subsequent_indent=subsequent_indent, - ) - ) - - def comma_join(sequence: List[Any]) -> str: """ Helper function to combine a list into a human-readable expression with commas and a diff --git a/montreal_forced_aligner/ivector/trainer.py b/montreal_forced_aligner/ivector/trainer.py index baa14bcc..ed4bab42 100644 --- a/montreal_forced_aligner/ivector/trainer.py +++ b/montreal_forced_aligner/ivector/trainer.py @@ -11,7 +11,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Tuple -import tqdm +from tqdm.rich import tqdm from montreal_forced_aligner.abc import MetaDict, ModelExporterMixin, TopLevelMfaWorker from montreal_forced_aligner.acoustic_modeling.base import AcousticModelTrainingMixin @@ -240,7 +240,7 @@ def gmm_gselect(self) -> None: begin = time.time() logger.info("Selecting gaussians...") arguments = self.gmm_gselect_arguments() - with tqdm.tqdm( + with tqdm( total=int(self.num_current_utterances / 10), disable=GLOBAL_CONFIG.quiet ) as pbar: for _ in run_kaldi_function(GmmGselectFunction, arguments, pbar.update): @@ -323,7 +323,7 @@ def acc_global_stats(self) -> None: logger.info("Accumulating global stats...") arguments = self.acc_global_stats_arguments() - with tqdm.tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: for _ in run_kaldi_function(AccGlobalStatsFunction, arguments, pbar.update): pass @@ -539,7 +539,7 @@ def gauss_to_post(self) -> None: logger.info("Extracting posteriors...") arguments = self.gauss_to_post_arguments() - with tqdm.tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: for _ in run_kaldi_function(GaussToPostFunction, arguments, pbar.update): pass @@ -612,7 +612,7 @@ def acc_ivector_stats(self) -> None: logger.info("Accumulating ivector stats...") arguments = self.acc_ivector_stats_arguments() - with tqdm.tqdm(total=self.worker.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.worker.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: for _ in run_kaldi_function(AccIvectorStatsFunction, arguments, pbar.update): pass @@ -719,9 +719,9 @@ def compute_lda(self): lda_path = self.working_directory.joinpath("ivector_lda.mat") log_path = self.working_log_directory.joinpath("lda.log") utt2spk_path = os.path.join(self.corpus_output_directory, "utt2spk.scp") - with tqdm.tqdm( - total=self.worker.num_utterances, disable=GLOBAL_CONFIG.quiet - ) as pbar, mfa_open(log_path, "w") as log_file: + with tqdm(total=self.worker.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar, mfa_open( + log_path, "w" + ) as log_file: normalize_proc = subprocess.Popen( [ thirdparty_binary("ivector-normalize-length"), diff --git a/montreal_forced_aligner/language_modeling/trainer.py b/montreal_forced_aligner/language_modeling/trainer.py index 07bed55a..3ba087f7 100644 --- a/montreal_forced_aligner/language_modeling/trainer.py +++ b/montreal_forced_aligner/language_modeling/trainer.py @@ -11,11 +11,11 @@ from queue import Empty import sqlalchemy -import tqdm +from tqdm.rich import tqdm -from montreal_forced_aligner.abc import DatabaseMixin, TopLevelMfaWorker, TrainerMixin +from montreal_forced_aligner.abc import DatabaseMixin, MfaWorker, TopLevelMfaWorker, TrainerMixin from montreal_forced_aligner.config import GLOBAL_CONFIG -from montreal_forced_aligner.corpus.text_corpus import MfaWorker, TextCorpusMixin +from montreal_forced_aligner.corpus.text_corpus import TextCorpusMixin from montreal_forced_aligner.data import WordType, WorkflowType from montreal_forced_aligner.db import Dictionary, Utterance, Word from montreal_forced_aligner.dictionary.mixins import DictionaryMixin @@ -421,7 +421,7 @@ def train_large_lm(self) -> None: procs.append(p) p.start() count_paths.append(self.working_directory.joinpath(f"{j.id}.cnts")) - with tqdm.tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: while True: try: result = return_queue.get(timeout=1) diff --git a/montreal_forced_aligner/models.py b/montreal_forced_aligner/models.py index 2d692ae1..1af3f9fa 100644 --- a/montreal_forced_aligner/models.py +++ b/montreal_forced_aligner/models.py @@ -16,6 +16,7 @@ import requests import yaml +from rich.pretty import pprint from montreal_forced_aligner.abc import MfaModel, ModelExporterMixin from montreal_forced_aligner.data import PhoneSetType @@ -26,7 +27,7 @@ PronunciationAcousticMismatchError, RemoteModelNotFoundError, ) -from montreal_forced_aligner.helper import EnhancedJSONEncoder, TerminalPrinter, mfa_open +from montreal_forced_aligner.helper import EnhancedJSONEncoder, mfa_open if TYPE_CHECKING: from dataclasses import dataclass @@ -241,11 +242,9 @@ def generate_path( def pretty_print(self) -> None: """ - Pretty print the archive's meta data using TerminalPrinter + Pretty print the archive's meta data using rich """ - printer = TerminalPrinter() - configuration_data = {"Archive": {"name": (self.name, "green"), "data": self.meta}} - printer.print_config(configuration_data) + pprint({"Archive": {"name": self.name, "data": self.meta}}) @property def meta(self) -> dict: @@ -527,17 +526,9 @@ def pretty_print(self) -> None: """ Prints the metadata information to the terminal """ - from .utils import get_mfa_version - printer = TerminalPrinter() - configuration_data = {"Acoustic model": {"name": (self.name, "green"), "data": {}}} - version_color = "green" - if self.meta["version"] != get_mfa_version(): - version_color = "red" - configuration_data["Acoustic model"]["data"]["Version"] = ( - self.meta["version"], - version_color, - ) + configuration_data = {"Acoustic model": {"name": self.name, "data": {}}} + configuration_data["Acoustic model"]["data"]["Version"] = (self.meta["version"],) if "citation" in self.meta: configuration_data["Acoustic model"]["data"]["Citation"] = self.meta["citation"] @@ -554,9 +545,9 @@ def pretty_print(self) -> None: if self.meta["phones"]: configuration_data["Acoustic model"]["data"]["Phones"] = self.meta["phones"] else: - configuration_data["Acoustic model"]["data"]["Phones"] = ("None found!", "red") + configuration_data["Acoustic model"]["data"]["Phones"] = "None found!" - printer.print_config(configuration_data) + pprint(configuration_data) def add_model(self, source: str) -> None: """ @@ -1233,12 +1224,11 @@ def add_meta_file(self, trainer: ModelExporterMixin) -> None: def pretty_print(self) -> None: """ - Pretty print the dictionary's metadata using TerminalPrinter + Pretty print the dictionary's metadata """ from montreal_forced_aligner.dictionary.multispeaker import MultispeakerDictionary - printer = TerminalPrinter() - configuration_data = {"Dictionary": {"name": (self.name, "green"), "data": self.meta}} + configuration_data = {"Dictionary": {"name": self.name, "data": self.meta}} temp_directory = self.dirname.joinpath("temp") if temp_directory.exists(): shutil.rmtree(temp_directory) @@ -1266,7 +1256,7 @@ def pretty_print(self) -> None: configuration_data["Dictionary"]["data"]["graphemes"] = sorted(graphemes) else: configuration_data["Dictionary"]["data"]["graphemes"] = f"{len(graphemes)} graphemes" - printer.print_config(configuration_data) + pprint(configuration_data) @classmethod def valid_extension(cls, filename: Path) -> bool: @@ -1418,7 +1408,6 @@ def __init__(self, token=None): if self.token is not None: self.token = environment_token self.synced_remote = False - self.printer = TerminalPrinter() self._cache_info = {} self.refresh_local() @@ -1534,28 +1523,19 @@ def print_local_models(self, model_type: typing.Optional[str] = None) -> None: """ self.refresh_local() if model_type is None: - self.printer.print_information_line("Available local models", "", level=0) + logger.info("Available local models") + data = {} for model_type, model_class in MODEL_TYPES.items(): - names = model_class.get_available_models() - if names: - self.printer.print_information_line(model_type, names, value_color="green") - else: - self.printer.print_information_line( - model_type, "No models found", value_color="yellow" - ) + data[model_type] = model_class.get_available_models() + pprint(data) else: - self.printer.print_information_line( - f"Available local {model_type} models", "", level=0 - ) + logger.info(f"Available local {model_type} models") model_class = MODEL_TYPES[model_type] names = model_class.get_available_models() if names: - for name in names: - self.printer.print_information_line("", name, value_color="green", level=1) + pprint(names) else: - self.printer.print_information_line( - "", "No models found", value_color="yellow", level=1 - ) + logger.error("No models found") def print_remote_models(self, model_type: typing.Optional[str] = None) -> None: """ @@ -1569,27 +1549,18 @@ def print_remote_models(self, model_type: typing.Optional[str] = None) -> None: if not self.synced_remote: self.refresh_remote() if model_type is None: - self.printer.print_information_line("Available models for download", "", level=0) + logger.info("Available models for download") + data = {} for model_type, release_data in self.remote_models.items(): - names = sorted(release_data.keys()) - if names: - self.printer.print_information_line(model_type, names, value_color="green") - else: - self.printer.print_information_line( - model_type, "No models found", value_color="red" - ) + data[model_type] = sorted(release_data.keys()) + pprint(data) else: - self.printer.print_information_line( - f"Available {model_type} models for download", "", level=0 - ) + logger.info(f"Available {model_type} models for download") names = sorted(self.remote_models[model_type].keys()) if names: - for name in names: - self.printer.print_information_line("", name, value_color="green", level=1) + pprint(names) else: - self.printer.print_information_line( - "", "No models found", value_color="yellow", level=1 - ) + logger.error("No models found") def download_model( self, model_type: str, model_name=typing.Optional[str], ignore_cache=False diff --git a/montreal_forced_aligner/tokenization/tokenizer.py b/montreal_forced_aligner/tokenization/tokenizer.py index 91a83cf9..9a547eb8 100644 --- a/montreal_forced_aligner/tokenization/tokenizer.py +++ b/montreal_forced_aligner/tokenization/tokenizer.py @@ -13,17 +13,17 @@ import pynini import pywrapfst import sqlalchemy -import tqdm from praatio import textgrid from pynini import Fst from pynini.lib import rewrite from pywrapfst import SymbolTable from sqlalchemy.orm import joinedload, selectinload +from tqdm.rich import tqdm from montreal_forced_aligner.abc import KaldiFunction, TopLevelMfaWorker from montreal_forced_aligner.alignment.multiprocessing import construct_output_path from montreal_forced_aligner.config import GLOBAL_CONFIG -from montreal_forced_aligner.corpus.text_corpus import TextCorpusMixin +from montreal_forced_aligner.corpus.acoustic_corpus import AcousticCorpusMixin from montreal_forced_aligner.data import MfaArguments, TextgridFormats from montreal_forced_aligner.db import File, Utterance, bulk_update from montreal_forced_aligner.dictionary.mixins import DictionaryMixin @@ -207,7 +207,7 @@ def _run(self) -> typing.Generator: yield u_id, tokenized_text -class CorpusTokenizer(TextCorpusMixin, TopLevelMfaWorker, DictionaryMixin): +class CorpusTokenizer(AcousticCorpusMixin, TopLevelMfaWorker, DictionaryMixin): """ Top-level worker for generating pronunciations from a corpus and a Pynini tokenizer model """ @@ -322,7 +322,7 @@ def tokenize_utterances(self) -> None: self.setup() logger.info("Tokenizing utterances...") args = self.tokenize_arguments() - with tqdm.tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: update_mapping = [] for utt_id, tokenized in run_kaldi_function(TokenizerFunction, args, pbar.update): update_mapping.append({"id": utt_id, "text": tokenized}) @@ -388,7 +388,7 @@ def tokenize_utterances(self) -> typing.Dict[str, str]: logger.info("Tokenizing utterances...") to_return = {} if num_utterances < 30 or GLOBAL_CONFIG.num_jobs == 1: - with tqdm.tqdm(total=num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: for utterance in self.utterances_to_tokenize: pbar.update(1) result = self.rewriter(utterance) @@ -410,7 +410,7 @@ def tokenize_utterances(self) -> typing.Dict[str, str]: ) procs.append(p) p.start() - with tqdm.tqdm(total=num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: while True: try: utterance, result = return_queue.get(timeout=1) diff --git a/montreal_forced_aligner/tokenization/trainer.py b/montreal_forced_aligner/tokenization/trainer.py index 5aba586e..0e735b79 100644 --- a/montreal_forced_aligner/tokenization/trainer.py +++ b/montreal_forced_aligner/tokenization/trainer.py @@ -13,7 +13,7 @@ from montreal_forced_aligner.abc import MetaDict, TopLevelMfaWorker from montreal_forced_aligner.config import GLOBAL_CONFIG -from montreal_forced_aligner.corpus.text_corpus import TextCorpusMixin +from montreal_forced_aligner.corpus.acoustic_corpus import AcousticCorpusMixin from montreal_forced_aligner.data import WorkflowType from montreal_forced_aligner.db import M2M2Job, M2MSymbol, Utterance from montreal_forced_aligner.dictionary.mixins import DictionaryMixin @@ -194,7 +194,7 @@ def run(self) -> None: del far_writer -class TokenizerMixin(TextCorpusMixin, G2PTrainer, DictionaryMixin, TopLevelMfaWorker): +class TokenizerMixin(AcousticCorpusMixin, G2PTrainer, DictionaryMixin, TopLevelMfaWorker): def __init__(self, **kwargs): super().__init__(**kwargs) self.training_graphemes = set() diff --git a/montreal_forced_aligner/transcription/transcriber.py b/montreal_forced_aligner/transcription/transcriber.py index ef15902c..ef755245 100644 --- a/montreal_forced_aligner/transcription/transcriber.py +++ b/montreal_forced_aligner/transcription/transcriber.py @@ -19,9 +19,9 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import pywrapfst -import tqdm from praatio import textgrid from sqlalchemy.orm import joinedload, selectinload +from tqdm.rich import tqdm from montreal_forced_aligner.abc import TopLevelMfaWorker from montreal_forced_aligner.alignment.base import CorpusAligner @@ -205,7 +205,7 @@ def train_speaker_lms(self) -> None: os.makedirs(log_directory, exist_ok=True) logger.info("Compiling per speaker biased language models...") arguments = self.train_speaker_lm_arguments() - with tqdm.tqdm(total=self.num_speakers, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_speakers, disable=GLOBAL_CONFIG.quiet) as pbar: if GLOBAL_CONFIG.use_mp: error_dict = {} return_queue = mp.Queue() @@ -276,7 +276,7 @@ def lm_rescore(self) -> None: p = KaldiProcessWorker(i, return_queue, function, stopped) procs.append(p) p.start() - with tqdm.tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: while True: try: result = return_queue.get(timeout=1) @@ -304,7 +304,7 @@ def lm_rescore(self) -> None: else: for args in self.lm_rescore_arguments(): function = LmRescoreFunction(args) - with tqdm.tqdm(total=GLOBAL_CONFIG.num_jobs, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=GLOBAL_CONFIG.num_jobs, disable=GLOBAL_CONFIG.quiet) as pbar: for succeeded, failed in function.run(): if failed: logger.warning("Some lattices failed to be rescored") @@ -332,7 +332,7 @@ def carpa_lm_rescore(self) -> None: p = KaldiProcessWorker(i, return_queue, function, stopped) procs.append(p) p.start() - with tqdm.tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: while True: try: result = return_queue.get(timeout=1) @@ -360,7 +360,7 @@ def carpa_lm_rescore(self) -> None: else: for args in self.carpa_lm_rescore_arguments(): function = CarpaLmRescoreFunction(args) - with tqdm.tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: for succeeded, failed in function.run(): if failed: logger.warning("Some lattices failed to be rescored") @@ -387,7 +387,7 @@ def train_phone_lm(self): procs = [] count_paths = [] allowed_bigrams = collections.defaultdict(set) - with self.session() as session, tqdm.tqdm( + with self.session() as session, tqdm( total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet ) as pbar: @@ -878,7 +878,7 @@ def decode(self) -> None: Arguments for function """ logger.info("Generating lattices...") - with tqdm.tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: workflow = self.current_workflow arguments = self.decode_arguments(workflow.workflow_type) log_likelihood_sum = 0 @@ -912,7 +912,7 @@ def calc_initial_fmllr(self) -> None: """ logger.info("Calculating initial fMLLR transforms...") sum_errors = 0 - with tqdm.tqdm(total=self.num_speakers, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_speakers, disable=GLOBAL_CONFIG.quiet) as pbar: if GLOBAL_CONFIG.use_mp: error_dict = {} return_queue = mp.Queue() @@ -966,7 +966,7 @@ def lat_gen_fmllr(self) -> None: logger.info("Regenerating lattices with fMLLR transforms...") workflow = self.current_workflow arguments = self.lat_gen_fmllr_arguments(workflow.workflow_type) - with tqdm.tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar, mfa_open( + with tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar, mfa_open( self.working_log_directory.joinpath("lat_gen_fmllr_log_like.csv"), "w", encoding="utf8", @@ -1025,7 +1025,7 @@ def calc_final_fmllr(self) -> None: """ logger.info("Calculating final fMLLR transforms...") sum_errors = 0 - with tqdm.tqdm(total=self.num_speakers, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_speakers, disable=GLOBAL_CONFIG.quiet) as pbar: if GLOBAL_CONFIG.use_mp: error_dict = {} return_queue = mp.Queue() @@ -1078,7 +1078,7 @@ def fmllr_rescore(self) -> None: """ logger.info("Rescoring fMLLR lattices with final transform...") sum_errors = 0 - with tqdm.tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: if GLOBAL_CONFIG.use_mp: error_dict = {} return_queue = mp.Queue() @@ -1549,7 +1549,7 @@ def create_hclgs(self) -> None: p = KaldiProcessWorker(i, return_queue, function, stopped) procs.append(p) p.start() - with tqdm.tqdm(total=len(dict_arguments) * 7, disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=len(dict_arguments) * 7, disable=GLOBAL_CONFIG.quiet) as pbar: while True: try: result = return_queue.get(timeout=1) @@ -1582,7 +1582,7 @@ def create_hclgs(self) -> None: else: for args in dict_arguments: function = CreateHclgFunction(args) - with tqdm.tqdm(total=len(dict_arguments), disable=GLOBAL_CONFIG.quiet) as pbar: + with tqdm(total=len(dict_arguments), disable=GLOBAL_CONFIG.quiet) as pbar: for result in function.run(): if not isinstance(result, tuple): pbar.update(1) diff --git a/montreal_forced_aligner/vad/segmenter.py b/montreal_forced_aligner/vad/segmenter.py index b5d93d75..cd8df2e8 100644 --- a/montreal_forced_aligner/vad/segmenter.py +++ b/montreal_forced_aligner/vad/segmenter.py @@ -13,8 +13,8 @@ from typing import Dict, List, Optional import sqlalchemy -import tqdm from sqlalchemy.orm import joinedload, selectinload +from tqdm.rich import tqdm from montreal_forced_aligner.abc import FileExporterMixin, MetaDict, TopLevelMfaWorker from montreal_forced_aligner.config import GLOBAL_CONFIG @@ -230,7 +230,7 @@ def segment_vad_speechbrain(self) -> None: new_utts = [] kwargs = self.segmentation_options kwargs.pop("frame_shift") - with tqdm.tqdm( + with tqdm( total=self.num_utterances, disable=GLOBAL_CONFIG.quiet ) as pbar, self.session() as session: utt_index = session.query(sqlalchemy.func.max(Utterance.id)).scalar() @@ -293,7 +293,7 @@ def segment_vad_mfa(self) -> None: old_utts = set() new_utts = [] - with tqdm.tqdm( + with tqdm( total=self.num_utterances, disable=GLOBAL_CONFIG.quiet ) as pbar, self.session() as session: utterances = session.query( diff --git a/montreal_forced_aligner/validation/corpus_validator.py b/montreal_forced_aligner/validation/corpus_validator.py index afb17002..f1774d7f 100644 --- a/montreal_forced_aligner/validation/corpus_validator.py +++ b/montreal_forced_aligner/validation/corpus_validator.py @@ -22,12 +22,7 @@ from montreal_forced_aligner.data import WorkflowType from montreal_forced_aligner.db import Corpus, File, SoundFile, Speaker, TextFile, Utterance from montreal_forced_aligner.exceptions import ConfigError, KaldiProcessingError -from montreal_forced_aligner.helper import ( - TerminalPrinter, - comma_join, - load_configuration, - mfa_open, -) +from montreal_forced_aligner.helper import comma_join, load_configuration, mfa_open from montreal_forced_aligner.utils import log_kaldi_errors, run_mp, run_non_mp if TYPE_CHECKING: @@ -59,10 +54,6 @@ class ValidationMixin: :class:`~montreal_forced_aligner.alignment.base.CorpusAligner` For corpus, dictionary, and alignment parameters - Attributes - ---------- - printer: TerminalPrinter - Printer for output messages """ def __init__( @@ -80,7 +71,6 @@ def __init__( self.target_num_ngrams = target_num_ngrams self.order = order self.method = method - self.printer = TerminalPrinter(print_function=logger.info) @property def working_log_directory(self) -> str: @@ -107,25 +97,21 @@ def analyze_setup(self) -> None: ignored_count += len(self.decode_error_files) logger.debug(f"Ignored count calculation took {time.time() - begin:.3f} seconds") - self.printer.print_header("Corpus") - self.printer.print_green_stat(sound_file_count, "sound files") - self.printer.print_green_stat(text_file_count, "text files") + logger.info("Corpus") + logger.info(f"{sound_file_count} sound files") + logger.info(f"{text_file_count} text files") if len(self.no_transcription_files): - self.printer.print_yellow_stat( - len(self.no_transcription_files), - "sound files without corresponding transcriptions", + logger.warning( + f"{len(self.no_transcription_files)} sound files without corresponding transcriptions", ) if len(self.decode_error_files): - self.printer.print_red_stat(len(self.decode_error_files), "read errors for lab files") + logger.error(f"{len(self.decode_error_files)} read errors for lab files") if len(self.textgrid_read_errors): - self.printer.print_red_stat( - len(self.textgrid_read_errors), "read errors for TextGrid files" - ) + logger.error(f"{len(self.textgrid_read_errors)} read errors for TextGrid files") - self.printer.print_green_stat(self.num_speakers, "speakers") - self.printer.print_green_stat(self.num_utterances, "utterances") - self.printer.print_green_stat(total_duration, "seconds total duration") - print() + logger.info(f"{self.num_speakers} speakers") + logger.info(f"{self.num_utterances} utterances") + logger.info(f"{total_duration} seconds total duration") self.analyze_wav_errors() self.analyze_missing_features() self.analyze_files_with_no_transcription() @@ -136,14 +122,14 @@ def analyze_setup(self) -> None: if len(self.textgrid_read_errors): self.analyze_textgrid_read_errors() - self.printer.print_header("Dictionary") + logger.info("Dictionary") self.analyze_oovs() def analyze_oovs(self) -> None: """ Analyzes OOVs in the corpus and constructs message """ - self.printer.print_sub_header("Out of vocabulary words") + logger.info("Out of vocabulary words") output_dir = self.output_directory oov_path = os.path.join(output_dir, "oovs_found.txt") utterance_oov_path = os.path.join(output_dir, "utterance_oovs.txt") @@ -173,33 +159,24 @@ def analyze_oovs(self) -> None: self.oovs_found.update(oovs) if self.oovs_found: self.save_oovs_found(self.output_directory) - self.printer.print_yellow_stat(len(self.oovs_found), "OOV word types") - self.printer.print_yellow_stat(total_instances, "total OOV tokens") - lines = [ - "", - "For a full list of the word types, please see:", - "", - self.printer.indent_string + self.printer.colorize(oov_path, "bright"), - "", - "For a by-utterance breakdown of missing words, see:", - "", - self.printer.indent_string + self.printer.colorize(utterance_oov_path, "bright"), - "", - ] - self.printer.print_info_lines(lines) + logger.warning(f"{len(self.oovs_found)} OOV word types") + logger.warning(f"{total_instances}total OOV tokens") + logger.warning( + f"For a full list of the word types, please see: {oov_path}. " + f"For a by-utterance breakdown of missing words, see: {utterance_oov_path}" + ) else: - self.printer.print_info_lines( - f"There were {self.printer.colorize('no', 'yellow')} missing words from the dictionary. If you plan on using the a model trained " + logger.info( + "There were no missing words from the dictionary. If you plan on using the a model trained " "on this dataset to align other datasets in the future, it is recommended that there be at " "least some missing words." ) - self.printer.print_end_section() def analyze_wav_errors(self) -> None: """ Analyzes any sound file issues in the corpus and constructs message """ - self.printer.print_sub_header("Sound file read errors") + logger.info("Sound file read errors") output_dir = self.output_directory wav_read_errors = self.sound_file_errors @@ -209,25 +186,20 @@ def analyze_wav_errors(self) -> None: for p in wav_read_errors: f.write(f"{p}\n") - self.printer.print_info_lines( - f"There were {self.printer.colorize(len(wav_read_errors), 'red')} issues reading sound files. " - f"Please see {self.printer.colorize(path, 'bright')} for a list." + logger.error( + f"There were {len(wav_read_errors)} issues reading sound files. " + f"Please see {path} for a list." ) else: - self.printer.print_info_lines( - f"There were {self.printer.colorize('no', 'green')} issues reading sound files." - ) - - self.printer.print_end_section() + logger.info("There were no issues reading sound files.") def analyze_missing_features(self) -> None: """ Analyzes issues in feature generation in the corpus and constructs message """ - self.printer.print_sub_header("Feature generation") + logger.info("Feature generation") if self.ignore_acoustics: - self.printer.print_info_lines("Acoustic feature generation was skipped.") - self.printer.print_end_section() + logger.info("Acoustic feature generation was skipped.") return output_dir = self.output_directory with self.session() as session: @@ -243,66 +215,57 @@ def analyze_missing_features(self) -> None: f.write(f"{relative_path + '/' + file_name},{begin},{end}\n") - self.printer.print_info_lines( - f"There were {self.printer.colorize(utterances.count(), 'red')} utterances missing features. " - f"Please see {self.printer.colorize(path, 'bright')} for a list." + logger.error( + f"There were {utterances.count()} utterances missing features. " + f"Please see {path} for a list." ) else: - self.printer.print_info_lines( - f"There were {self.printer.colorize('no', 'green')} utterances missing features." - ) - self.printer.print_end_section() + logger.info("There were no utterances missing features.") def analyze_files_with_no_transcription(self) -> None: """ Analyzes issues with sound files that have no transcription files in the corpus and constructs message """ - self.printer.print_sub_header("Files without transcriptions") + logger.info("Files without transcriptions") output_dir = self.output_directory if self.no_transcription_files: path = os.path.join(output_dir, "missing_transcriptions.csv") with mfa_open(path, "w") as f: for file_path in self.no_transcription_files: f.write(f"{file_path}\n") - self.printer.print_info_lines( - f"There were {self.printer.colorize(len(self.no_transcription_files), 'red')} sound files missing transcriptions. " - f"Please see {self.printer.colorize(path, 'bright')} for a list." + logger.error( + f"There were {len(self.no_transcription_files)} sound files missing transcriptions." ) + logger.error(f"Please see {path} for a list.") else: - self.printer.print_info_lines( - f"There were {self.printer.colorize('no', 'green')} sound files missing transcriptions." - ) - self.printer.print_end_section() + logger.info("There were no sound files missing transcriptions.") def analyze_transcriptions_with_no_wavs(self) -> None: """ Analyzes issues with transcription that have no sound files in the corpus and constructs message """ - self.printer.print_sub_header("Transcriptions without sound files") + logger.info("Transcriptions without sound files") output_dir = self.output_directory if self.transcriptions_without_wavs: path = os.path.join(output_dir, "transcriptions_missing_sound_files.csv") with mfa_open(path, "w") as f: for file_path in self.transcriptions_without_wavs: f.write(f"{file_path}\n") - self.printer.print_info_lines( - f"There were {self.printer.colorize(len(self.transcriptions_without_wavs), 'red')} transcription files missing sound files. " - f"Please see {self.printer.colorize(path, 'bright')} for a list." + logger.error( + f"There were {len(self.transcriptions_without_wavs)} transcription files missing sound files. " + f"Please see {path} for a list." ) else: - self.printer.print_info_lines( - f"There were {self.printer.colorize('no', 'green')} transcription files missing sound files." - ) - self.printer.print_end_section() + logger.info("There were no transcription files missing sound files.") def analyze_textgrid_read_errors(self) -> None: """ Analyzes issues with reading TextGrid files in the corpus and constructs message """ - self.printer.print_sub_header("TextGrid read errors") + logger.info("TextGrid read errors") output_dir = self.output_directory if self.textgrid_read_errors: path = os.path.join(output_dir, "textgrid_read_errors.txt") @@ -311,43 +274,31 @@ def analyze_textgrid_read_errors(self) -> None: f.write( f"The TextGrid file {e.file_name} gave the following error on load:\n\n{e}\n\n\n" ) - self.printer.print_info_lines( - [ - f"There were {self.printer.colorize(len(self.textgrid_read_errors), 'red')} TextGrid files that could not be loaded. " - "For details, please see:", - "", - self.printer.indent_string + self.printer.colorize(path, "bright"), - ] + logger.error( + f"There were {len(self.textgrid_read_errors)} TextGrid files that could not be loaded. " + f"For details, please see: {path}", ) else: - self.printer.print_info_lines( - f"There were {self.printer.colorize('no', 'green')} issues reading TextGrids." - ) - - self.printer.print_end_section() + logger.info("There were no issues reading TextGrids.") def analyze_unreadable_text_files(self) -> None: """ Analyzes issues with reading text files in the corpus and constructs message """ - self.printer.print_sub_header("Text file read errors") + logger.info("Text file read errors") output_dir = self.output_directory if self.decode_error_files: path = os.path.join(output_dir, "utf8_read_errors.csv") with mfa_open(path, "w") as f: for file_path in self.decode_error_files: f.write(f"{file_path}\n") - self.printer.print_info_lines( - f"There were {self.printer.colorize(len(self.decode_error_files), 'red')} text files that could not be read. " - f"Please see {self.printer.colorize(path, 'bright')} for a list." + logger.error( + f"There were {len(self.decode_error_files)} text files that could not be read. " + f"Please see {path} for a list." ) else: - self.printer.print_info_lines( - f"There were {self.printer.colorize('no', 'green')} issues reading text files." - ) - - self.printer.print_end_section() + logger.info("There were no issues reading text files.") def compile_information(self) -> None: """ @@ -392,29 +343,22 @@ def compile_information(self) -> None: average_logdet_frames += data["logdet_frames"] average_logdet_sum += data["logdet"] * data["logdet_frames"] - self.printer.print_header("Alignment") + logger.info("Alignment") if not avg_like_frames: - logger.debug( - "No utterances were aligned, this likely indicates serious problems with the aligner." - ) - self.printer.print_red_stat(0, f"of {self.num_utterances} utterances were aligned") + logger.error(f"0 of {self.num_utterances} utterances were aligned") else: if too_short_count: - self.printer.print_red_stat( - too_short_count, "utterances were too short to be aligned" + logger.error( + too_short_count, f"{too_short_count} utterances were too short to be aligned" ) else: - self.printer.print_green_stat(0, "utterances were too short to be aligned") + logger.info("0 utterances were too short to be aligned") if beam_too_narrow_count: - logger.debug( - f"There were {beam_too_narrow_count} utterances that could not be aligned with " - f"the current beam settings." - ) - self.printer.print_yellow_stat( - beam_too_narrow_count, "utterances that need a larger beam to align" + logger.warning( + f"{beam_too_narrow_count} utterances that need a larger beam to align" ) else: - self.printer.print_green_stat(0, "utterances that need a larger beam to align") + logger.info("0 utterances that need a larger beam to align") num_utterances = self.num_utterances with self.session() as session: @@ -433,18 +377,13 @@ def compile_information(self) -> None: f.write( f"{u.file.name},{u.begin},{u.end},{u.duration},{utt_length_words}\n" ) - self.printer.print_info_lines( - [ - f"There were {self.printer.colorize(unaligned_count, 'red')} unaligned utterances out of {self.printer.colorize(self.num_utterances, 'bright')} after initial training. " - f"For details, please see:", - "", - self.printer.indent_string + self.printer.colorize(path, "bright"), - ] + logger.error( + f"There were {unaligned_count} unaligned utterances out of {self.num_utterances} after initial training. " + f"For details, please see: {path}", ) - - self.printer.print_green_stat( - num_utterances - beam_too_narrow_count - too_short_count, - "utterances were successfully aligned", + successful_utterances = num_utterances - beam_too_narrow_count - too_short_count + logger.info( + f"{successful_utterances} utterances were successfully aligned", ) average_log_like = avg_like_sum / avg_like_frames if average_logdet_sum: @@ -462,39 +401,38 @@ def test_utterance_transcriptions(self) -> None: :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` If there were any errors in running Kaldi binaries """ - logger.info("Checking utterance transcriptions...") try: self.train_speaker_lms() self.transcribe(WorkflowType.per_speaker_transcription) - self.printer.print_header("Test transcriptions") + logger.info("Test transcriptions") ser, wer, cer = self.compute_wer() if ser < 0.3: - self.printer.print_green_stat(f"{ser*100:.2f}%", "sentence error rate") + logger.info(f"{ser*100:.2f}% sentence error rate") elif ser < 0.8: - self.printer.print_yellow_stat(f"{ser*100:.2f}%", "sentence error rate") + logger.warning(f"{ser*100:.2f}% sentence error rate") else: - self.printer.print_red_stat(f"{ser*100:.2f}%", "sentence error rate") + logger.error(f"{ser*100:.2f}% sentence error rate") if wer < 0.25: - self.printer.print_green_stat(f"{wer*100:.2f}%", "word error rate") + logger.info(f"{wer*100:.2f}% word error rate") elif wer < 0.75: - self.printer.print_yellow_stat(f"{wer*100:.2f}%", "word error rate") + logger.warning(f"{wer*100:.2f}% word error rate") else: - self.printer.print_red_stat(f"{wer*100:.2f}%", "word error rate") + logger.error(f"{wer*100:.2f}% word error rate") if cer < 0.25: - self.printer.print_green_stat(f"{cer*100:.2f}%", "character error rate") + logger.info(f"{cer*100:.2f}% character error rate") elif cer < 0.75: - self.printer.print_yellow_stat(f"{cer*100:.2f}%", "character error rate") + logger.warning(f"{cer*100:.2f}% character error rate") else: - self.printer.print_red_stat(f"{cer*100:.2f}%", "character error rate") + logger.error(f"{cer*100:.2f}% character error rate") self.save_transcription_evaluation(self.output_directory) out_path = os.path.join(self.output_directory, "transcription_evaluation.csv") - print(f"See {self.printer.colorize(out_path, 'bright')} for more details.") + logger.info(f"See {out_path} for more details.") except Exception as e: if isinstance(e, KaldiProcessingError): @@ -658,9 +596,9 @@ def validate(self) -> None: self.analyze_setup() logger.debug(f"Setup took {time.time() - begin:.3f} seconds") if self.ignore_acoustics: - self.printer.print_info_lines("Skipping test alignments.") + logger.info("Skipping test alignments.") return - self.printer.print_header("Training") + logger.info("Training") self.train() if self.test_transcriptions: self.test_utterance_transcriptions() @@ -756,26 +694,13 @@ def validate(self) -> None: def analyze_missing_phones(self) -> None: """Analyzes dictionary and acoustic model for phones in the dictionary that don't have acoustic models""" - self.printer.print_sub_header("Acoustic model compatibility") + logger.info("Acoustic model compatibility") if self.excluded_pronunciation_count: - self.printer.print_yellow_stat( - len(self.excluded_phones), "phones not in acoustic model" - ) - self.printer.print_yellow_stat( - self.excluded_pronunciation_count, "ignored pronunciations" - ) + logger.warning(len(self.excluded_phones), "phones not in acoustic model") + logger.warning(self.excluded_pronunciation_count, "ignored pronunciations") - phone_string = [self.printer.colorize(x, "red") for x in sorted(self.excluded_phones)] - self.printer.print_info_lines( - [ - "", - "Phones missing acoustic models:", - "", - self.printer.indent_string + comma_join(phone_string), - ] + logger.error( + f"Phones missing acoustic models: {comma_join(sorted(self.excluded_phones))}" ) else: - self.printer.print_info_lines( - f"There were {self.printer.colorize('no', 'green')} phones in the dictionary without acoustic models." - ) - self.printer.print_end_section() + logger.info("There were no phones in the dictionary without acoustic models.") diff --git a/requirements.txt b/requirements.txt index 29df5501..2ad786a5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,8 +3,8 @@ tqdm pyyaml librosa requests -colorama biopython dataclassy sqlalchemy -ansiwrap +rich +rich-click diff --git a/rtd_environment.yml b/rtd_environment.yml index 513d91fe..cf330d92 100644 --- a/rtd_environment.yml +++ b/rtd_environment.yml @@ -6,8 +6,6 @@ dependencies: - librosa - tqdm - requests - - colorama - - ansiwrap - pyyaml - praatio=6.0.0 - dataclassy @@ -35,6 +33,8 @@ dependencies: - kneed - matplotlib - seaborn + - rich + - rich-click - pip: - sphinx-needs - sphinxcontrib-plantuml diff --git a/setup.cfg b/setup.cfg index 31bca5f5..33821bf1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,12 +36,9 @@ keywords = phonology [options] packages = find: install_requires = - ansiwrap biopython biopython<=1.79 click - click - colorama dataclassy kneed librosa @@ -51,6 +48,8 @@ install_requires = praatio>=5.0 pyyaml requests + rich + rich-click scikit-learn seaborn sqlalchemy>=1.4 diff --git a/tests/conftest.py b/tests/conftest.py index 68d196f4..e319b5ba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -597,6 +597,27 @@ def swedish_dir(corpus_root_dir, wav_dir, lab_dir): return path +@pytest.fixture() +def japanese_cv_dir(corpus_root_dir, wav_dir, lab_dir): + path = corpus_root_dir.joinpath("test_japanese_cv") + path.mkdir(parents=True, exist_ok=True) + names = [ + ( + "02a8841a00d7624", + [ + "common_voice_ja_24511055", + ], + ) + ] + for s, files in names: + s_dir = path.joinpath(s) + s_dir.mkdir(parents=True, exist_ok=True) + for name in files: + shutil.copyfile(wav_dir.joinpath(name + ".mp3"), s_dir.joinpath(name + ".mp3")) + shutil.copyfile(lab_dir.joinpath(name + ".lab"), s_dir.joinpath(name + ".lab")) + return path + + @pytest.fixture() def basic_corpus_txt_dir(corpus_root_dir, wav_dir, lab_dir): path = corpus_root_dir.joinpath("test_basic_txt") diff --git a/tests/data/lab/common_voice_ja_24511055.lab b/tests/data/lab/common_voice_ja_24511055.lab new file mode 100644 index 00000000..9df88fa4 --- /dev/null +++ b/tests/data/lab/common_voice_ja_24511055.lab @@ -0,0 +1 @@ +真っ昼間なのにキャンプの外れの電柱に電球がともっていた diff --git a/tests/data/wav/common_voice_ja_24511055.mp3 b/tests/data/wav/common_voice_ja_24511055.mp3 new file mode 100644 index 00000000..c466769c Binary files /dev/null and b/tests/data/wav/common_voice_ja_24511055.mp3 differ diff --git a/tests/test_commandline_tokenize.py b/tests/test_commandline_tokenize.py index 5cf0b32d..3c242a11 100644 --- a/tests/test_commandline_tokenize.py +++ b/tests/test_commandline_tokenize.py @@ -5,11 +5,11 @@ from montreal_forced_aligner.command_line.mfa import mfa_cli -def test_tokenize_pretrained(japanese_tokenizer_model, japanese_dir, temp_dir, generated_dir): +def test_tokenize_pretrained(japanese_tokenizer_model, japanese_cv_dir, temp_dir, generated_dir): out_directory = generated_dir.joinpath("japanese_tokenized") command = [ "tokenize", - japanese_dir, + japanese_cv_dir, japanese_tokenizer_model, out_directory, "-t",