From f4af2e6b5a0d783979472ecec6f202575961b6d1 Mon Sep 17 00:00:00 2001 From: sudoskys Date: Thu, 26 Sep 2024 18:41:45 +0800 Subject: [PATCH] :art: refactor: update type alias to TextLLMModel and add COLORS_LLM list --- src/novelai_python/_enum.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/novelai_python/_enum.py b/src/novelai_python/_enum.py index 8f6504c..991909f 100644 --- a/src/novelai_python/_enum.py +++ b/src/novelai_python/_enum.py @@ -62,21 +62,32 @@ class TextTokenizerGroup(object): TextLLMModel.ERATO: TextTokenizerGroup.LLAMA3, } +COLORS_LLM = [ + TextLLMModel.BLUE, + TextLLMModel.RED, + TextLLMModel.GREEN, + TextLLMModel.PURPLE, + TextLLMModel.PINK, + TextLLMModel.YELLOW, + TextLLMModel.WHITE, + TextLLMModel.BLACK, +] -def get_llm_group(model: TextLLMModelTypeAlias) -> Optional[TextTokenizerGroup]: + +def get_llm_group(model: TextLLMModel) -> Optional[TextTokenizerGroup]: if isinstance(model, str): model = TextLLMModel(model) return TOKENIZER_MODEL_MAP.get(model, None) -def get_tokenizer_model(model: TextLLMModelTypeAlias) -> str: +def get_tokenizer_model(model: TextLLMModel) -> str: if isinstance(model, str): model = TextLLMModel(model) group = TOKENIZER_MODEL_MAP.get(model, TextTokenizerGroup.GPT2) return group -def get_tokenizer_model_url(model: TextLLMModelTypeAlias) -> str: +def get_tokenizer_model_url(model: TextLLMModel) -> str: model_name = get_tokenizer_model(model) if not model_name.endswith(".def"): model_name = f"{model_name}.def"