Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: move vocabs to external storage #73

Merged
merged 4 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@ classifiers = [
"Programming Language :: Python :: 3 :: Only",
]
keywords = ["llm", "materials", "chemistry"]
dependencies = ["transformers", "slices", "robocrys", "matminer", "keras<3"]
dependencies = [
"transformers",
"slices",
"robocrys",
"matminer",
"keras<3",
"pystow",
]
[project.urls]
Homepage = "https://github.com/lamalab-org/xtal2txt"
Issues = "https://github.com/lamalab-org/xtal2txt/issues"
Expand Down
105 changes: 79 additions & 26 deletions src/xtal2txt/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,69 @@
)

from typing import List
from xtal2txt.utils import xtal2txt_storage


THIS_DIR = os.path.dirname(os.path.abspath(__file__))

SLICE_VOCAB = os.path.join(THIS_DIR, "vocabs", "slice_vocab.txt")
SLICE_RT_VOCAB = os.path.join(THIS_DIR, "vocabs", "slice_vocab_rt.txt")
SLICE_VOCAB = str(
xtal2txt_storage.ensure(
url="https://zenodo.org/records/11484062/files/slice_vocab.txt?download=1"
)
)
SLICE_RT_VOCAB = str(
xtal2txt_storage.ensure(
url="https://zenodo.org/records/11484062/files/slice_vocab_rt.txt?download=1"
)
)

COMPOSITION_VOCAB = os.path.join(THIS_DIR, "vocabs", "composition_vocab.txt")
COMPOSITION_RT_VOCAB = os.path.join(THIS_DIR, "vocabs", "composition_vocab_rt.txt")
COMPOSITION_VOCAB = str(
xtal2txt_storage.ensure(
url="https://zenodo.org/records/11484062/files/composition_vocab.txt?download=1"
)
)
COMPOSITION_RT_VOCAB = str(
xtal2txt_storage.ensure(
url="https://zenodo.org/records/11484062/files/composition_vocab_rt.txt?download=1"
)
)

CIF_VOCAB = os.path.join(THIS_DIR, "vocabs", "cif_vocab.json")
CIF_RT_VOCAB = os.path.join(THIS_DIR, "vocabs", "cif_vocab_rt.json")
CIF_VOCAB = str(
xtal2txt_storage.ensure(
url="https://zenodo.org/records/11484062/files/cif_vocab.json?download=1"
)
)
CIF_RT_VOCAB = str(
xtal2txt_storage.ensure(
url="https://zenodo.org/records/11484062/files/cif_vocab_rt.json?download=1"
)
)

CRYSTAL_LLM_VOCAB = os.path.join(THIS_DIR, "vocabs", "crystal_llm_vocab.json")
CRYSTAL_LLM_RT_VOCAB = os.path.join(THIS_DIR, "vocabs", "crystal_llm_vocab_rt.json")
CRYSTAL_LLM_VOCAB = str(
xtal2txt_storage.ensure(
url="https://zenodo.org/records/11484062/files/crystal_llm_vocab.json?download=1"
)
)
CRYSTAL_LLM_RT_VOCAB = str(
xtal2txt_storage.ensure(
url="https://zenodo.org/records/11484062/files/crystal_llm_vocab_rt.json?download=1"
)
)

SMILES_VOCAB = os.path.join(THIS_DIR, "vocabs", "smiles_vocab.json")
SMILES_RT_VOCAB = os.path.join(THIS_DIR, "vocabs", "smiles_vocab_rt.json")
SMILES_VOCAB = str(
xtal2txt_storage.ensure(
url="https://zenodo.org/records/11484062/files/smiles_vocab.json?download=1"
)
)
SMILES_RT_VOCAB = str(
xtal2txt_storage.ensure(
url="https://zenodo.org/records/11484062/files/smiles_vocab_rt.json?download=1"
)
)

ROBOCRYS_VOCAB = os.path.join(THIS_DIR, "vocabs", "robocrys_vocab.json")
ROBOCRYS_VOCAB = str(
xtal2txt_storage.ensure(
url="https://zenodo.org/records/11484062/files/robocrys_vocab.json?download=1"
)
)


class NumTokenizer:
Expand Down Expand Up @@ -203,9 +246,11 @@ def convert_tokens_to_string(self, tokens):
if self.special_num_tokens:
return "".join(
[
token
if not (token.startswith("_") and token.endswith("_"))
else token.split("_")[1]
(
token
if not (token.startswith("_") and token.endswith("_"))
else token.split("_")[1]
)
for token in tokens
]
)
Expand Down Expand Up @@ -272,9 +317,11 @@ def save_vocabulary(self, save_directory, filename_prefix=None):

vocab_file = os.path.join(
save_directory,
f"{index + 1}-{filename_prefix}.json"
if filename_prefix
else f"{index + 1}.json",
(
f"{index + 1}-{filename_prefix}.json"
if filename_prefix
else f"{index + 1}.json"
),
)

with open(vocab_file, "w", encoding="utf-8") as f:
Expand Down Expand Up @@ -335,17 +382,20 @@ def convert_tokens_to_string(self, tokens):
if self.special_num_tokens:
return " ".join(
[
token
if not (token.startswith("_") and token.endswith("_"))
else token.split("_")[1]
(
token
if not (token.startswith("_") and token.endswith("_"))
else token.split("_")[1]
)
for token in tokens
]
)
return " ".join(tokens).rstrip()

def token_analysis(self, list_of_tokens):
"""Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The
token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token."""
token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.
"""
analysis_masks = ANALYSIS_MASK_TOKENS
token_type = SLICE_ANALYSIS_DICT
return [
Expand Down Expand Up @@ -377,7 +427,8 @@ def __init__(

def token_analysis(self, list_of_tokens):
"""Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The
token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token."""
token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.
"""
analysis_masks = ANALYSIS_MASK_TOKENS
token_type = COMPOSITION_ANALYSIS_DICT
return [
Expand Down Expand Up @@ -409,7 +460,8 @@ def __init__(

def token_analysis(self, list_of_tokens):
"""Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The
token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token."""
token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.
"""
analysis_masks = ANALYSIS_MASK_TOKENS
token_type = CIF_ANALYSIS_DICT
return [
Expand Down Expand Up @@ -441,7 +493,8 @@ def __init__(

def token_analysis(self, list_of_tokens):
"""Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The
token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token."""
token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.
"""
analysis_masks = ANALYSIS_MASK_TOKENS
token_type = CRYSTAL_LLM_ANALYSIS_DICT
return [
Expand Down
3 changes: 3 additions & 0 deletions src/xtal2txt/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import pystow

xtal2txt_storage = pystow.module("xtal2txt")
1 change: 0 additions & 1 deletion src/xtal2txt/vocabs/1.json

This file was deleted.

185 changes: 0 additions & 185 deletions src/xtal2txt/vocabs/cif_vocab.json

This file was deleted.

Loading
Loading