Skip to content

Commit

Permalink
Merge pull request #330 from jhj0517/fix/compute-type
Browse files Browse the repository at this point in the history
Fix/compute type
  • Loading branch information
jhj0517 authored Oct 14, 2024
2 parents bc6b2e9 + 87cbb02 commit 6ae85bd
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 14 deletions.
3 changes: 1 addition & 2 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,7 @@ def launch(self):
tb_api_key = gr.Textbox(label="Your Auth Key (API KEY)", value=deepl_params["api_key"])
with gr.Row():
dd_source_lang = gr.Dropdown(label="Source Language", value=deepl_params["source_lang"],
choices=list(
self.deepl_api.available_source_langs.keys()))
choices=list(self.deepl_api.available_source_langs.keys()))
dd_target_lang = gr.Dropdown(label="Target Language", value=deepl_params["target_lang"],
choices=list(self.deepl_api.available_target_langs.keys()))
with gr.Row():
Expand Down
9 changes: 4 additions & 5 deletions modules/translation/nllb_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import os

from modules.utils.paths import TRANSLATION_OUTPUT_DIR, NLLB_MODELS_DIR
from modules.translation.translation_base import TranslationBase
import modules.translation.translation_base as base


class NLLBInference(TranslationBase):
class NLLBInference(base.TranslationBase):
def __init__(self,
model_dir: str = NLLB_MODELS_DIR,
output_dir: str = TRANSLATION_OUTPUT_DIR
Expand All @@ -29,7 +29,7 @@ def translate(self,
text,
max_length=max_length
)
return result[0]['translation_text']
return result[0]["translation_text"]

def update_model(self,
model_size: str,
Expand All @@ -41,8 +41,7 @@ def validate_language(lang: str) -> str:
if lang in NLLB_AVAILABLE_LANGS:
return NLLB_AVAILABLE_LANGS[lang]
elif lang not in NLLB_AVAILABLE_LANGS.values():
raise ValueError(
f"Language '{lang}' is not supported. Use one of: {list(NLLB_AVAILABLE_LANGS.keys())}")
raise ValueError(f"Language '{lang}' is not supported. Use one of: {list(NLLB_AVAILABLE_LANGS.keys())}")
return lang

src_lang = validate_language(src_lang)
Expand Down
11 changes: 9 additions & 2 deletions modules/translation/translation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import List
from datetime import datetime

import modules.translation.nllb_inference as nllb
from modules.whisper.whisper_parameter import *
from modules.utils.subtitle_manager import *
from modules.utils.files_manager import load_yaml, save_yaml
Expand Down Expand Up @@ -166,11 +167,17 @@ def cache_parameters(model_size: str,
tgt_lang: str,
max_length: int,
add_timestamp: bool):
def validate_lang(lang: str):
if lang in list(nllb.NLLB_AVAILABLE_LANGS.values()):
flipped = {value: key for key, value in nllb.NLLB_AVAILABLE_LANGS.items()}
return flipped[lang]
return lang

cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
cached_params["translation"]["nllb"] = {
"model_size": model_size,
"source_lang": src_lang,
"target_lang": tgt_lang,
"source_lang": validate_lang(src_lang),
"target_lang": validate_lang(tgt_lang),
"max_length": max_length,
}
cached_params["translation"]["add_timestamp"] = add_timestamp
Expand Down
2 changes: 0 additions & 2 deletions modules/whisper/faster_whisper_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ def __init__(self,
self.model_paths = self.get_model_paths()
self.device = self.get_device()
self.available_models = self.model_paths.keys()
self.available_compute_types = ctranslate2.get_supported_compute_types(
"cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")

def transcribe(self,
audio: Union[str, BinaryIO, np.ndarray],
Expand Down
1 change: 0 additions & 1 deletion modules/whisper/insanely_fast_whisper_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def __init__(self,
openai_models = whisper.available_models()
distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
self.available_models = openai_models + distil_models
self.available_compute_types = ["float16"]

def transcribe(self,
audio: Union[str, np.ndarray, torch.Tensor],
Expand Down
19 changes: 17 additions & 2 deletions modules/whisper/whisper_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import torch
import whisper
import ctranslate2
import gradio as gr
import torchaudio
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -47,8 +48,8 @@ def __init__(self,
self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
self.device = self.get_device()
self.available_compute_types = ["float16", "float32"]
self.current_compute_type = "float16" if self.device == "cuda" else "float32"
self.available_compute_types = self.get_available_compute_type()
self.current_compute_type = self.get_compute_type()

@abstractmethod
def transcribe(self,
Expand Down Expand Up @@ -371,6 +372,20 @@ def transcribe_youtube(self,
finally:
self.release_cuda_memory()

def get_compute_type(self):
if "float16" in self.available_compute_types:
return "float16"
if "float32" in self.available_compute_types:
return "float32"
else:
return self.available_compute_types[0]

def get_available_compute_type(self):
if self.device == "cuda":
return list(ctranslate2.get_supported_compute_types("cuda"))
else:
return list(ctranslate2.get_supported_compute_types("cpu"))

@staticmethod
def generate_and_write_file(file_name: str,
transcribed_segments: list,
Expand Down

0 comments on commit 6ae85bd

Please sign in to comment.