From b2bc253d00799b29fe8e411c18b1d71e0cef9394 Mon Sep 17 00:00:00 2001 From: n0w0f Date: Sun, 11 Aug 2024 00:26:19 +0200 Subject: [PATCH 01/20] feat: scripts prep classification --- revision-scripts/5fold_split.py | 38 +++++++++ revision-scripts/matbench_is_metal.py | 83 ++++++++++++++++++ revision-scripts/mp_classification.py | 58 +++++++++++++ revision-scripts/prep_json.py | 94 +++++++++++++++++++++ revision-scripts/prep_rep.py | 117 ++++++++++++++++++++++++++ revision-scripts/text_rep.py | 116 +++++++++++++++++++++++++ 6 files changed, 506 insertions(+) create mode 100644 revision-scripts/5fold_split.py create mode 100644 revision-scripts/matbench_is_metal.py create mode 100644 revision-scripts/mp_classification.py create mode 100644 revision-scripts/prep_json.py create mode 100644 revision-scripts/prep_rep.py create mode 100644 revision-scripts/text_rep.py diff --git a/revision-scripts/5fold_split.py b/revision-scripts/5fold_split.py new file mode 100644 index 0000000..2342aaa --- /dev/null +++ b/revision-scripts/5fold_split.py @@ -0,0 +1,38 @@ +import json +import os +import random +from sklearn.model_selection import KFold +import fire + +def split_dataset(input_json, output_dir, n_splits=5, random_state=42): + # Load the data + with open(input_json, 'r') as f: + data = json.load(f) + + # Create KFold object + kf = KFold(n_splits=n_splits, shuffle=True, random_state=random_state) + + # Ensure output directory exists + os.makedirs(output_dir, exist_ok=True) + + # Perform the split + for fold, (train_index, test_index) in enumerate(kf.split(data), 1): + train_data = [data[i] for i in train_index] + test_data = [data[i] for i in test_index] + + # Save train data + train_file = os.path.join(output_dir, f'train_mp_classification_fold_{fold}.json') + with open(train_file, 'w') as f: + json.dump(train_data, f, indent=2) + + # Save test data + test_file = os.path.join(output_dir, f'test_mp_classification_fold_{fold}.json') + with open(test_file, 'w') as f: + json.dump(test_data, f, indent=2) + + print(f"Fold {fold} created: {train_file} and {test_file}") + + print("Dataset splitting completed.") + +if __name__ == "__main__": + fire.Fire(split_dataset) \ No newline at end of file diff --git a/revision-scripts/matbench_is_metal.py b/revision-scripts/matbench_is_metal.py new file mode 100644 index 0000000..0949444 --- /dev/null +++ b/revision-scripts/matbench_is_metal.py @@ -0,0 +1,83 @@ +import json +import os + +import hydra +from matbench.bench import MatbenchBenchmark +from omegaconf import DictConfig + +# Check if the specified benchmark exists +available_benchmarks = [ + # "matbench_dielectric", + # "matbench_expt_gap", + # "matbench_expt_is_metal", + # "matbench_glass", + # "matbench_mp_e_form", + # "matbench_mp_gap", + "matbench_mp_is_metal", + # "matbench_phonons", + # "matbench_steels", +] + + +def convert_structure_to_serializable(pymatgen_structure): + # Assuming Structure has 'data' and 'metadata' attributes + cif_content = pymatgen_structure.to(fmt="cif") + return cif_content + + +@hydra.main(version_base=None, config_path="../conf", config_name="config") +def main(cfg: DictConfig) -> None: + mb = MatbenchBenchmark(autoload=False) + benchmarks = cfg.matbench.benchmarks.dataset + path = cfg.matbench.path.save_path + print(path) + if not os.path.exists(path): + os.mkdir(path) + else: + print(f"Directory '{path}' already exists.") + for benchmark_name in benchmarks: + if benchmark_name not in available_benchmarks: + raise ValueError( + f"Invalid benchmark name. Available benchmarks: {', '.join(available_benchmarks)}" + ) + + for benchmark_name in benchmarks: + benchmark = getattr(mb, benchmark_name) + benchmark.load() + + for fold in benchmark.folds: + # Get train inputs and outputs + train_inputs, train_outputs = benchmark.get_train_and_val_data(fold) + test_inputs = benchmark.get_test_data(fold) + + # Create the train data + train_data = [ + { + "structure": convert_structure_to_serializable(train_inputs[index]), + "labels": train_outputs[index], + } + for index in range(len(train_inputs)) + ] + + # Save the train data as a JSON file + train_dataset_name = f"train_{benchmark_name}_{fold}.json" + with open(f"{path}/{train_dataset_name}", "w") as train_file: + json.dump(train_data, train_file) + + print(f"Train data saved to {path}/{train_dataset_name}") + + test_data = [ + convert_structure_to_serializable(test_inputs[index]) + for index in range(len(test_inputs)) + ] + + # Save the test data as a JSON file + test_dataset_name = f"test_{benchmark_name}_{fold}.json" + with open(f"{path}/{test_dataset_name}", "w") as test_file: + json.dump(test_data, test_file) + + print(f"Test data saved to {path}/{test_dataset_name}") + + +if __name__ == "__main__": + main() diff --git a/revision-scripts/mp_classification.py b/revision-scripts/mp_classification.py new file mode 100644 index 0000000..03d5cca --- /dev/null +++ b/revision-scripts/mp_classification.py @@ -0,0 +1,58 @@ +import lmdb +import pickle +import json +import os +from pymatgen.core import Structure +import fire + +class Dataset: + def __init__(self, lmdb_path, max_readers=1): + self.env = lmdb.open(lmdb_path, + subdir=False, + readonly=True, + lock=False, + readahead=False, + meminit=False, + max_readers=max_readers) + self.txn = self.env.begin() + + def __len__(self): + return self.txn.stat()['entries'] + + def get(self, index): + id = f"{index}".encode("ascii") + datapoint = pickle.loads(self.txn.get(id)) + return datapoint + +def create_json_from_lmdb(lmdb_path, output_dir): + dataset = Dataset(lmdb_path) + output_data = [] + + for i in range(len(dataset)): + d = dataset.get(i) + + # Convert structure to CIF + structure = d['structure'] + cif = structure.to(fmt="cif") + + entry = { + "structure": cif, + "is_stable": d['is_stable'], + "is_metal": d['is_metal'], + "is_magnetic": d['is_magnetic'] + } + + output_data.append(entry) + + # Ensure output directory exists + os.makedirs(output_dir, exist_ok=True) + + # Write to JSON file + output_file = os.path.join(output_dir, "mp_test.json") + with open(output_file, 'w') as f: + json.dump(output_data, f, indent=2) + + print(f"JSON file created: {output_file}") + +if __name__ == "__main__": + fire.Fire(create_json_from_lmdb) \ No newline at end of file diff --git a/revision-scripts/prep_json.py b/revision-scripts/prep_json.py new file mode 100644 index 0000000..015e5fc --- /dev/null +++ b/revision-scripts/prep_json.py @@ -0,0 +1,94 @@ +import json +import os +from matbench.bench import MatbenchBenchmark +import numpy as np + +# Define the available benchmarks +available_benchmarks = [ + "matbench_mp_is_metal", +] + +def convert_structure_to_serializable(pymatgen_structure): + """ + Convert a pymatgen Structure object to a serializable format (CIF). + """ + return pymatgen_structure.to(fmt="cif") + +def convert_label_to_serializable(label): + """ + Convert labels to 0 or 1, specifically converting numpy booleans to Python integers. + """ + return int(label) + +def download_benchmark_data(benchmark_name, save_path): + """ + Download and save the Matbench benchmark data as JSON files. + + Args: + benchmark_name (str): The name of the benchmark to download. + save_path (str): The directory path where the JSON files will be saved. + """ + if benchmark_name not in available_benchmarks: + raise ValueError( + f"Invalid benchmark name. Available benchmarks: {', '.join(available_benchmarks)}" + ) + + # Load the MatbenchBenchmark + mb = MatbenchBenchmark(autoload=False) + + # Create the save directory if it does not exist + if not os.path.exists(save_path): + os.mkdir(save_path) + else: + print(f"Directory '{save_path}' already exists.") + + # Load the benchmark data + benchmark = getattr(mb, benchmark_name) + benchmark.load() + + # Process each fold in the benchmark + for fold in benchmark.folds: + # Get train inputs and outputs + train_inputs, train_outputs = benchmark.get_train_and_val_data(fold) + test_inputs = benchmark.get_test_data(fold) + + # Create the train data + train_data = [ + { + "mbid": index, # Add material ID (index) + "structure": convert_structure_to_serializable(train_inputs[index]), + "labels": convert_label_to_serializable(train_outputs[index]), # Convert bool to 0 or 1 + } + for index in train_inputs.index + ] + + # Save the train data as a JSON file + train_dataset_name = f"train_{benchmark_name}_{fold}.json" + with open(os.path.join(save_path, train_dataset_name), "w") as train_file: + json.dump(train_data, train_file) + + print(f"Train data saved to {save_path}/{train_dataset_name}") + + # Create the test data + test_data = [ + { + "mbid": index, # Add material ID (index) + "structure": convert_structure_to_serializable(test_inputs[index]) + } + for index in test_inputs.index + ] + + # Save the test data as a JSON file + test_dataset_name = f"test_{benchmark_name}_{fold}.json" + with open(os.path.join(save_path, test_dataset_name), "w") as test_file: + json.dump(test_data, test_file) + + print(f"Test data saved to {save_path}/{test_dataset_name}") + +if __name__ == "__main__": + # Define the benchmark name and the directory to save the data + benchmark_name = "matbench_mp_is_metal" + save_path = "./benchmark_data_is_metal" + + # Download and save the benchmark data + download_benchmark_data(benchmark_name, save_path) diff --git a/revision-scripts/prep_rep.py b/revision-scripts/prep_rep.py new file mode 100644 index 0000000..5c277d7 --- /dev/null +++ b/revision-scripts/prep_rep.py @@ -0,0 +1,117 @@ +import json +import fire + +from concurrent.futures import ProcessPoolExecutor, TimeoutError +import multiprocessing +from functools import partial +from xtal2txt.core import TextRep + +from typing import List, Dict + + +def read_json(json_file: str) -> List[Dict]: + """Read JSON data from a file. + + Args: + json_file (str): The path to the JSON file. + + Returns: + List[Dict]: A list of dictionaries containing the JSON data. + """ + with open(json_file, 'r') as file: + data = json.load(file) + return data + + + + +def process_entry_train_matbench(entry: dict, timeout: int) -> dict: + + try: + text_reps = TextRep.from_input(entry["structure"]).get_requested_text_reps(["local_env","slice","composition","cif_symmetrized","cif_p1","crystal_llm_rep", "atoms","atoms_params", "zmatrix", "wyckoff_rep", "mbid" ]) # Use get_all_text_reps to get various text representations # Add chemical formula to the dictionary + text_reps['is_stable'] = int(entry["is_stable"]) + text_reps["is_magnetic"] = int(entry["is_magnetic"]) + text_reps["is_metal"] = int(entry["is_metal"]) + return text_reps # Return the entire dictionary + except TimeoutError: + print("Timeout error processing a row") + return None + except Exception as e: + print(f"Error processing a row: {e}") + return None + + +def process_entry_test_matbench(entry: List, timeout: int) -> dict: + # Ensure the give_slice function and necessary data are picklable + try: + text_reps = TextRep.from_input(entry["structure"]).get_requested_text_reps(["local_env","slice","composition","cif_symmetrized","cif_p1","crystal_llm_rep", "atoms","atoms_params", "zmatrix", "wyckoff_rep", "mbid" ]) # Use get_all_text_reps to get various text representations # Add chemical formula to the dictionary + # Use get_all_text_reps to get various text representations # Add chemical formula to the dictionary + text_reps["mbid"] = entry["mbid"] + return text_reps # Return the entire dictionary + except TimeoutError: + print("Timeout error processing a row") + return None + except Exception as e: + print(f"Error processing a row: {e}") + return None + + +def process_batch(num_workers, batch, timeout, process_entry_func): + + process_entry_with_timeout = partial(process_entry_func, timeout=timeout) + + with ProcessPoolExecutor(max_workers=num_workers) as executor: + results = list(executor.map(process_entry_with_timeout, batch)) + + return [result for result in results if result is not None] + + + +def process_json_to_json(json_file: str, output_json_file: str, log_file_path: str,process_entry: str = 'test', num_workers: int = 48, timeout: int = 600, save_interval: int = 100, last_processed_entry: int = 0): + + num_cpus = multiprocessing.cpu_count() + print(num_workers) + + process_entry_funcs = { + 'test': process_entry_test_matbench, + 'train': process_entry_train_matbench + } + # Get the selected function + process_entry_func = process_entry_funcs[process_entry] + + print(f"json file: {json_file}") + print(f"number of cpus: {num_cpus}") + print(f"number of workers: {num_workers}") + print(f"last processed entry: {last_processed_entry}") + print(f"save_interval: {save_interval}") + + data = read_json(json_file) + batch_size = num_workers * 4 + + if last_processed_entry > 0: + data = data[last_processed_entry:] + + batch_iterator = (data[i:i + batch_size] for i in range(0, len(data), batch_size)) + + for i, batch_data in enumerate(batch_iterator, start=1): + batch_results = process_batch(num_workers,batch_data, timeout, process_entry_func) + + # Append batch_results to the output JSON file + with open(output_json_file, 'a') as f: + for result in batch_results: + json.dump(result, f) + f.write('\n') + + last_processed_entry += len(batch_data) + if i % save_interval == 0: + with open(log_file_path, "w") as log_file: + log_file.write(f"Last processed entry index: {last_processed_entry}\n") + log_file.write(f"Last processed batch number: {i}\n") + + print(f"Finished !!! logging at {log_file_path}") + + +if __name__ == "__main__": + fire.Fire(process_json_to_json) + + diff --git a/revision-scripts/text_rep.py b/revision-scripts/text_rep.py new file mode 100644 index 0000000..0c025e1 --- /dev/null +++ b/revision-scripts/text_rep.py @@ -0,0 +1,116 @@ +import json +import fire + +from concurrent.futures import ProcessPoolExecutor, TimeoutError +import multiprocessing +from functools import partial +from xtal2txt.core import TextRep + +from typing import List, Dict + + +def read_json(json_file: str) -> List[Dict]: + """Read JSON data from a file. + + Args: + json_file (str): The path to the JSON file. + + Returns: + List[Dict]: A list of dictionaries containing the JSON data. + """ + with open(json_file, 'r') as file: + data = json.load(file) + return data + + + + +def process_entry_train_matbench(entry: dict, timeout: int) -> dict: + + try: + text_reps = TextRep.from_input(entry["structure"]).get_requested_text_reps(["local_env","slice","composition","cif_symmetrized","cif_p1","crystal_llm_rep", "atoms","atoms_params", "zmatrix", "wyckoff_rep", "mbid" ]) # Use get_all_text_reps to get various text representations # Add chemical formula to the dictionary + text_reps['labels'] = entry["labels"] + text_reps["mbid"] = entry["mbid"] + return text_reps # Return the entire dictionary + except TimeoutError: + print("Timeout error processing a row") + return None + except Exception as e: + print(f"Error processing a row: {e}") + return None + + +def process_entry_test_matbench(entry: List, timeout: int) -> dict: + # Ensure the give_slice function and necessary data are picklable + try: + text_reps = TextRep.from_input(entry["structure"]).get_requested_text_reps(["local_env","slice","composition","cif_symmetrized","cif_p1","crystal_llm_rep", "atoms","atoms_params", "zmatrix", "wyckoff_rep", "mbid" ]) # Use get_all_text_reps to get various text representations # Add chemical formula to the dictionary + # Use get_all_text_reps to get various text representations # Add chemical formula to the dictionary + text_reps["mbid"] = entry["mbid"] + return text_reps # Return the entire dictionary + except TimeoutError: + print("Timeout error processing a row") + return None + except Exception as e: + print(f"Error processing a row: {e}") + return None + + +def process_batch(num_workers, batch, timeout, process_entry_func): + + process_entry_with_timeout = partial(process_entry_func, timeout=timeout) + + with ProcessPoolExecutor(max_workers=num_workers) as executor: + results = list(executor.map(process_entry_with_timeout, batch)) + + return [result for result in results if result is not None] + + + +def process_json_to_json(json_file: str, output_json_file: str, log_file_path: str,process_entry: str = 'test', num_workers: int = 48, timeout: int = 600, save_interval: int = 100, last_processed_entry: int = 0): + + num_cpus = multiprocessing.cpu_count() + print(num_workers) + + process_entry_funcs = { + 'test': process_entry_test_matbench, + 'train': process_entry_train_matbench + } + # Get the selected function + process_entry_func = process_entry_funcs[process_entry] + + print(f"json file: {json_file}") + print(f"number of cpus: {num_cpus}") + print(f"number of workers: {num_workers}") + print(f"last processed entry: {last_processed_entry}") + print(f"save_interval: {save_interval}") + + data = read_json(json_file) + batch_size = num_workers * 4 + + if last_processed_entry > 0: + data = data[last_processed_entry:] + + batch_iterator = (data[i:i + batch_size] for i in range(0, len(data), batch_size)) + + for i, batch_data in enumerate(batch_iterator, start=1): + batch_results = process_batch(num_workers,batch_data, timeout, process_entry_func) + + # Append batch_results to the output JSON file + with open(output_json_file, 'a') as f: + for result in batch_results: + json.dump(result, f) + f.write('\n') + + last_processed_entry += len(batch_data) + if i % save_interval == 0: + with open(log_file_path, "w") as log_file: + log_file.write(f"Last processed entry index: {last_processed_entry}\n") + log_file.write(f"Last processed batch number: {i}\n") + + print(f"Finished !!! logging at {log_file_path}") + + +if __name__ == "__main__": + fire.Fire(process_json_to_json) + + From 59fc2872b3706c4dcaec7c2e64809e6c51478e3f Mon Sep 17 00:00:00 2001 From: n0w0f Date: Sun, 11 Aug 2024 22:42:41 +0200 Subject: [PATCH 02/20] chore: configs for bg and form --- conf/benchmark.yaml | 53 +++++++++------- conf/bg/atoms.yaml | 19 ++++++ conf/bg/atoms_params.yaml | 17 +++++ conf/bg/cifp1.yaml | 17 +++++ conf/bg/cifpsym.yaml | 17 +++++ conf/bg/composition.yaml | 17 +++++ conf/bg/crystal_llm.yaml | 16 +++++ conf/bg/local_env.yaml | 17 +++++ conf/bg/slices.yaml | 17 +++++ conf/bg/zmatrix.yaml | 17 +++++ conf/bg2m/atoms.yaml | 13 ++++ conf/bg2m/atoms_params.yaml | 13 ++++ conf/bg2m/cifp1.yaml | 13 ++++ conf/bg2m/cifsymmetrized.yaml | 13 ++++ conf/bg2m/composition.yaml | 13 ++++ conf/bg2m/crystal_llm.yaml | 13 ++++ conf/bg2m/local_env.yaml | 13 ++++ conf/bg2m/slice.yaml | 13 ++++ conf/bg2m/zmatrix.yaml | 13 ++++ conf/form/atoms.yaml | 19 ++++++ conf/form/atoms_params.yaml | 17 +++++ conf/form/cifp1.yaml | 17 +++++ conf/form/cifpsym.yaml | 17 +++++ conf/form/composition.yaml | 17 +++++ conf/form/crystal_llm.yaml | 16 +++++ conf/form/local_env.yaml | 17 +++++ conf/form/slices.yaml | 17 +++++ conf/form/zmatrix.yaml | 17 +++++ conf/form_energy.yaml | 19 ++++++ conf/model/benchmark_example.yaml | 2 +- conf/model/formation_energy.yaml | 100 ++++++++++++++++++++++++++++++ src/mattext/models/score.py | 2 + 32 files changed, 578 insertions(+), 23 deletions(-) create mode 100644 conf/bg/atoms.yaml create mode 100644 conf/bg/atoms_params.yaml create mode 100644 conf/bg/cifp1.yaml create mode 100644 conf/bg/cifpsym.yaml create mode 100644 conf/bg/composition.yaml create mode 100644 conf/bg/crystal_llm.yaml create mode 100644 conf/bg/local_env.yaml create mode 100644 conf/bg/slices.yaml create mode 100644 conf/bg/zmatrix.yaml create mode 100644 conf/bg2m/atoms.yaml create mode 100644 conf/bg2m/atoms_params.yaml create mode 100644 conf/bg2m/cifp1.yaml create mode 100644 conf/bg2m/cifsymmetrized.yaml create mode 100644 conf/bg2m/composition.yaml create mode 100644 conf/bg2m/crystal_llm.yaml create mode 100644 conf/bg2m/local_env.yaml create mode 100644 conf/bg2m/slice.yaml create mode 100644 conf/bg2m/zmatrix.yaml create mode 100644 conf/form/atoms.yaml create mode 100644 conf/form/atoms_params.yaml create mode 100644 conf/form/cifp1.yaml create mode 100644 conf/form/cifpsym.yaml create mode 100644 conf/form/composition.yaml create mode 100644 conf/form/crystal_llm.yaml create mode 100644 conf/form/local_env.yaml create mode 100644 conf/form/slices.yaml create mode 100644 conf/form/zmatrix.yaml create mode 100644 conf/form_energy.yaml create mode 100644 conf/model/formation_energy.yaml diff --git a/conf/benchmark.yaml b/conf/benchmark.yaml index 56f4217..0c8a299 100644 --- a/conf/benchmark.yaml +++ b/conf/benchmark.yaml @@ -1,24 +1,33 @@ - hydra: - job: - name: benchmark - run: - dir: ${hydra:runtime.cwd}/outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} - sweep: - dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} - subdir: ${hydra.job.override_dirname} - - - - defaults: - - model: none - - - - runs: - - - - name: benchmark_run - tasks: [benchmark] - +hydra: + job: + name: benchmark + run: + dir: ${hydra:runtime.cwd}/outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} + sweep: + dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} + subdir: ${hydra.job.override_dirname} + + # launcher: + # _target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher + # submitit_folder: ${hydra.sweep.dir}/.submitit/%j + # timeout_min: 3600 + # mem_gb: 160 + # nodes: 1 + # #gpus_per_task: 1 + # gres: gpu:1 + # #gpus_per_node: 2 + # name: ${hydra.job.name} + # partition: 'gpu' + # additional_parameters: + # nodelist: 'gpu[008,013-017]' + # tasks_per_node: 1 + +defaults: +- model: none +# - override hydra/launcher: submitit_slurm + +runs: + - name: benchmark_run + tasks: [benchmark] \ No newline at end of file diff --git a/conf/bg/atoms.yaml b/conf/bg/atoms.yaml new file mode 100644 index 0000000..4c183f8 --- /dev/null +++ b/conf/bg/atoms.yaml @@ -0,0 +1,19 @@ +# @package _global_ +model: + representation: atom_sequences + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-atom-seq-2m + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 32 + training_arguments: + per_device_train_batch_size: 1024 + path: + pretrained_checkpoint: n0w0f/MatText-atom-seq-2m + + \ No newline at end of file diff --git a/conf/bg/atoms_params.yaml b/conf/bg/atoms_params.yaml new file mode 100644 index 0000000..728685e --- /dev/null +++ b/conf/bg/atoms_params.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: atom_sequences_plusplus + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-atom-seq-plusplus-2m + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 32 + training_arguments: + per_device_train_batch_size: 1024 + + \ No newline at end of file diff --git a/conf/bg/cifp1.yaml b/conf/bg/cifp1.yaml new file mode 100644 index 0000000..633f5de --- /dev/null +++ b/conf/bg/cifp1.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: cif_p1 + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-cifp1-2m + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 1024 + training_arguments: + per_device_train_batch_size: 64 + path: + pretrained_checkpoint: n0w0f/MatText-cifp1-2m \ No newline at end of file diff --git a/conf/bg/cifpsym.yaml b/conf/bg/cifpsym.yaml new file mode 100644 index 0000000..6175580 --- /dev/null +++ b/conf/bg/cifpsym.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: cif_symmetrized + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-cifsymmetrized-2m + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 1024 + training_arguments: + per_device_train_batch_size: 64 + path: + pretrained_checkpoint: n0w0f/MatText-cifsymmetrized-2m \ No newline at end of file diff --git a/conf/bg/composition.yaml b/conf/bg/composition.yaml new file mode 100644 index 0000000..7e52344 --- /dev/null +++ b/conf/bg/composition.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: composition + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-composition-2m + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 32 + training_arguments: + per_device_train_batch_size: 1024 + + \ No newline at end of file diff --git a/conf/bg/crystal_llm.yaml b/conf/bg/crystal_llm.yaml new file mode 100644 index 0000000..e0750e0 --- /dev/null +++ b/conf/bg/crystal_llm.yaml @@ -0,0 +1,16 @@ +# @package _global_ +model: + representation: crystal_text_llm + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: /home/so87pot/n0w0f/structllm_ckpt/alpaca_ckpt/checkpoint-393000 + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 512 + training_arguments: + per_device_train_batch_size: 64 + \ No newline at end of file diff --git a/conf/bg/local_env.yaml b/conf/bg/local_env.yaml new file mode 100644 index 0000000..b26e598 --- /dev/null +++ b/conf/bg/local_env.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: local_env + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: /home/so87pot/n0w0f/structllm_ckpt/santiago_ckpt_rt/checkpoint-95000 + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 512 + training_arguments: + per_device_train_batch_size: 64 + path: + pretrained_checkpoint: /home/so87pot/n0w0f/structllm_ckpt/santiago_ckpt_rt/checkpoint-95000 \ No newline at end of file diff --git a/conf/bg/slices.yaml b/conf/bg/slices.yaml new file mode 100644 index 0000000..1076dfe --- /dev/null +++ b/conf/bg/slices.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: slices + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-slices-2m + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 512 + training_arguments: + per_device_train_batch_size: 64 + path: + pretrained_checkpoint: n0w0f/MatText-slices-2m \ No newline at end of file diff --git a/conf/bg/zmatrix.yaml b/conf/bg/zmatrix.yaml new file mode 100644 index 0000000..f25472f --- /dev/null +++ b/conf/bg/zmatrix.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: zmatrix + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-zmatrix-2m + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 512 + training_arguments: + per_device_train_batch_size: 64 + path: + pretrained_checkpoint: n0w0f/MatText-zmatrix-2m \ No newline at end of file diff --git a/conf/bg2m/atoms.yaml b/conf/bg2m/atoms.yaml new file mode 100644 index 0000000..1d55ace --- /dev/null +++ b/conf/bg2m/atoms.yaml @@ -0,0 +1,13 @@ +# @package _global_ +model: + representation: atoms_params + logging: + wandb_project: 2m_intel_ft + + finetune: + model_name: 2m_intel_ft + context_length: 32 + training_arguments: + per_device_train_batch_size: 1024 + path: + pretrained_checkpoint: /work/so87pot/mattext/megaloop/checkpoints/checkpoints/atoms_params_pt_30k_atoms/checkpoint-1000 diff --git a/conf/bg2m/atoms_params.yaml b/conf/bg2m/atoms_params.yaml new file mode 100644 index 0000000..1d55ace --- /dev/null +++ b/conf/bg2m/atoms_params.yaml @@ -0,0 +1,13 @@ +# @package _global_ +model: + representation: atoms_params + logging: + wandb_project: 2m_intel_ft + + finetune: + model_name: 2m_intel_ft + context_length: 32 + training_arguments: + per_device_train_batch_size: 1024 + path: + pretrained_checkpoint: /work/so87pot/mattext/megaloop/checkpoints/checkpoints/atoms_params_pt_30k_atoms/checkpoint-1000 diff --git a/conf/bg2m/cifp1.yaml b/conf/bg2m/cifp1.yaml new file mode 100644 index 0000000..ad74f90 --- /dev/null +++ b/conf/bg2m/cifp1.yaml @@ -0,0 +1,13 @@ +# @package _global_ +model: + representation: cif_p1 + logging: + wandb_project: 2m_intel_ft + + finetune: + model_name: 2m_intel_ft + context_length: 1024 + training_arguments: + per_device_train_batch_size: 32 + path: + pretrained_checkpoint: /work/so87pot/mattext/megaloop2/checkpoints/checkpoints/cif_p1_pt_30k_rt_2/checkpoint-46000 diff --git a/conf/bg2m/cifsymmetrized.yaml b/conf/bg2m/cifsymmetrized.yaml new file mode 100644 index 0000000..e7cc55b --- /dev/null +++ b/conf/bg2m/cifsymmetrized.yaml @@ -0,0 +1,13 @@ +# @package _global_ +model: + representation: cif_symmetrized + logging: + wandb_project: 2m_intel_ft + + finetune: + model_name: 2m_intel_ft + context_length: 1024 + training_arguments: + per_device_train_batch_size: 32 + path: + pretrained_checkpoint: /work/so87pot/mattext/megaloop2/checkpoints/checkpoints/cif_symmetrized_pt_30k_rt/checkpoint-45000 diff --git a/conf/bg2m/composition.yaml b/conf/bg2m/composition.yaml new file mode 100644 index 0000000..3783298 --- /dev/null +++ b/conf/bg2m/composition.yaml @@ -0,0 +1,13 @@ +# @package _global_ +model: + representation: composition + logging: + wandb_project: 2m_intel_ft + + finetune: + model_name: 2m_intel_ft + context_length: 32 + training_arguments: + per_device_train_batch_size: 1024 + path: + pretrained_checkpoint: /work/so87pot/mattext/megaloop2/checkpoints/checkpoints/composition_pt_30k_rt/checkpoint-1000 diff --git a/conf/bg2m/crystal_llm.yaml b/conf/bg2m/crystal_llm.yaml new file mode 100644 index 0000000..9f97208 --- /dev/null +++ b/conf/bg2m/crystal_llm.yaml @@ -0,0 +1,13 @@ +# @package _global_ +model: + representation: crystal_llm_rep + logging: + wandb_project: 2m_intel_ft + + finetune: + model_name: 2m_intel_ft + context_length: 512 + training_arguments: + per_device_train_batch_size: 64 + path: + pretrained_checkpoint: /work/so87pot/mattext/megaloop2/checkpoints/checkpoints/crystal_llm_rep_pt_30k_rt/checkpoint-11000 diff --git a/conf/bg2m/local_env.yaml b/conf/bg2m/local_env.yaml new file mode 100644 index 0000000..cbb1363 --- /dev/null +++ b/conf/bg2m/local_env.yaml @@ -0,0 +1,13 @@ +# @package _global_ +model: + representation: zmatrix + logging: + wandb_project: 2m_intel_ft + + finetune: + model_name: 2m_intel_ft + context_length: 512 + training_arguments: + per_device_train_batch_size: 64 + path: + pretrained_checkpoint: /work/so87pot/mattext/megaloop/checkpoints/checkpoints/atoms_params_pt_30k_atoms/checkpoint-1000 diff --git a/conf/bg2m/slice.yaml b/conf/bg2m/slice.yaml new file mode 100644 index 0000000..1fe01e1 --- /dev/null +++ b/conf/bg2m/slice.yaml @@ -0,0 +1,13 @@ +# @package _global_ +model: + representation: slice + logging: + wandb_project: 2m_intel_ft + + finetune: + model_name: 2m_intel_ft + context_length: 512 + training_arguments: + per_device_train_batch_size: 64 + path: + pretrained_checkpoint: /work/so87pot/mattext/megaloop2/checkpoints/checkpoints/slice_pt_30k_rt/checkpoint-23000 diff --git a/conf/bg2m/zmatrix.yaml b/conf/bg2m/zmatrix.yaml new file mode 100644 index 0000000..cbb1363 --- /dev/null +++ b/conf/bg2m/zmatrix.yaml @@ -0,0 +1,13 @@ +# @package _global_ +model: + representation: zmatrix + logging: + wandb_project: 2m_intel_ft + + finetune: + model_name: 2m_intel_ft + context_length: 512 + training_arguments: + per_device_train_batch_size: 64 + path: + pretrained_checkpoint: /work/so87pot/mattext/megaloop/checkpoints/checkpoints/atoms_params_pt_30k_atoms/checkpoint-1000 diff --git a/conf/form/atoms.yaml b/conf/form/atoms.yaml new file mode 100644 index 0000000..6923edd --- /dev/null +++ b/conf/form/atoms.yaml @@ -0,0 +1,19 @@ +# @package _global_ +model: + representation: atom_sequences + dataset: "form_energy" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-atom-seq-2m + logging: + wandb_project: revision-form + + finetune: + model_name: revision-form + context_length: 32 + training_arguments: + per_device_train_batch_size: 1024 + path: + pretrained_checkpoint: n0w0f/MatText-atom-seq-2m + + \ No newline at end of file diff --git a/conf/form/atoms_params.yaml b/conf/form/atoms_params.yaml new file mode 100644 index 0000000..8d3ca79 --- /dev/null +++ b/conf/form/atoms_params.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: atom_sequences_plusplus + dataset: "form_energy" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-atom-seq-plusplus-2m + logging: + wandb_project: revision-form + + finetune: + model_name: revision-form + context_length: 32 + training_arguments: + per_device_train_batch_size: 1024 + + \ No newline at end of file diff --git a/conf/form/cifp1.yaml b/conf/form/cifp1.yaml new file mode 100644 index 0000000..221da0b --- /dev/null +++ b/conf/form/cifp1.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: cif_p1 + dataset: "form_energy" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-cifp1-2m + logging: + wandb_project: revision-form + + finetune: + model_name: revision-form + context_length: 1024 + training_arguments: + per_device_train_batch_size: 64 + path: + pretrained_checkpoint: n0w0f/MatText-cifp1-2m \ No newline at end of file diff --git a/conf/form/cifpsym.yaml b/conf/form/cifpsym.yaml new file mode 100644 index 0000000..0dccf71 --- /dev/null +++ b/conf/form/cifpsym.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: cif_symmetrized + dataset: "form_energy" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-cifsymmetrized-2m + logging: + wandb_project: revision-form + + finetune: + model_name: revision-form + context_length: 1024 + training_arguments: + per_device_train_batch_size: 64 + path: + pretrained_checkpoint: n0w0f/MatText-cifsymmetrized-2m \ No newline at end of file diff --git a/conf/form/composition.yaml b/conf/form/composition.yaml new file mode 100644 index 0000000..61e2211 --- /dev/null +++ b/conf/form/composition.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: composition + dataset: "form_energy" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText−composition−2m + logging: + wandb_project: revision-form + + finetune: + model_name: revision-form + context_length: 32 + training_arguments: + per_device_train_batch_size: 1024 + + \ No newline at end of file diff --git a/conf/form/crystal_llm.yaml b/conf/form/crystal_llm.yaml new file mode 100644 index 0000000..4e831b0 --- /dev/null +++ b/conf/form/crystal_llm.yaml @@ -0,0 +1,16 @@ +# @package _global_ +model: + representation: crystal_text_llm + dataset: "form_energy" + dataset_type: matbench + special_num_token: False + checkpoint: /home/so87pot/n0w0f/structllm_ckpt/alpaca_ckpt/checkpoint-393000 + logging: + wandb_project: revision-form + + finetune: + model_name: revision-form + context_length: 512 + training_arguments: + per_device_train_batch_size: 64 + \ No newline at end of file diff --git a/conf/form/local_env.yaml b/conf/form/local_env.yaml new file mode 100644 index 0000000..c59cc86 --- /dev/null +++ b/conf/form/local_env.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: local_env + dataset: "form_energy" + dataset_type: matbench + special_num_token: False + checkpoint: /home/so87pot/n0w0f/structllm_ckpt/santiago_ckpt_rt/checkpoint-95000 + logging: + wandb_project: revision-form + + finetune: + model_name: revision-form + context_length: 512 + training_arguments: + per_device_train_batch_size: 64 + path: + pretrained_checkpoint: /home/so87pot/n0w0f/structllm_ckpt/santiago_ckpt_rt/checkpoint-95000 \ No newline at end of file diff --git a/conf/form/slices.yaml b/conf/form/slices.yaml new file mode 100644 index 0000000..1dc7e3c --- /dev/null +++ b/conf/form/slices.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: slices + dataset: "form_energy" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-slices-2m + logging: + wandb_project: revision-form + + finetune: + model_name: revision-form + context_length: 512 + training_arguments: + per_device_train_batch_size: 64 + path: + pretrained_checkpoint: n0w0f/MatText-slices-2m \ No newline at end of file diff --git a/conf/form/zmatrix.yaml b/conf/form/zmatrix.yaml new file mode 100644 index 0000000..02a38bc --- /dev/null +++ b/conf/form/zmatrix.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: zmatrix + dataset: "form_energy" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-zmatrix-2m + logging: + wandb_project: revision-form + + finetune: + model_name: revision-form + context_length: 512 + training_arguments: + per_device_train_batch_size: 64 + path: + pretrained_checkpoint: n0w0f/MatText-zmatrix-2m \ No newline at end of file diff --git a/conf/form_energy.yaml b/conf/form_energy.yaml new file mode 100644 index 0000000..00ff258 --- /dev/null +++ b/conf/form_energy.yaml @@ -0,0 +1,19 @@ + + +hydra: + job: + name: formation_energy + run: + dir: ${hydra:runtime.cwd}/outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} + sweep: + dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} + subdir: ${hydra.job.override_dirname} + + +defaults: +- model: none + + +runs: + - name: benchmark_run + tasks: [benchmark] \ No newline at end of file diff --git a/conf/model/benchmark_example.yaml b/conf/model/benchmark_example.yaml index 68ef9c6..6ffd0f3 100644 --- a/conf/model/benchmark_example.yaml +++ b/conf/model/benchmark_example.yaml @@ -59,7 +59,7 @@ finetune: training_arguments: output_dir: "${model.finetune.path.output_dir}" overwrite_output_dir: True - num_train_epochs: 1 + num_train_epochs: 50 per_device_train_batch_size: 1024 save_strategy: "epoch" evaluation_strategy: "epoch" diff --git a/conf/model/formation_energy.yaml b/conf/model/formation_energy.yaml new file mode 100644 index 0000000..6ffd0f3 --- /dev/null +++ b/conf/model/formation_energy.yaml @@ -0,0 +1,100 @@ +representation: ??? +special_num_token: False +dataset: ??? +dataset_type: ??? +fold : 5 +data_repository: "n0w0f/MatText" +checkpoint: ??? +special_tokens: + { + "unk_token": "[UNK]", + "pad_token": "[PAD]", + "cls_token": "[CLS]", + "sep_token": "[SEP]", + "mask_token": "[MASK]", + "eos_token": "[EOS]", + "bos_token": "[BOS]", + } + +logging: + wandb_project: test-benchmark + wandb_log_model: "checkpoint" + +finetune: + model_name: test-benchmark + freeze_base_model: False + dataset_name: "${model.dataset}-train-${model.dataset_type}" + exp_name: + [ + "train_${model.representation}_${model.finetune.dataset_name}_0", + "train_${model.representation}_${model.finetune.dataset_name}_1", + "train_${model.representation}_${model.finetune.dataset_name}_2", + "train_${model.representation}_${model.finetune.dataset_name}_3", + "train_${model.representation}_${model.finetune.dataset_name}_4", + ] + + path: + pretrained_checkpoint: "${model.checkpoint}" + + finetune_data_rootpath: results # <--- Change this to the path of the finetune data + finetune_traindata: + [ + # "kvrh-train-filtered", + ] + + finetune_testdata: + root_path: "${hydra:runtime.cwd}/../../results/${now:%Y-%m-%d}/${now:%H-%M-%S}/${model.finetune.model_name}" # <--- Change this to the path where chkpoints and logs will be saved + output_dir: "${model.finetune.path.root_path}/checkpoints/${model.finetune.exp_name}" + logging_dir: "${model.finetune.path.root_path}/logs/${model.finetune.exp_name}" + finetuned_modelname: "${model.finetune.path.root_path}/checkpoints/finetuned_${model.finetune.exp_name}" + + context_length: 32 + dataprep_seed: 42 + callbacks: + early_stopping: True + custom_logger: True + early_stopping_patience: 10 + early_stopping_threshold: 0.001 + + training_arguments: + output_dir: "${model.finetune.path.output_dir}" + overwrite_output_dir: True + num_train_epochs: 50 + per_device_train_batch_size: 1024 + save_strategy: "epoch" + evaluation_strategy: "epoch" + logging_strategy: "epoch" + logging_first_step: True + save_steps: 3 # Number of epochs before saving + report_to: "wandb" + save_total_limit: 5 + learning_rate: 2e-4 + logging_steps: 1 + eval_steps: 1 + seed: 42 + load_best_model_at_end: True + +inference: + benchmark_dataset: "${model.dataset}-test-${model.dataset_type}" + context_length: "${model.finetune.context_length}" + exp_name: + [ + "test_${model.representation}_${model.finetune.dataset_name}_0", + "test_${model.representation}_${model.finetune.dataset_name}_1", + "test_${model.representation}_${model.finetune.dataset_name}_2", + "test_${model.representation}_${model.finetune.dataset_name}_3", + "test_${model.representation}_${model.finetune.dataset_name}_4", + ] + path: + pretrained_checkpoint: [] + test_data_rootpath: # <--- Change this to the path of the finetune data + test_data: + [ + # "kvrh-train-filtered", + ] + root_path: "/home/so87pot/n0w0f/mattext/src/mattext/models/predictions" # <--- Change this to the path where predictions will be saved + output_dir: "${model.inference.path.root_path}/checkpoints/${model.inference.exp_name}" + logging_dir: "${model.inference.path.root_path}/logs/${model.inference.exp_name}" + predictions: "${model.inference.path.root_path}/checkpoints/inference${model.inference.exp_name}" + + benchmark_save_file: "${model.finetune.path.root_path}" diff --git a/src/mattext/models/score.py b/src/mattext/models/score.py index baacdbe..732d1c2 100644 --- a/src/mattext/models/score.py +++ b/src/mattext/models/score.py @@ -15,6 +15,8 @@ "kvrh": "matbench_log_kvrh", "gvrh": "matbench_log_gvrh", "perovskites": "matbench_perovskites", + "bandgap" : "matbench_mp_gap", + "form_energy": "matbench_mp_e_form" } MATMINER_COLUMNS = { From 7ea9c0b369973b7a0399205d12c2efbf2bf3ece0 Mon Sep 17 00:00:00 2001 From: n0w0f Date: Mon, 12 Aug 2024 21:42:56 +0200 Subject: [PATCH 03/20] configs for llama-run --- conf/bandgap.yaml | 33 +++++++ conf/bg/cifp1.yaml | 2 +- conf/bg/crystal_llm.yaml | 2 +- conf/bg/local_env.yaml | 2 +- conf/bg/slices.yaml | 2 +- conf/bg/zmatrix.yaml | 2 +- conf/form/atoms_params.yaml | 2 +- conf/form/composition.yaml | 4 +- conf/form/crystal_llm.yaml | 4 +- conf/form/local_env.yaml | 6 +- conf/form/slices.yaml | 2 +- conf/llama_8b_bg/atoms.yaml | 8 ++ conf/llama_8b_bg/atoms_params.yaml | 11 +++ conf/llama_8b_bg/cifp1.yaml | 8 ++ conf/llama_8b_bg/cifpsym.yaml | 7 ++ conf/llama_8b_bg/composition.yaml | 8 ++ conf/llama_8b_bg/crystal_llm.yaml | 7 ++ conf/llama_8b_bg/local_env.yaml | 7 ++ conf/llama_8b_bg/slices.yaml | 7 ++ conf/llama_8b_bg/zmatrix.yaml | 7 ++ conf/llm_sft.yaml | 3 - conf/model/llama_8b.yaml | 133 +++++++++++++++++++++++++++++ 22 files changed, 250 insertions(+), 17 deletions(-) create mode 100644 conf/bandgap.yaml create mode 100644 conf/llama_8b_bg/atoms.yaml create mode 100644 conf/llama_8b_bg/atoms_params.yaml create mode 100644 conf/llama_8b_bg/cifp1.yaml create mode 100644 conf/llama_8b_bg/cifpsym.yaml create mode 100644 conf/llama_8b_bg/composition.yaml create mode 100644 conf/llama_8b_bg/crystal_llm.yaml create mode 100644 conf/llama_8b_bg/local_env.yaml create mode 100644 conf/llama_8b_bg/slices.yaml create mode 100644 conf/llama_8b_bg/zmatrix.yaml create mode 100644 conf/model/llama_8b.yaml diff --git a/conf/bandgap.yaml b/conf/bandgap.yaml new file mode 100644 index 0000000..130b8fd --- /dev/null +++ b/conf/bandgap.yaml @@ -0,0 +1,33 @@ + + +hydra: + job: + name: bandgap + run: + dir: ${hydra:runtime.cwd}/outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} + sweep: + dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} + subdir: ${hydra.job.override_dirname} + + # launcher: + # _target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher + # submitit_folder: ${hydra.sweep.dir}/.submitit/%j + # timeout_min: 3600 + # mem_gb: 160 + # nodes: 1 + # #gpus_per_task: 1 + # gres: gpu:1 + # #gpus_per_node: 2 + # name: ${hydra.job.name} + # partition: 'gpu' + # additional_parameters: + # nodelist: 'gpu[008,013-017]' + # tasks_per_node: 1 + +defaults: +- model: none +# - override hydra/launcher: submitit_slurm + +runs: + - name: benchmark_run + tasks: [benchmark] \ No newline at end of file diff --git a/conf/bg/cifp1.yaml b/conf/bg/cifp1.yaml index 633f5de..51bed8e 100644 --- a/conf/bg/cifp1.yaml +++ b/conf/bg/cifp1.yaml @@ -12,6 +12,6 @@ model: model_name: revision-bg context_length: 1024 training_arguments: - per_device_train_batch_size: 64 + per_device_train_batch_size: 128 path: pretrained_checkpoint: n0w0f/MatText-cifp1-2m \ No newline at end of file diff --git a/conf/bg/crystal_llm.yaml b/conf/bg/crystal_llm.yaml index e0750e0..ce787ac 100644 --- a/conf/bg/crystal_llm.yaml +++ b/conf/bg/crystal_llm.yaml @@ -12,5 +12,5 @@ model: model_name: revision-bg context_length: 512 training_arguments: - per_device_train_batch_size: 64 + per_device_train_batch_size: 256 \ No newline at end of file diff --git a/conf/bg/local_env.yaml b/conf/bg/local_env.yaml index b26e598..15a3667 100644 --- a/conf/bg/local_env.yaml +++ b/conf/bg/local_env.yaml @@ -12,6 +12,6 @@ model: model_name: revision-bg context_length: 512 training_arguments: - per_device_train_batch_size: 64 + per_device_train_batch_size: 256 path: pretrained_checkpoint: /home/so87pot/n0w0f/structllm_ckpt/santiago_ckpt_rt/checkpoint-95000 \ No newline at end of file diff --git a/conf/bg/slices.yaml b/conf/bg/slices.yaml index 1076dfe..4a447e7 100644 --- a/conf/bg/slices.yaml +++ b/conf/bg/slices.yaml @@ -12,6 +12,6 @@ model: model_name: revision-bg context_length: 512 training_arguments: - per_device_train_batch_size: 64 + per_device_train_batch_size: 256 path: pretrained_checkpoint: n0w0f/MatText-slices-2m \ No newline at end of file diff --git a/conf/bg/zmatrix.yaml b/conf/bg/zmatrix.yaml index f25472f..5c1e96f 100644 --- a/conf/bg/zmatrix.yaml +++ b/conf/bg/zmatrix.yaml @@ -12,6 +12,6 @@ model: model_name: revision-bg context_length: 512 training_arguments: - per_device_train_batch_size: 64 + per_device_train_batch_size: 256 path: pretrained_checkpoint: n0w0f/MatText-zmatrix-2m \ No newline at end of file diff --git a/conf/form/atoms_params.yaml b/conf/form/atoms_params.yaml index 8d3ca79..42d2740 100644 --- a/conf/form/atoms_params.yaml +++ b/conf/form/atoms_params.yaml @@ -12,6 +12,6 @@ model: model_name: revision-form context_length: 32 training_arguments: - per_device_train_batch_size: 1024 + per_device_train_batch_size: 2048 \ No newline at end of file diff --git a/conf/form/composition.yaml b/conf/form/composition.yaml index 61e2211..4a2ab67 100644 --- a/conf/form/composition.yaml +++ b/conf/form/composition.yaml @@ -4,7 +4,7 @@ model: dataset: "form_energy" dataset_type: matbench special_num_token: False - checkpoint: n0w0f/MatText−composition−2m + checkpoint: n0w0f/MatText-composition-2m logging: wandb_project: revision-form @@ -12,6 +12,6 @@ model: model_name: revision-form context_length: 32 training_arguments: - per_device_train_batch_size: 1024 + per_device_train_batch_size: 2048 \ No newline at end of file diff --git a/conf/form/crystal_llm.yaml b/conf/form/crystal_llm.yaml index 4e831b0..667968d 100644 --- a/conf/form/crystal_llm.yaml +++ b/conf/form/crystal_llm.yaml @@ -4,7 +4,7 @@ model: dataset: "form_energy" dataset_type: matbench special_num_token: False - checkpoint: /home/so87pot/n0w0f/structllm_ckpt/alpaca_ckpt/checkpoint-393000 + checkpoint: /home/so87pot/n0w0f/structllm_ckpt/alpaca_ckpt/cllm/checkpoint-393000 logging: wandb_project: revision-form @@ -12,5 +12,5 @@ model: model_name: revision-form context_length: 512 training_arguments: - per_device_train_batch_size: 64 + per_device_train_batch_size: 256 \ No newline at end of file diff --git a/conf/form/local_env.yaml b/conf/form/local_env.yaml index c59cc86..0113a76 100644 --- a/conf/form/local_env.yaml +++ b/conf/form/local_env.yaml @@ -4,7 +4,7 @@ model: dataset: "form_energy" dataset_type: matbench special_num_token: False - checkpoint: /home/so87pot/n0w0f/structllm_ckpt/santiago_ckpt_rt/checkpoint-95000 + checkpoint: /home/so87pot/n0w0f/structllm_ckpt/alpaca_ckpt/local_env/checkpoint-381000 logging: wandb_project: revision-form @@ -12,6 +12,6 @@ model: model_name: revision-form context_length: 512 training_arguments: - per_device_train_batch_size: 64 + per_device_train_batch_size: 256 path: - pretrained_checkpoint: /home/so87pot/n0w0f/structllm_ckpt/santiago_ckpt_rt/checkpoint-95000 \ No newline at end of file + pretrained_checkpoint: /home/so87pot/n0w0f/structllm_ckpt/alpaca_ckpt/local_env/checkpoint-381000 \ No newline at end of file diff --git a/conf/form/slices.yaml b/conf/form/slices.yaml index 1dc7e3c..9b21975 100644 --- a/conf/form/slices.yaml +++ b/conf/form/slices.yaml @@ -12,6 +12,6 @@ model: model_name: revision-form context_length: 512 training_arguments: - per_device_train_batch_size: 64 + per_device_train_batch_size: 128 path: pretrained_checkpoint: n0w0f/MatText-slices-2m \ No newline at end of file diff --git a/conf/llama_8b_bg/atoms.yaml b/conf/llama_8b_bg/atoms.yaml new file mode 100644 index 0000000..1f243c1 --- /dev/null +++ b/conf/llama_8b_bg/atoms.yaml @@ -0,0 +1,8 @@ +# @package _global_ +model: + representation: atom_sequences + dataset: "bandgap" + dataset_type: matbench + logging: + wandb_project: llama-7B-ft + diff --git a/conf/llama_8b_bg/atoms_params.yaml b/conf/llama_8b_bg/atoms_params.yaml new file mode 100644 index 0000000..471187f --- /dev/null +++ b/conf/llama_8b_bg/atoms_params.yaml @@ -0,0 +1,11 @@ +# @package _global_ +model: + representation: atom_sequences_plusplus + dataset: "bandgap" + dataset_type: matbench + logging: + wandb_project: llama-7B-ft + + + + \ No newline at end of file diff --git a/conf/llama_8b_bg/cifp1.yaml b/conf/llama_8b_bg/cifp1.yaml new file mode 100644 index 0000000..dd9bc39 --- /dev/null +++ b/conf/llama_8b_bg/cifp1.yaml @@ -0,0 +1,8 @@ +# @package _global_ +model: + representation: cif_p1 + dataset: "bandgap" + dataset_type: matbench + logging: + wandb_project: llama-7B-ft + diff --git a/conf/llama_8b_bg/cifpsym.yaml b/conf/llama_8b_bg/cifpsym.yaml new file mode 100644 index 0000000..d03844e --- /dev/null +++ b/conf/llama_8b_bg/cifpsym.yaml @@ -0,0 +1,7 @@ +# @package _global_ +model: + representation: cif_symmetrized + dataset: "bandgap" + dataset_type: matbench + logging: + wandb_project: llama-7B-ft diff --git a/conf/llama_8b_bg/composition.yaml b/conf/llama_8b_bg/composition.yaml new file mode 100644 index 0000000..0e2b791 --- /dev/null +++ b/conf/llama_8b_bg/composition.yaml @@ -0,0 +1,8 @@ +# @package _global_ +model: + representation: composition + dataset: "bandgap" + dataset_type: matbench + logging: + wandb_project: llama-7B-ft + \ No newline at end of file diff --git a/conf/llama_8b_bg/crystal_llm.yaml b/conf/llama_8b_bg/crystal_llm.yaml new file mode 100644 index 0000000..e0cf6f0 --- /dev/null +++ b/conf/llama_8b_bg/crystal_llm.yaml @@ -0,0 +1,7 @@ +# @package _global_ +model: + representation: crystal_text_llm + dataset: "bandgap" + dataset_type: matbench + logging: + wandb_project: llama-7B-ft \ No newline at end of file diff --git a/conf/llama_8b_bg/local_env.yaml b/conf/llama_8b_bg/local_env.yaml new file mode 100644 index 0000000..8b83900 --- /dev/null +++ b/conf/llama_8b_bg/local_env.yaml @@ -0,0 +1,7 @@ +# @package _global_ +model: + representation: local_env + dataset: "bandgap" + dataset_type: matbench + logging: + wandb_project: llama-7B-ft diff --git a/conf/llama_8b_bg/slices.yaml b/conf/llama_8b_bg/slices.yaml new file mode 100644 index 0000000..32fe01f --- /dev/null +++ b/conf/llama_8b_bg/slices.yaml @@ -0,0 +1,7 @@ +# @package _global_ +model: + representation: slices + dataset: "bandgap" + dataset_type: matbench + logging: + wandb_project: llama-7B-ft diff --git a/conf/llama_8b_bg/zmatrix.yaml b/conf/llama_8b_bg/zmatrix.yaml new file mode 100644 index 0000000..401309b --- /dev/null +++ b/conf/llama_8b_bg/zmatrix.yaml @@ -0,0 +1,7 @@ +# @package _global_ +model: + representation: zmatrix + dataset: "bandgap" + dataset_type: matbench + logging: + wandb_project: llama-7B-ft diff --git a/conf/llm_sft.yaml b/conf/llm_sft.yaml index 74ce6ed..434b756 100644 --- a/conf/llm_sft.yaml +++ b/conf/llm_sft.yaml @@ -15,9 +15,6 @@ runs: - - - - name: llama_sft_run tasks: [llama_sft] diff --git a/conf/model/llama_8b.yaml b/conf/model/llama_8b.yaml new file mode 100644 index 0000000..44e7254 --- /dev/null +++ b/conf/model/llama_8b.yaml @@ -0,0 +1,133 @@ +representation: ??? +add_special_tokens: False +dataset: ??? +dataset_type: ??? +fold : 5 +data_repository: "n0w0f/MatText" +checkpoint: "meta-llama/Meta-Llama-3-8B-Instruct" +special_tokens: { + "unk_token": "[UNK]", + "pad_token": "[PAD]", + "cls_token": "[CLS]", + "sep_token": "[SEP]", + "mask_token": "[MASK]", + "eos_token": "[EOS]", + "bos_token": "[BOS]", +} + +REPRESENTATION_MAP : { + "cif_p1" : "cif_p1", + "Slice" : "slice", + } + +PROPERTY_MAP : { + "gvrh" : "shear modulus (in GPa)", + "kvrh" : "bulk modulus (in GPa)", + "dielectric" : "refractive index", + "perovskites" : "formation energy (in eV)",} + +MATERIAL_MAP : { + "gvrh" : "material", + "kvrh" : "material", + "dielectric" : "dielectric material", + "perovskites" : "perovskite material", } + + +logging: + wandb_project : test-llama + wandb_log_model : "checkpoint" + +finetune: + model_name: test-llama + freeze_base_model: False + dataprep_seed: 42 + dataset_name: "${model.dataset}-train-${model.dataset_type}" + benchmark_dataset: "${model.dataset}-test-${model.dataset_type}" + exp_name: [ + "train_${model.representation}_${model.finetune.dataset_name}", + ] + + + path: + pretrained_checkpoint: "${model.checkpoint}" + + + finetune_data_rootpath: "/work/so87pot/material_db/all_1" # <--- Change this to the path of the finetune data + finetune_traindata: [ + "${model.finetune.path.finetune_data_rootpath}/train_${model.finetune.dataset_name}_2.json", + ] + + finetune_testdata: [ + "${model.finetune.path.finetune_data_rootpath}/test_${model.finetune.dataset_name}_2.json", + ] + + root_path: "${hydra:runtime.cwd}/../../results/${now:%Y-%m-%d}/${now:%H-%M-%S}/${model.finetune.model_name}" + output_dir: "${model.finetune.path.root_path}/checkpoints/${model.finetune.exp_name}" + logging_dir: "${model.finetune.path.root_path}/logs/${model.finetune.exp_name}" + finetuned_modelname: "${model.finetune.path.root_path}/checkpoints/finetuned_${model.finetune.exp_name}" + + context_length: 1024 + callbacks: + early_stopping: False + custom_logger: False + early_stopping_patience: 5 + early_stopping_threshold: 0.001 + generation: + n_epochs: 1 + output_dir: "${model.finetune.path.output_dir}" + + bnb_config: + use_4bit: True + use_8bit: False + bnb_4bit_compute_dtype: "float16" + bnb_4bit_quant_type: "nf4" + use_nested_quant: False + + lora_config: + r: 32 + lora_alpha: 64 + lora_dropout: 0.05 + bias: "none" + task_type: "CAUSAL_LM" + #target_modules: ['q_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj', 'k_proj', 'v_proj'] # Choose all linear layers from the model + + + training_arguments: + output_dir: "${model.finetune.path.output_dir}" + bf16: True + fp16: False + overwrite_output_dir: True + dataloader_num_workers: 2 + num_train_epochs: 5 + per_device_train_batch_size: 8 + per_device_eval_batch_size: 8 + save_strategy: "steps" + do_eval: True + evaluation_strategy: 'steps' + logging_strategy: 'steps' + logging_first_step: True + save_steps: 20 # Number of epochs before saving + report_to: "wandb" + save_total_limit: 2 + logging_steps: 10 + eval_steps: 10 + seed: 42 + load_best_model_at_end: True + # Number of update steps to accumulate the gradients for + gradient_accumulation_steps : 4 + # Enable gradient checkpointing + gradient_checkpointing : True + # Maximum gradient normal (gradient clipping) + max_grad_norm : 0.3 + # Initial learning rate (AdamW optimizer) + learning_rate : 3e-4 # 0.0005 crystal-llm + # Weight decay to apply to all layers except bias/LayerNorm weights + weight_decay : 0.001 + # Optimizer to use + optim : "paged_adamw_32bit" + # Learning rate schedule + lr_scheduler_type : "cosine" + # Ratio of steps for a linear warmup (from 0 to learning rate) + warmup_ratio : 0.03 + warmup_steps : 10 + eval_accumulation_steps : 4 From 94cc3e9d063ad36823853f5a1ff36f5ff1f7bffb Mon Sep 17 00:00:00 2001 From: n0w0f Date: Tue, 13 Aug 2024 02:50:08 +0200 Subject: [PATCH 04/20] chore: update config --- conf/llama_8b_bg/atoms.yaml | 2 +- conf/llama_8b_bg/atoms_params.yaml | 2 +- conf/llama_8b_bg/cifp1.yaml | 2 +- conf/llama_8b_bg/cifpsym.yaml | 2 +- conf/llama_8b_bg/composition.yaml | 2 +- conf/llama_8b_bg/crystal_llm.yaml | 2 +- conf/llama_8b_bg/local_env.yaml | 2 +- conf/llama_8b_bg/slices.yaml | 2 +- conf/llama_8b_bg/zmatrix.yaml | 2 +- conf/model/llama_8b.yaml | 8 ++++++-- 10 files changed, 15 insertions(+), 11 deletions(-) diff --git a/conf/llama_8b_bg/atoms.yaml b/conf/llama_8b_bg/atoms.yaml index 1f243c1..6406987 100644 --- a/conf/llama_8b_bg/atoms.yaml +++ b/conf/llama_8b_bg/atoms.yaml @@ -2,7 +2,7 @@ model: representation: atom_sequences dataset: "bandgap" - dataset_type: matbench + dataset_type: filtered logging: wandb_project: llama-7B-ft diff --git a/conf/llama_8b_bg/atoms_params.yaml b/conf/llama_8b_bg/atoms_params.yaml index 471187f..efe4430 100644 --- a/conf/llama_8b_bg/atoms_params.yaml +++ b/conf/llama_8b_bg/atoms_params.yaml @@ -2,7 +2,7 @@ model: representation: atom_sequences_plusplus dataset: "bandgap" - dataset_type: matbench + dataset_type: filtered logging: wandb_project: llama-7B-ft diff --git a/conf/llama_8b_bg/cifp1.yaml b/conf/llama_8b_bg/cifp1.yaml index dd9bc39..826af92 100644 --- a/conf/llama_8b_bg/cifp1.yaml +++ b/conf/llama_8b_bg/cifp1.yaml @@ -2,7 +2,7 @@ model: representation: cif_p1 dataset: "bandgap" - dataset_type: matbench + dataset_type: filtered logging: wandb_project: llama-7B-ft diff --git a/conf/llama_8b_bg/cifpsym.yaml b/conf/llama_8b_bg/cifpsym.yaml index d03844e..86addd9 100644 --- a/conf/llama_8b_bg/cifpsym.yaml +++ b/conf/llama_8b_bg/cifpsym.yaml @@ -2,6 +2,6 @@ model: representation: cif_symmetrized dataset: "bandgap" - dataset_type: matbench + dataset_type: filtered logging: wandb_project: llama-7B-ft diff --git a/conf/llama_8b_bg/composition.yaml b/conf/llama_8b_bg/composition.yaml index 0e2b791..5289b4b 100644 --- a/conf/llama_8b_bg/composition.yaml +++ b/conf/llama_8b_bg/composition.yaml @@ -2,7 +2,7 @@ model: representation: composition dataset: "bandgap" - dataset_type: matbench + dataset_type: filtered logging: wandb_project: llama-7B-ft \ No newline at end of file diff --git a/conf/llama_8b_bg/crystal_llm.yaml b/conf/llama_8b_bg/crystal_llm.yaml index e0cf6f0..61b5d3b 100644 --- a/conf/llama_8b_bg/crystal_llm.yaml +++ b/conf/llama_8b_bg/crystal_llm.yaml @@ -2,6 +2,6 @@ model: representation: crystal_text_llm dataset: "bandgap" - dataset_type: matbench + dataset_type: filtered logging: wandb_project: llama-7B-ft \ No newline at end of file diff --git a/conf/llama_8b_bg/local_env.yaml b/conf/llama_8b_bg/local_env.yaml index 8b83900..7a25734 100644 --- a/conf/llama_8b_bg/local_env.yaml +++ b/conf/llama_8b_bg/local_env.yaml @@ -2,6 +2,6 @@ model: representation: local_env dataset: "bandgap" - dataset_type: matbench + dataset_type: filtered logging: wandb_project: llama-7B-ft diff --git a/conf/llama_8b_bg/slices.yaml b/conf/llama_8b_bg/slices.yaml index 32fe01f..b680d22 100644 --- a/conf/llama_8b_bg/slices.yaml +++ b/conf/llama_8b_bg/slices.yaml @@ -2,6 +2,6 @@ model: representation: slices dataset: "bandgap" - dataset_type: matbench + dataset_type: filtered logging: wandb_project: llama-7B-ft diff --git a/conf/llama_8b_bg/zmatrix.yaml b/conf/llama_8b_bg/zmatrix.yaml index 401309b..94734f0 100644 --- a/conf/llama_8b_bg/zmatrix.yaml +++ b/conf/llama_8b_bg/zmatrix.yaml @@ -2,6 +2,6 @@ model: representation: zmatrix dataset: "bandgap" - dataset_type: matbench + dataset_type: filtered logging: wandb_project: llama-7B-ft diff --git a/conf/model/llama_8b.yaml b/conf/model/llama_8b.yaml index 44e7254..8f175f6 100644 --- a/conf/model/llama_8b.yaml +++ b/conf/model/llama_8b.yaml @@ -24,13 +24,17 @@ PROPERTY_MAP : { "gvrh" : "shear modulus (in GPa)", "kvrh" : "bulk modulus (in GPa)", "dielectric" : "refractive index", - "perovskites" : "formation energy (in eV)",} + "perovskites" : "formation energy (in eV)", + "bandgap" : "bandgap (in eV)", + "form_energy" : "formation energy (in eV)",} MATERIAL_MAP : { "gvrh" : "material", "kvrh" : "material", "dielectric" : "dielectric material", - "perovskites" : "perovskite material", } + "perovskites" : "perovskite material", + "bandgap" : "material ", + "form_energy" : "material", } logging: From f03074ebf32df02ea22c6614b5127f3f717ed38e Mon Sep 17 00:00:00 2001 From: n0w0f Date: Tue, 13 Aug 2024 23:14:28 +0200 Subject: [PATCH 05/20] feat: add classification --- conf/classification.yaml | 33 +++++ conf/is_metal/atoms.yaml | 19 +++ conf/is_metal/atoms_params.yaml | 17 +++ conf/is_metal/cifp1.yaml | 17 +++ conf/is_metal/cifpsym.yaml | 17 +++ conf/is_metal/composition.yaml | 17 +++ conf/is_metal/crystal_llm.yaml | 16 +++ conf/is_metal/local_env.yaml | 17 +++ conf/is_metal/slices.yaml | 17 +++ conf/is_metal/zmatrix.yaml | 17 +++ conf/model/classification_example.yaml | 100 +++++++++++++ src/mattext/main.py | 11 +- src/mattext/models/benchmark.py | 113 ++++++++++++++- src/mattext/models/classification.py | 188 +++++++++++++++++++++++++ src/mattext/models/score.py | 113 ++++++++++++++- 15 files changed, 708 insertions(+), 4 deletions(-) create mode 100644 conf/classification.yaml create mode 100644 conf/is_metal/atoms.yaml create mode 100644 conf/is_metal/atoms_params.yaml create mode 100644 conf/is_metal/cifp1.yaml create mode 100644 conf/is_metal/cifpsym.yaml create mode 100644 conf/is_metal/composition.yaml create mode 100644 conf/is_metal/crystal_llm.yaml create mode 100644 conf/is_metal/local_env.yaml create mode 100644 conf/is_metal/slices.yaml create mode 100644 conf/is_metal/zmatrix.yaml create mode 100644 conf/model/classification_example.yaml create mode 100644 src/mattext/models/classification.py diff --git a/conf/classification.yaml b/conf/classification.yaml new file mode 100644 index 0000000..c5acbd4 --- /dev/null +++ b/conf/classification.yaml @@ -0,0 +1,33 @@ + + +hydra: + job: + name: is_metal + run: + dir: ${hydra:runtime.cwd}/outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} + sweep: + dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} + subdir: ${hydra.job.override_dirname} + + # launcher: + # _target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher + # submitit_folder: ${hydra.sweep.dir}/.submitit/%j + # timeout_min: 3600 + # mem_gb: 160 + # nodes: 1 + # #gpus_per_task: 1 + # gres: gpu:1 + # #gpus_per_node: 2 + # name: ${hydra.job.name} + # partition: 'gpu' + # additional_parameters: + # nodelist: 'gpu[008,013-017]' + # tasks_per_node: 1 + +defaults: +- model: none +# - override hydra/launcher: submitit_slurm + +runs: + - name: classification_run + tasks: [classification] \ No newline at end of file diff --git a/conf/is_metal/atoms.yaml b/conf/is_metal/atoms.yaml new file mode 100644 index 0000000..4c183f8 --- /dev/null +++ b/conf/is_metal/atoms.yaml @@ -0,0 +1,19 @@ +# @package _global_ +model: + representation: atom_sequences + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-atom-seq-2m + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 32 + training_arguments: + per_device_train_batch_size: 1024 + path: + pretrained_checkpoint: n0w0f/MatText-atom-seq-2m + + \ No newline at end of file diff --git a/conf/is_metal/atoms_params.yaml b/conf/is_metal/atoms_params.yaml new file mode 100644 index 0000000..728685e --- /dev/null +++ b/conf/is_metal/atoms_params.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: atom_sequences_plusplus + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-atom-seq-plusplus-2m + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 32 + training_arguments: + per_device_train_batch_size: 1024 + + \ No newline at end of file diff --git a/conf/is_metal/cifp1.yaml b/conf/is_metal/cifp1.yaml new file mode 100644 index 0000000..51bed8e --- /dev/null +++ b/conf/is_metal/cifp1.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: cif_p1 + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-cifp1-2m + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 1024 + training_arguments: + per_device_train_batch_size: 128 + path: + pretrained_checkpoint: n0w0f/MatText-cifp1-2m \ No newline at end of file diff --git a/conf/is_metal/cifpsym.yaml b/conf/is_metal/cifpsym.yaml new file mode 100644 index 0000000..6175580 --- /dev/null +++ b/conf/is_metal/cifpsym.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: cif_symmetrized + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-cifsymmetrized-2m + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 1024 + training_arguments: + per_device_train_batch_size: 64 + path: + pretrained_checkpoint: n0w0f/MatText-cifsymmetrized-2m \ No newline at end of file diff --git a/conf/is_metal/composition.yaml b/conf/is_metal/composition.yaml new file mode 100644 index 0000000..7f66ae7 --- /dev/null +++ b/conf/is_metal/composition.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: composition + dataset: "is-metal" + dataset_type: filtered + special_num_token: False + checkpoint: n0w0f/MatText-composition-2m + logging: + wandb_project: revision-bg-filtered + + finetune: + model_name: revision-bg-filtered + context_length: 32 + training_arguments: + per_device_train_batch_size: 1024 + + \ No newline at end of file diff --git a/conf/is_metal/crystal_llm.yaml b/conf/is_metal/crystal_llm.yaml new file mode 100644 index 0000000..ce787ac --- /dev/null +++ b/conf/is_metal/crystal_llm.yaml @@ -0,0 +1,16 @@ +# @package _global_ +model: + representation: crystal_text_llm + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: /home/so87pot/n0w0f/structllm_ckpt/alpaca_ckpt/checkpoint-393000 + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 512 + training_arguments: + per_device_train_batch_size: 256 + \ No newline at end of file diff --git a/conf/is_metal/local_env.yaml b/conf/is_metal/local_env.yaml new file mode 100644 index 0000000..15a3667 --- /dev/null +++ b/conf/is_metal/local_env.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: local_env + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: /home/so87pot/n0w0f/structllm_ckpt/santiago_ckpt_rt/checkpoint-95000 + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 512 + training_arguments: + per_device_train_batch_size: 256 + path: + pretrained_checkpoint: /home/so87pot/n0w0f/structllm_ckpt/santiago_ckpt_rt/checkpoint-95000 \ No newline at end of file diff --git a/conf/is_metal/slices.yaml b/conf/is_metal/slices.yaml new file mode 100644 index 0000000..4a447e7 --- /dev/null +++ b/conf/is_metal/slices.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: slices + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-slices-2m + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 512 + training_arguments: + per_device_train_batch_size: 256 + path: + pretrained_checkpoint: n0w0f/MatText-slices-2m \ No newline at end of file diff --git a/conf/is_metal/zmatrix.yaml b/conf/is_metal/zmatrix.yaml new file mode 100644 index 0000000..5c1e96f --- /dev/null +++ b/conf/is_metal/zmatrix.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: zmatrix + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-zmatrix-2m + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 512 + training_arguments: + per_device_train_batch_size: 256 + path: + pretrained_checkpoint: n0w0f/MatText-zmatrix-2m \ No newline at end of file diff --git a/conf/model/classification_example.yaml b/conf/model/classification_example.yaml new file mode 100644 index 0000000..dd96a4e --- /dev/null +++ b/conf/model/classification_example.yaml @@ -0,0 +1,100 @@ +representation: ??? +special_num_token: False +dataset: ??? +dataset_type: ??? +fold : 5 +data_repository: "n0w0f/MatText" +checkpoint: ??? +special_tokens: + { + "unk_token": "[UNK]", + "pad_token": "[PAD]", + "cls_token": "[CLS]", + "sep_token": "[SEP]", + "mask_token": "[MASK]", + "eos_token": "[EOS]", + "bos_token": "[BOS]", + } + +logging: + wandb_project: classification + wandb_log_model: "checkpoint" + +finetune: + model_name: classification + freeze_base_model: False + dataset_name: "${model.dataset}-train-${model.dataset_type}" + exp_name: + [ + "train_${model.representation}_${model.finetune.dataset_name}_0", + "train_${model.representation}_${model.finetune.dataset_name}_1", + "train_${model.representation}_${model.finetune.dataset_name}_2", + "train_${model.representation}_${model.finetune.dataset_name}_3", + "train_${model.representation}_${model.finetune.dataset_name}_4", + ] + + path: + pretrained_checkpoint: "${model.checkpoint}" + + finetune_data_rootpath: results # <--- Change this to the path of the finetune data + finetune_traindata: + [ + # "kvrh-train-filtered", + ] + + finetune_testdata: + root_path: "${hydra:runtime.cwd}/../../results/${now:%Y-%m-%d}/${now:%H-%M-%S}/${model.finetune.model_name}" # <--- Change this to the path where chkpoints and logs will be saved + output_dir: "${model.finetune.path.root_path}/checkpoints/${model.finetune.exp_name}" + logging_dir: "${model.finetune.path.root_path}/logs/${model.finetune.exp_name}" + finetuned_modelname: "${model.finetune.path.root_path}/checkpoints/finetuned_${model.finetune.exp_name}" + + context_length: 32 + dataprep_seed: 42 + callbacks: + early_stopping: True + custom_logger: True + early_stopping_patience: 10 + early_stopping_threshold: 0.001 + + training_arguments: + output_dir: "${model.finetune.path.output_dir}" + overwrite_output_dir: True + num_train_epochs: 2 + per_device_train_batch_size: 1024 + save_strategy: "epoch" + evaluation_strategy: "epoch" + logging_strategy: "epoch" + logging_first_step: True + save_steps: 3 # Number of epochs before saving + report_to: "wandb" + save_total_limit: 5 + learning_rate: 2e-4 + logging_steps: 1 + eval_steps: 1 + seed: 42 + load_best_model_at_end: True + +inference: + benchmark_dataset: "${model.dataset}-test-${model.dataset_type}" + context_length: "${model.finetune.context_length}" + exp_name: + [ + "test_${model.representation}_${model.finetune.dataset_name}_0", + "test_${model.representation}_${model.finetune.dataset_name}_1", + "test_${model.representation}_${model.finetune.dataset_name}_2", + "test_${model.representation}_${model.finetune.dataset_name}_3", + "test_${model.representation}_${model.finetune.dataset_name}_4", + ] + path: + pretrained_checkpoint: [] + test_data_rootpath: # <--- Change this to the path of the finetune data + test_data: + [ + # "kvrh-train-filtered", + ] + root_path: "/home/so87pot/n0w0f/mattext/src/mattext/models/predictions" # <--- Change this to the path where predictions will be saved + output_dir: "${model.inference.path.root_path}/checkpoints/${model.inference.exp_name}" + logging_dir: "${model.inference.path.root_path}/logs/${model.inference.exp_name}" + predictions: "${model.inference.path.root_path}/checkpoints/inference${model.inference.exp_name}" + + benchmark_save_file: "${model.finetune.path.root_path}" diff --git a/src/mattext/main.py b/src/mattext/main.py index 91120a6..5412053 100644 --- a/src/mattext/main.py +++ b/src/mattext/main.py @@ -6,7 +6,7 @@ from hydra import utils from omegaconf import DictConfig -from mattext.models.benchmark import Matbenchmark +from mattext.models.benchmark import Matbenchmark, MatbenchmarkClassification from mattext.models.finetune import FinetuneModel from mattext.models.inference import Benchmark from mattext.models.llama import FinetuneLLama @@ -22,6 +22,9 @@ def __init__(self): def run_task(self, run: list, task_cfg: DictConfig, local_rank=None) -> None: if "benchmark" in run: self.run_benchmarking(task_cfg) + + if "classification" in run: + self.run_classification(task_cfg) if "inference" in run: self.run_inference(task_cfg) @@ -44,11 +47,17 @@ def run_task(self, run: list, task_cfg: DictConfig, local_rank=None) -> None: if "potential" in run: self.run_potential(task_cfg) + def run_benchmarking(self, task_cfg: DictConfig, local_rank=None) -> None: print("Finetuning and testing on matbench dataset") matbench_predictor = Matbenchmark(task_cfg) matbench_predictor.run_benchmarking(local_rank=local_rank) + def run_classification(self, task_cfg: DictConfig, local_rank=None) -> None: + print(f"Finetuning and testing on classification task") + matbench_predictor = MatbenchmarkClassification(task_cfg) + matbench_predictor.run_benchmarking(local_rank=local_rank) + def run_qmof(self, task_cfg: DictConfig, local_rank=None) -> None: print("Finetuning on qmof") matbench_predictor = Matbenchmark(task_cfg) diff --git a/src/mattext/models/benchmark.py b/src/mattext/models/benchmark.py index 93e4734..0c9519d 100644 --- a/src/mattext/models/benchmark.py +++ b/src/mattext/models/benchmark.py @@ -5,6 +5,7 @@ from matbench.bench import MatbenchBenchmark from omegaconf import DictConfig +from mattext.models.classification import FinetuneClassificationModel from mattext.models.finetune import FinetuneModel from mattext.models.predict import Inference from mattext.models.score import MATTEXT_MATBENCH, MatTextTask @@ -57,7 +58,7 @@ def run_benchmarking(self, local_rank=None) -> None: Exception: If an error occurs during inference for a finetuned checkpoint. """ - if self.task_type == "matbench": + if self.task_type == "matbench" or self.task_type == "classification": mb = MatbenchBenchmark(autoload=False) task = getattr(mb, MATTEXT_MATBENCH[self.task]) task.load() @@ -104,7 +105,7 @@ def run_benchmarking(self, local_rank=None) -> None: predictions, prediction_ids = predict.predict() print(len(prediction_ids), len(predictions)) - if self.task_type == "matbench": + if self.task_type == "matbench" or self.task_type == "classification": task.record(i, predictions) else: task.record_fold( @@ -128,3 +129,111 @@ def run_benchmarking(self, local_rank=None) -> None: # Get final results after recording all folds # final_results = task.get_final_results() # print(final_results) + + +class MatbenchmarkClassification: + """ + Class to perform predictions on Matbench datasets. + + Args: + - task_cfg (DictConfig): Configuration dictionary containing task parameters. + """ + + def __init__(self, task_cfg: DictConfig): + """ + Initializes the object with the given task configuration. + + Parameters: + task_cfg (DictConfig): The configuration dictionary containing task parameters. + + Returns: + None + """ + self.task_cfg = task_cfg + self.representation = self.task_cfg.model.representation + self.task = self.task_cfg.model.dataset + self.task_type = self.task_cfg.model.dataset_type + self.benchmark = self.task_cfg.model.inference.benchmark_dataset + self.exp_names = self.task_cfg.model.finetune.exp_name + self.test_exp_names = self.task_cfg.model.inference.exp_name + self.train_data = self.task_cfg.model.finetune.dataset_name + self.test_data = self.task_cfg.model.inference.benchmark_dataset + self.benchmark_save_path = self.task_cfg.model.inference.benchmark_save_file + + # override wandb project name & tokenizer + self.wandb_project = self.task_cfg.model.logging.wandb_project + + def run_benchmarking(self, local_rank=None) -> None: + """ + Runs benchmarking on the specified dataset. + + Args: + local_rank (int, optional): The local rank for distributed training. Defaults to None. + + Returns: + None + + Raises: + Exception: If an error occurs during inference for a finetuned checkpoint. + + """ + + task = MatTextTask(task_name=self.task) + + for i, (exp_name, test_name) in enumerate( + zip(self.exp_names, self.test_exp_names) + ): + print( + f"Running training on {self.train_data}, and testing on {self.test_data} for fold {i}" + ) + wandb.init( + config=dict(self.task_cfg.model.finetune), + project=self.task_cfg.model.logging.wandb_project, + name=exp_name, + ) + fold_name = fold_key_namer(i) + print("-------------------------") + print(fold_name) + print("-------------------------") + + exp_cfg = self.task_cfg.copy() + exp_cfg.model.finetune.exp_name = exp_name + exp_cfg.model.finetune.path.finetune_traindata = self.train_data + + finetuner = FinetuneClassificationModel(exp_cfg, local_rank, fold=fold_name) + ckpt = finetuner.finetune() + print("-------------------------") + print(ckpt) + print("-------------------------") + + wandb.init( + config=dict(self.task_cfg.model.inference), + project=self.task_cfg.model.logging.wandb_project, + name=test_name, + ) + + exp_cfg.model.inference.path.test_data = self.test_data + exp_cfg.model.inference.path.pretrained_checkpoint = ckpt + + try: + predict = Inference(exp_cfg, fold=fold_name) + predictions, prediction_ids = predict.predict() + print(len(prediction_ids), len(predictions)) + task.record_fold( + fold=i, prediction_ids=prediction_ids, predictions=predictions + ) + + except Exception as e: + print( + f"Error occurred during inference for finetuned checkpoint '{exp_name}':" + ) + print(traceback.format_exc()) + + if not os.path.exists(self.benchmark_save_path): + os.makedirs(self.benchmark_save_path) + + file_name = os.path.join( + self.benchmark_save_path, + f"mattext_benchmark_{self.representation}_{self.benchmark}.json", + ) + task.to_file(file_name) diff --git a/src/mattext/models/classification.py b/src/mattext/models/classification.py new file mode 100644 index 0000000..630890a --- /dev/null +++ b/src/mattext/models/classification.py @@ -0,0 +1,188 @@ +from functools import partial +from typing import Any, Dict, List + +import numpy as np +import wandb +from datasets import DatasetDict, load_dataset +from omegaconf import DictConfig +from sklearn.metrics import ( + accuracy_score, + precision_recall_fscore_support, + roc_auc_score, +) +from sklearn.preprocessing import label_binarize +from torch import nn +from transformers import ( + AutoModelForSequenceClassification, + EarlyStoppingCallback, + Trainer, + TrainerCallback, + TrainingArguments, +) + +from mattext.models.utils import ( + CustomWandbCallback_FineTune, + EvaluateFirstStepCallback, + TokenizerMixin, +) + + +class FinetuneClassificationModel(TokenizerMixin): + """Class to perform finetuning of a language model. + Initialize the FinetuneModel. + + Args: + cfg (DictConfig): Configuration for the fine-tuning. + local_rank (int, optional): Local rank for distributed training. Defaults to None. + """ + + def __init__(self, cfg: DictConfig, local_rank=None, fold="fold_0") -> None: + super().__init__( + cfg=cfg.model.representation, + special_tokens=cfg.model.special_tokens, + special_num_token=cfg.model.special_num_token, + ) + self.fold = fold + self.local_rank = local_rank + self.representation = cfg.model.representation + self.data_repository = cfg.model.data_repository + self.cfg = cfg.model.finetune + self.context_length: int = self.cfg.context_length + self.callbacks = self.cfg.callbacks + self.tokenized_dataset = self._prepare_datasets( + self.cfg.path.finetune_traindata + ) + + def _prepare_datasets(self, subset: str) -> DatasetDict: + """ + Prepare training and validation datasets. + + Args: + train_df (pd.DataFrame): DataFrame containing training data. + + Returns: + DatasetDict: Dictionary containing training and validation datasets. + """ + + def replace_none(example, replacement="[PAD]"): + for key, value in example.items(): + if value is None: + example[key] = replacement + return example + + ds = load_dataset(self.data_repository, subset) + dataset = ds[self.fold].train_test_split(shuffle=True, test_size=0.2, seed=42) + dataset = dataset.filter( + lambda example: example[self.representation] is not None + ) + return dataset.map( + partial( + self._tokenize_pad_and_truncate, context_length=self.context_length + ), + batched=True, + ) + + def _callbacks(self) -> List[TrainerCallback]: + """Returns a list of callbacks for early stopping, and custom logging.""" + callbacks = [] + + if self.callbacks.early_stopping: + callbacks.append( + EarlyStoppingCallback( + early_stopping_patience=self.callbacks.early_stopping_patience, + early_stopping_threshold=self.callbacks.early_stopping_threshold, + ) + ) + + if self.callbacks.custom_logger: + callbacks.append(CustomWandbCallback_FineTune()) + + callbacks.append(EvaluateFirstStepCallback) + + return callbacks + + def _compute_metrics(self, p: Any) -> Dict[str, float]: + preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions + preds_argmax = np.argmax(preds, axis=1) + labels = p.label_ids + precision, recall, f1, _ = precision_recall_fscore_support(labels, preds_argmax, average='weighted') + acc = accuracy_score(labels, preds_argmax) + + # Compute ROC AUC + n_classes = preds.shape[1] + if n_classes == 2: + # Binary classification + roc_auc = roc_auc_score(labels, preds[:, 1]) + else: + # Multi-class classification + labels_binarized = label_binarize(labels, classes=range(n_classes)) + roc_auc = roc_auc_score(labels_binarized, preds, average='weighted', multi_class='ovr') + + return { + 'accuracy': acc, + 'f1': f1, + 'precision': precision, + 'recall': recall, + 'roc_auc': roc_auc + } + + def finetune(self) -> None: + """ + Perform fine-tuning of the language model. + """ + + pretrained_ckpt = self.cfg.path.pretrained_checkpoint + + config_train_args = self.cfg.training_arguments + callbacks = self._callbacks() + + training_args = TrainingArguments( + **config_train_args, + metric_for_best_model="f1", # or "accuracy", depending on your preference + greater_is_better=True, + ) + + model = AutoModelForSequenceClassification.from_pretrained( + pretrained_ckpt, num_labels=2, ignore_mismatched_sizes=False + ) + + if self.cfg.freeze_base_model: + for param in model.base_model.parameters(): + param.requires_grad = False + + if self.local_rank is not None: + model = model.to(self.local_rank) + model = nn.parallel.DistributedDataParallel( + model, device_ids=[self.local_rank] + ) + else: + model = model.to("cuda") + + trainer = Trainer( + model=model, + args=training_args, + data_collator=None, + compute_metrics=self._compute_metrics, + tokenizer=self._wrapped_tokenizer, + train_dataset=self.tokenized_dataset["train"], + eval_dataset=self.tokenized_dataset["test"], + callbacks=callbacks, + ) + + wandb.log({"Training Arguments": str(config_train_args)}) + wandb.log({"model_summary": str(model)}) + + trainer.train() + + eval_result = trainer.evaluate(eval_dataset=self.tokenized_dataset["test"]) + wandb.log(eval_result) + + model.save_pretrained(self.cfg.path.finetuned_modelname) + wandb.finish() + return self.cfg.path.finetuned_modelname + + def evaluate(self): + """ + Evaluate the fine-tuned model on the test dataset. + """ + ckpt = self.finetune() diff --git a/src/mattext/models/score.py b/src/mattext/models/score.py index 732d1c2..6904ecf 100644 --- a/src/mattext/models/score.py +++ b/src/mattext/models/score.py @@ -7,22 +7,30 @@ import pandas as pd from matbench.data_ops import load from sklearn.metrics import ( + accuracy_score, mean_absolute_error, mean_squared_error, + precision_recall_fscore_support, + roc_auc_score, ) +from sklearn.preprocessing import label_binarize MATTEXT_MATBENCH = { "kvrh": "matbench_log_kvrh", "gvrh": "matbench_log_gvrh", "perovskites": "matbench_perovskites", "bandgap" : "matbench_mp_gap", - "form_energy": "matbench_mp_e_form" + "form_energy": "matbench_mp_e_form", + "is-metal": "matbench_mp_is_metal", } MATMINER_COLUMNS = { "kvrh": "log10(K_VRH)", "gvrh": "log10(G_VRH)", "perovskites": "e_form", + "is-metal": "is_metal", + "bandgap": "gap pbe", + "form_energy": "e_form", } METRIC_MAP = { @@ -37,6 +45,7 @@ def fold_key_namer(fold_key): def load_true_scores(dataset, mbids): data_frame = load(MATTEXT_MATBENCH[dataset]) + print(MATMINER_COLUMNS) scores = [] for mbid in mbids: # Get the score for the mbid @@ -133,3 +142,105 @@ def _json_serializable(obj): if isinstance(obj, (np.ndarray, pd.Series)): return obj.tolist() raise TypeError(f"Type {type(obj)} not serializable") + + + + +@dataclass +class MatTextClassificationTask: + task_name: str + num_folds: int = 5 + num_classes: int = 2 + folds_results: Dict[int, Dict[str, Any]] = field(default_factory=dict) + recorded_folds: List[int] = field(default_factory=list) + + def record_fold( + self, fold: int, prediction_ids: List[str], predictions: List[float] + ): + if fold in self.recorded_folds: + raise ValueError(f"Fold {fold} has already been recorded.") + + true_labels = self.load_true_labels(self.task_name, prediction_ids) + pred_labels = np.argmax(predictions, axis=1) + + accuracy = accuracy_score(true_labels, pred_labels) + precision, recall, f1, _ = precision_recall_fscore_support(true_labels, pred_labels, average='weighted') + roc_auc = roc_auc_score(true_labels, predictions[:, 1]) + + # Compute ROC AUC + # if self.num_classes == 2: + # roc_auc = roc_auc_score(true_labels, predictions[:, 1]) + # else: + # true_labels_binarized = label_binarize(true_labels, classes=range(self.num_classes)) + # roc_auc = roc_auc_score(true_labels_binarized, predictions, average='weighted', multi_class='ovr') + + self.folds_results[fold] = { + "prediction_ids": prediction_ids, + "predictions": predictions, + "true_labels": true_labels, + "accuracy": accuracy, + "precision": precision, + "recall": recall, + "f1": f1, + "roc_auc": roc_auc + } + self.recorded_folds.append(fold) + + def get_final_results(self): + if len(self.recorded_folds) < self.num_folds: + raise ValueError( + f"All {self.num_folds} folds must be recorded before getting final results." + ) + metrics = ['accuracy', 'precision', 'recall', 'f1', 'roc_auc'] + final_scores = {metric: [] for metric in metrics} + + for fold in range(self.num_folds): + for metric in metrics: + final_scores[metric].append(self.folds_results[fold][metric]) + + return { + f"mean_{metric}": np.mean(scores) for metric, scores in final_scores.items() + } | { + f"std_{metric}": np.std(scores) for metric, scores in final_scores.items() + } + + def to_file(self, file_path: str): + final_results = ( + self.get_final_results() + if len(self.recorded_folds) == self.num_folds + else {} + ) + data_to_save = asdict(self) + data_to_save["final_results"] = final_results + with open(file_path, "w") as f: + json.dump(data_to_save, f, default=self._json_serializable) + + @staticmethod + def from_file(file_path: str): + with open(file_path) as f: + data = json.load(f) + task = MatTextClassificationTask(task_name=data["task_name"], num_classes=data["num_classes"]) + task.folds_results = data["folds_results"] + task.recorded_folds = data["recorded_folds"] + return task + + @staticmethod + def _prepare_for_serialization(obj): + if isinstance(obj, dict): + return { + k: MatTextClassificationTask._prepare_for_serialization(v) for k, v in obj.items() + } + elif isinstance(obj, (list, pd.Series, np.ndarray)): + return MatTextClassificationTask._prepare_for_serialization(obj.tolist()) + else: + return obj + + @staticmethod + def _json_serializable(obj): + if isinstance(obj, (np.ndarray, pd.Series)): + return obj.tolist() + raise TypeError(f"Type {type(obj)} not serializable") + + # @staticmethod + # def load_true_labels(dataset, mbids): + # raise NotImplementedError("load_true_labels method needs to be implemented") \ No newline at end of file From ea18cf370e0b0173cdcdb501e237ca3024f27d9a Mon Sep 17 00:00:00 2001 From: n0w0f Date: Tue, 13 Aug 2024 23:54:43 +0200 Subject: [PATCH 06/20] chore: update --- src/mattext/models/benchmark.py | 9 ++- src/mattext/models/predict.py | 136 +++++++++++++++++++++++++++++++- 2 files changed, 140 insertions(+), 5 deletions(-) diff --git a/src/mattext/models/benchmark.py b/src/mattext/models/benchmark.py index 0c9519d..81ad93d 100644 --- a/src/mattext/models/benchmark.py +++ b/src/mattext/models/benchmark.py @@ -7,7 +7,7 @@ from mattext.models.classification import FinetuneClassificationModel from mattext.models.finetune import FinetuneModel -from mattext.models.predict import Inference +from mattext.models.predict import Inference, InferenceClassification from mattext.models.score import MATTEXT_MATBENCH, MatTextTask from mattext.models.utils import fold_key_namer @@ -216,11 +216,12 @@ def run_benchmarking(self, local_rank=None) -> None: exp_cfg.model.inference.path.pretrained_checkpoint = ckpt try: - predict = Inference(exp_cfg, fold=fold_name) - predictions, prediction_ids = predict.predict() + inference = InferenceClassification(exp_cfg, fold=fold_name) + predictions, prediction_ids = inference.predict() + # task.record_fold(fold=i, prediction_ids=prediction_ids, predictions=predictions.values) print(len(prediction_ids), len(predictions)) task.record_fold( - fold=i, prediction_ids=prediction_ids, predictions=predictions + fold=i, prediction_ids=prediction_ids, predictions=predictions.values ) except Exception as e: diff --git a/src/mattext/models/predict.py b/src/mattext/models/predict.py index a7228d7..7b4f0f7 100644 --- a/src/mattext/models/predict.py +++ b/src/mattext/models/predict.py @@ -1,10 +1,17 @@ from functools import partial -from typing import List +from typing import List, Tuple +import numpy as np import pandas as pd import torch from datasets import DatasetDict, load_dataset from omegaconf import DictConfig +from sklearn.metrics import ( + accuracy_score, + precision_recall_fscore_support, + roc_auc_score, +) +from sklearn.preprocessing import label_binarize from transformers import AutoModelForSequenceClassification, Trainer, TrainerCallback from mattext.models.utils import CustomWandbCallback_Inference, TokenizerMixin @@ -81,3 +88,130 @@ def predict(self): self.prediction_ids = prediction_ids return pd.Series(predictions.predictions.flatten()), prediction_ids + + +class InferenceClassification(TokenizerMixin): + """Class to perform inference on a language model with a sequence classification head for classification tasks.""" + + def __init__(self, cfg: DictConfig, fold="fold_0"): + super().__init__( + cfg=cfg.model.representation, + special_tokens=cfg.model.special_tokens, + special_num_token=cfg.model.special_num_token, + ) + self.fold = fold + self.representation = cfg.model.representation + self.data_repository = cfg.model.data_repository + self.dataset_name = cfg.model.finetune.dataset_name + self.cfg = cfg.model.inference + self.context_length: int = self.cfg.context_length + self.num_labels = cfg.model.num_labels + self.tokenized_test_datasets = self._prepare_datasets(self.cfg.path.test_data) + self.prediction_ids = None + + def _prepare_datasets(self, path: str) -> DatasetDict: + """ + Prepare test datasets. + + Args: + path (str): Path to the test data. + + Returns: + DatasetDict: Dictionary containing the test dataset. + """ + dataset = load_dataset(self.data_repository, path) + filtered_dataset = dataset[self.fold].filter( + lambda example: example[self.representation] is not None + ) + + return filtered_dataset.map( + partial( + self._tokenize_pad_and_truncate, context_length=self.context_length + ), + batched=True, + ) + + def _callbacks(self) -> List[TrainerCallback]: + """Returns a list of callbacks for logging.""" + return [CustomWandbCallback_Inference()] + + def predict(self) -> Tuple[pd.DataFrame, List[str]]: + """ + Perform prediction on the test dataset. + + Returns: + Tuple[pd.DataFrame, List[str]]: A tuple containing the predictions as a DataFrame + and the prediction IDs as a list. + """ + pretrained_ckpt = self.cfg.path.pretrained_checkpoint + callbacks = self._callbacks() + + model = AutoModelForSequenceClassification.from_pretrained( + pretrained_ckpt, num_labels=self.num_labels, ignore_mismatched_sizes=False + ) + + trainer = Trainer( + model=model.to("cuda"), data_collator=None, callbacks=callbacks + ) + + predictions = trainer.predict(self.tokenized_test_datasets) + for callback in callbacks: + callback.on_predict_end( + None, None, None, model, predictions + ) # Manually trigger callback + torch.cuda.empty_cache() + + prediction_ids = self.tokenized_test_datasets["mbid"] + self.prediction_ids = prediction_ids + + # Convert predictions to probabilities + probabilities = torch.nn.functional.softmax( + torch.from_numpy(predictions.predictions), dim=-1 + ).numpy() + + # Create a DataFrame with prediction probabilities + prediction_df = pd.DataFrame( + probabilities, columns=[f"class_{i}" for i in range(self.num_labels)] + ) + + return prediction_df, prediction_ids + + def evaluate(self, true_labels: List[int]) -> dict: + """ + Evaluate the model's predictions against true labels. + + Args: + true_labels (List[int]): The true labels for the test set. + + Returns: + dict: A dictionary containing evaluation metrics. + """ + + predictions, _ = self.predict() + pred_labels = np.argmax(predictions.values, axis=1) + + accuracy = accuracy_score(true_labels, pred_labels) + precision, recall, f1, _ = precision_recall_fscore_support( + true_labels, pred_labels, average="weighted" + ) + + if self.num_labels == 2: + roc_auc = roc_auc_score(true_labels, predictions.iloc[:, 1]) + else: + true_labels_binarized = label_binarize( + true_labels, classes=range(self.num_labels) + ) + roc_auc = roc_auc_score( + true_labels_binarized, + predictions, + average="weighted", + multi_class="ovr", + ) + + return { + "accuracy": accuracy, + "precision": precision, + "recall": recall, + "f1": f1, + "roc_auc": roc_auc, + } From 3cafe208d2ba69c4ee2d61f79749804eb8cc9d1b Mon Sep 17 00:00:00 2001 From: n0w0f Date: Wed, 14 Aug 2024 01:07:56 +0200 Subject: [PATCH 07/20] fix: classification benchmarking --- src/mattext/models/benchmark.py | 7 ++++--- src/mattext/models/predict.py | 5 +++-- src/mattext/models/score.py | 30 +++++++----------------------- 3 files changed, 14 insertions(+), 28 deletions(-) diff --git a/src/mattext/models/benchmark.py b/src/mattext/models/benchmark.py index 81ad93d..38695b1 100644 --- a/src/mattext/models/benchmark.py +++ b/src/mattext/models/benchmark.py @@ -8,7 +8,7 @@ from mattext.models.classification import FinetuneClassificationModel from mattext.models.finetune import FinetuneModel from mattext.models.predict import Inference, InferenceClassification -from mattext.models.score import MATTEXT_MATBENCH, MatTextTask +from mattext.models.score import MATTEXT_MATBENCH, MatTextTask, MatTextClassificationTask from mattext.models.utils import fold_key_namer @@ -103,7 +103,6 @@ def run_benchmarking(self, local_rank=None) -> None: try: predict = Inference(exp_cfg, fold=fold_name) predictions, prediction_ids = predict.predict() - print(len(prediction_ids), len(predictions)) if self.task_type == "matbench" or self.task_type == "classification": task.record(i, predictions) @@ -178,7 +177,7 @@ def run_benchmarking(self, local_rank=None) -> None: """ - task = MatTextTask(task_name=self.task) + task = MatTextClassificationTask(task_name=self.task) for i, (exp_name, test_name) in enumerate( zip(self.exp_names, self.test_exp_names) @@ -219,7 +218,9 @@ def run_benchmarking(self, local_rank=None) -> None: inference = InferenceClassification(exp_cfg, fold=fold_name) predictions, prediction_ids = inference.predict() # task.record_fold(fold=i, prediction_ids=prediction_ids, predictions=predictions.values) + print("--------------------") print(len(prediction_ids), len(predictions)) + print("---------------------") task.record_fold( fold=i, prediction_ids=prediction_ids, predictions=predictions.values ) diff --git a/src/mattext/models/predict.py b/src/mattext/models/predict.py index 7b4f0f7..b537299 100644 --- a/src/mattext/models/predict.py +++ b/src/mattext/models/predict.py @@ -105,7 +105,8 @@ def __init__(self, cfg: DictConfig, fold="fold_0"): self.dataset_name = cfg.model.finetune.dataset_name self.cfg = cfg.model.inference self.context_length: int = self.cfg.context_length - self.num_labels = cfg.model.num_labels + #self.num_labels = cfg.model.num_labels + self.num_labels = 2 self.tokenized_test_datasets = self._prepare_datasets(self.cfg.path.test_data) self.prediction_ids = None @@ -169,7 +170,7 @@ def predict(self) -> Tuple[pd.DataFrame, List[str]]: torch.from_numpy(predictions.predictions), dim=-1 ).numpy() - # Create a DataFrame with prediction probabilities + #Create a DataFrame with prediction probabilities prediction_df = pd.DataFrame( probabilities, columns=[f"class_{i}" for i in range(self.num_labels)] ) diff --git a/src/mattext/models/score.py b/src/mattext/models/score.py index 6904ecf..dc4f32b 100644 --- a/src/mattext/models/score.py +++ b/src/mattext/models/score.py @@ -160,20 +160,13 @@ def record_fold( if fold in self.recorded_folds: raise ValueError(f"Fold {fold} has already been recorded.") - true_labels = self.load_true_labels(self.task_name, prediction_ids) + true_labels = load_true_scores(self.task_name, prediction_ids) pred_labels = np.argmax(predictions, axis=1) accuracy = accuracy_score(true_labels, pred_labels) precision, recall, f1, _ = precision_recall_fscore_support(true_labels, pred_labels, average='weighted') roc_auc = roc_auc_score(true_labels, predictions[:, 1]) - # Compute ROC AUC - # if self.num_classes == 2: - # roc_auc = roc_auc_score(true_labels, predictions[:, 1]) - # else: - # true_labels_binarized = label_binarize(true_labels, classes=range(self.num_classes)) - # roc_auc = roc_auc_score(true_labels_binarized, predictions, average='weighted', multi_class='ovr') - self.folds_results[fold] = { "prediction_ids": prediction_ids, "predictions": predictions, @@ -224,23 +217,14 @@ def from_file(file_path: str): task.recorded_folds = data["recorded_folds"] return task - @staticmethod - def _prepare_for_serialization(obj): - if isinstance(obj, dict): - return { - k: MatTextClassificationTask._prepare_for_serialization(v) for k, v in obj.items() - } - elif isinstance(obj, (list, pd.Series, np.ndarray)): - return MatTextClassificationTask._prepare_for_serialization(obj.tolist()) - else: - return obj - @staticmethod def _json_serializable(obj): if isinstance(obj, (np.ndarray, pd.Series)): return obj.tolist() + elif isinstance(obj, np.bool_): + return bool(obj) + elif isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) raise TypeError(f"Type {type(obj)} not serializable") - - # @staticmethod - # def load_true_labels(dataset, mbids): - # raise NotImplementedError("load_true_labels method needs to be implemented") \ No newline at end of file From 0dfe9eeb444399ba5ea4068c5794086bc2326f67 Mon Sep 17 00:00:00 2001 From: Kevin M Jablonka <32935233+kjappelbaum@users.noreply.github.com> Date: Wed, 14 Aug 2024 22:57:11 +0200 Subject: [PATCH 08/20] Update src/mattext/main.py Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> --- src/mattext/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mattext/main.py b/src/mattext/main.py index 5412053..2051728 100644 --- a/src/mattext/main.py +++ b/src/mattext/main.py @@ -54,7 +54,7 @@ def run_benchmarking(self, task_cfg: DictConfig, local_rank=None) -> None: matbench_predictor.run_benchmarking(local_rank=local_rank) def run_classification(self, task_cfg: DictConfig, local_rank=None) -> None: - print(f"Finetuning and testing on classification task") + print("Finetuning and testing on classification task") matbench_predictor = MatbenchmarkClassification(task_cfg) matbench_predictor.run_benchmarking(local_rank=local_rank) From 0de6ddae90e5cc1196dd570907f9f51e8a7a3d2c Mon Sep 17 00:00:00 2001 From: Kevin M Jablonka <32935233+kjappelbaum@users.noreply.github.com> Date: Wed, 14 Aug 2024 22:57:24 +0200 Subject: [PATCH 09/20] Update revision-scripts/mp_classification.py Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> --- revision-scripts/mp_classification.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/revision-scripts/mp_classification.py b/revision-scripts/mp_classification.py index 03d5cca..219caec 100644 --- a/revision-scripts/mp_classification.py +++ b/revision-scripts/mp_classification.py @@ -21,8 +21,7 @@ def __len__(self): def get(self, index): id = f"{index}".encode("ascii") - datapoint = pickle.loads(self.txn.get(id)) - return datapoint + return pickle.loads(self.txn.get(id)) def create_json_from_lmdb(lmdb_path, output_dir): dataset = Dataset(lmdb_path) From 07ef612be8f11bc3ba8f19cbaa34df6fbde9949e Mon Sep 17 00:00:00 2001 From: Kevin M Jablonka <32935233+kjappelbaum@users.noreply.github.com> Date: Wed, 14 Aug 2024 22:57:47 +0200 Subject: [PATCH 10/20] Update revision-scripts/matbench_is_metal.py Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> --- revision-scripts/matbench_is_metal.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/revision-scripts/matbench_is_metal.py b/revision-scripts/matbench_is_metal.py index 0949444..76b2c64 100644 --- a/revision-scripts/matbench_is_metal.py +++ b/revision-scripts/matbench_is_metal.py @@ -20,9 +20,7 @@ def convert_structure_to_serializable(pymatgen_structure): - # Assuming Structure has 'data' and 'metadata' attributes - cif_content = pymatgen_structure.to(fmt="cif") - return cif_content + return pymatgen_structure.to(fmt="cif") @hydra.main(version_base=None, config_path="../conf", config_name="config") From d87272233a7321b865987b6337e416052a6ba304 Mon Sep 17 00:00:00 2001 From: n0w0f Date: Tue, 20 Aug 2024 13:11:32 +0200 Subject: [PATCH 11/20] chore: remove deduplication, dictionary mapping task, multifold under run_experiment --- src/mattext/main.py | 188 ++++++++++++++++++++------------------------ 1 file changed, 84 insertions(+), 104 deletions(-) diff --git a/src/mattext/main.py b/src/mattext/main.py index 2051728..e193754 100644 --- a/src/mattext/main.py +++ b/src/mattext/main.py @@ -1,4 +1,5 @@ import os +from typing import Callable, Union import hydra import wandb @@ -18,45 +19,86 @@ class TaskRunner: def __init__(self): self.wandb_api_key = os.environ.get("WANDB_API_KEY") + self.task_map = { + "benchmark": self.run_benchmarking, + "classification": self.run_classification, + "inference": self.run_inference, + "finetune": self.run_finetuning, + "pretrain": self.run_pretraining, + "qmof": self.run_qmof, + "llama": self.run_llama, + "llama_sft": self.run_llama_sft, + "potential": self.run_potential, + } def run_task(self, run: list, task_cfg: DictConfig, local_rank=None) -> None: - if "benchmark" in run: - self.run_benchmarking(task_cfg) - - if "classification" in run: - self.run_classification(task_cfg) - - if "inference" in run: - self.run_inference(task_cfg) - - if "finetune" in run: - self.run_finetuning(task_cfg) - - if "pretrain" in run: - self.run_pretraining(task_cfg) - - if "qmof" in run: - self.run_qmof(task_cfg) + for task in run: + if task in self.task_map: + self.task_map[task](task_cfg, local_rank) + else: + print(f"Unknown task: {task}") + + def _run_experiment( + self, + task_cfg: DictConfig, + local_rank: Union[int, None], + model_class: Callable, + experiment_type: str, + use_folds: bool = False, + use_train_data_path: bool = False, + ): + if use_folds: + iterations = range(task_cfg.model.fold) + elif use_train_data_path: + iterations = zip( + task_cfg.model.finetune.exp_name, + task_cfg.model.finetune.path.finetune_traindata, + ) + else: + iterations = [None] + + for item in iterations: + if use_folds: + exp_name = f"{task_cfg.model.finetune.exp_name}_fold_{item}" + fold = f"fold_{item}" + elif use_train_data_path: + exp_name, train_data_path = item + fold = None + else: + exp_name = task_cfg.model[experiment_type].exp_name + fold = None - if "llama" in run: - self.run_llama(task_cfg, local_rank=local_rank) + wandb.init( + config=dict(task_cfg.model[experiment_type]), + project=task_cfg.model.logging.wandb_project, + name=exp_name, + ) - if "llama_sft" in run: - self.run_llama_sft(task_cfg, local_rank=local_rank) + exp_cfg = task_cfg.copy() + exp_cfg.model[experiment_type].exp_name = exp_name + if use_train_data_path: + exp_cfg.model.finetune.path.finetune_traindata = train_data_path - if "potential" in run: - self.run_potential(task_cfg) + if fold: + model = model_class(exp_cfg, local_rank, fold=fold) + else: + model = model_class(exp_cfg, local_rank) + result = ( + model.finetune() if hasattr(model, "finetune") else model.pretrain_mlm() + ) + print(result) + wandb.finish() def run_benchmarking(self, task_cfg: DictConfig, local_rank=None) -> None: - print("Finetuning and testing on matbench dataset") - matbench_predictor = Matbenchmark(task_cfg) - matbench_predictor.run_benchmarking(local_rank=local_rank) + print("Benchmarking") + benchmark = Matbenchmark(task_cfg) + benchmark.run_benchmarking(local_rank=local_rank) def run_classification(self, task_cfg: DictConfig, local_rank=None) -> None: - print("Finetuning and testing on classification task") - matbench_predictor = MatbenchmarkClassification(task_cfg) - matbench_predictor.run_benchmarking(local_rank=local_rank) + print("Benchmarking Classification") + benchmark = MatbenchmarkClassification(task_cfg) + benchmark.run_benchmarking(local_rank=local_rank) def run_qmof(self, task_cfg: DictConfig, local_rank=None) -> None: print("Finetuning on qmof") @@ -69,89 +111,27 @@ def run_inference(self, task_cfg: DictConfig, local_rank=None) -> None: matbench_predictor.run_benchmarking(local_rank=local_rank) def run_llama(self, task_cfg: DictConfig, local_rank=None) -> None: - for exp_name, train_data_path in zip( - task_cfg.model.finetune.exp_name, - task_cfg.model.finetune.path.finetune_traindata, - ): - wandb.init( - config=dict(task_cfg.model.finetune), - project=task_cfg.model.logging.wandb_project, - name=exp_name, - ) - - exp_cfg = task_cfg.copy() - exp_cfg.model.finetune.exp_name = exp_name - exp_cfg.model.finetune.path.finetune_traindata = train_data_path - - finetuner = FinetuneLLama(exp_cfg, local_rank) - f = finetuner.finetune() - print(f) - wandb.finish() + self._run_experiment( + task_cfg, local_rank, FinetuneLLama, "finetune", use_train_data_path=True + ) def run_llama_sft(self, task_cfg: DictConfig, local_rank=None) -> None: - for fold in range(task_cfg.model.fold): - exp_name = f"{task_cfg.model.finetune.exp_name}_fold_{fold}" - wandb.init( - config=dict(task_cfg.model.finetune), - project=task_cfg.model.logging.wandb_project, - name=exp_name, - ) - - exp_cfg = task_cfg.copy() - exp_cfg.model.finetune.exp_name = exp_name - - finetuner = FinetuneLLamaSFT(exp_cfg, local_rank, fold=f"fold_{fold}") - f = finetuner.finetune() - print(f) - wandb.finish() + self._run_experiment( + task_cfg, local_rank, FinetuneLLamaSFT, "finetune", use_folds=True + ) def run_finetuning(self, task_cfg: DictConfig, local_rank=None) -> None: - for exp_name, train_data_path in zip( - task_cfg.model.finetune.exp_name, - task_cfg.model.finetune.path.finetune_traindata, - ): - wandb.init( - config=dict(task_cfg.model.finetune), - project=task_cfg.logging.wandb_project, - name=exp_name, - ) - - exp_cfg = task_cfg.copy() - exp_cfg.model.finetune.exp_name = exp_name - exp_cfg.model.finetune.path.finetune_traindata = train_data_path - - finetuner = FinetuneModel(exp_cfg, local_rank) - finetuner.finetune() - wandb.finish() + self._run_experiment( + task_cfg, local_rank, FinetuneModel, "finetune", use_train_data_path=True + ) def run_potential(self, task_cfg: DictConfig, local_rank=None) -> None: - for exp_name, train_data_path in zip( - task_cfg.model.finetune.exp_name, - task_cfg.model.finetune.path.finetune_traindata, - ): - wandb.init( - config=dict(task_cfg.model.finetune), - project=task_cfg.model.logging.wandb_project, - name=exp_name, - ) - - exp_cfg = task_cfg.copy() - exp_cfg.model.finetune.exp_name = exp_name - exp_cfg.model.finetune.path.finetune_traindata = train_data_path - - finetuner = PotentialModel(exp_cfg, local_rank) - finetuner.finetune() - wandb.finish() + self._run_experiment( + task_cfg, local_rank, PotentialModel, "finetune", use_train_data_path=True + ) def run_pretraining(self, task_cfg: DictConfig, local_rank=None) -> None: - wandb.init( - config=dict(task_cfg.model.pretrain), - project=task_cfg.model.logging.wandb_project, - name=task_cfg.model.pretrain.exp_name, - ) - print(task_cfg) - pretrainer = PretrainModel(task_cfg, local_rank) - pretrainer.pretrain_mlm() + self._run_experiment(task_cfg, local_rank, PretrainModel, "pretrain") def initialize_wandb(self): if self.wandb_api_key: From f4e8f4ee7b22d700cc6b89c5db247a9aae8e24ee Mon Sep 17 00:00:00 2001 From: n0w0f Date: Tue, 20 Aug 2024 13:17:10 +0200 Subject: [PATCH 12/20] chore: abstract out benchmarking to a base class for reg and classification --- src/mattext/models/benchmark.py | 259 +++++++++++--------------------- 1 file changed, 91 insertions(+), 168 deletions(-) diff --git a/src/mattext/models/benchmark.py b/src/mattext/models/benchmark.py index 38695b1..73d4585 100644 --- a/src/mattext/models/benchmark.py +++ b/src/mattext/models/benchmark.py @@ -1,5 +1,6 @@ import os import traceback +from abc import ABC, abstractmethod import wandb from matbench.bench import MatbenchBenchmark @@ -8,28 +9,16 @@ from mattext.models.classification import FinetuneClassificationModel from mattext.models.finetune import FinetuneModel from mattext.models.predict import Inference, InferenceClassification -from mattext.models.score import MATTEXT_MATBENCH, MatTextTask, MatTextClassificationTask +from mattext.models.score import ( + MATTEXT_MATBENCH, + MatTextClassificationTask, + MatTextTask, +) from mattext.models.utils import fold_key_namer -class Matbenchmark: - """ - Class to perform predictions on Matbench datasets. - - Args: - - task_cfg (DictConfig): Configuration dictionary containing task parameters. - """ - +class BaseBenchmark(ABC): def __init__(self, task_cfg: DictConfig): - """ - Initializes the object with the given task configuration. - - Parameters: - task_cfg (DictConfig): The configuration dictionary containing task parameters. - - Returns: - None - """ self.task_cfg = task_cfg self.representation = self.task_cfg.model.representation self.task = self.task_cfg.model.dataset @@ -40,83 +29,72 @@ def __init__(self, task_cfg: DictConfig): self.train_data = self.task_cfg.model.finetune.dataset_name self.test_data = self.task_cfg.model.inference.benchmark_dataset self.benchmark_save_path = self.task_cfg.model.inference.benchmark_save_file - - # override wandb project name & tokenizer self.wandb_project = self.task_cfg.model.logging.wandb_project + @abstractmethod def run_benchmarking(self, local_rank=None) -> None: - """ - Runs benchmarking on the specified dataset. + pass - Args: - local_rank (int, optional): The local rank for distributed training. Defaults to None. - - Returns: - None - - Raises: - Exception: If an error occurs during inference for a finetuned checkpoint. - - """ - if self.task_type == "matbench" or self.task_type == "classification": + def _initialize_task(self): + if self.task_type == "matbench": mb = MatbenchBenchmark(autoload=False) task = getattr(mb, MATTEXT_MATBENCH[self.task]) task.load() else: task = MatTextTask(task_name=self.task) + return task - for i, (exp_name, test_name) in enumerate( - zip(self.exp_names, self.test_exp_names) - ): - print( - f"Running training on {self.train_data}, and testing on {self.test_data} for fold {i}" - ) - wandb.init( - config=dict(self.task_cfg.model.finetune), - project=self.task_cfg.model.logging.wandb_project, - name=exp_name, - ) - fold_name = fold_key_namer(i) - print("-------------------------") - print(fold_name) - print("-------------------------") - - exp_cfg = self.task_cfg.copy() - exp_cfg.model.finetune.exp_name = exp_name - exp_cfg.model.finetune.path.finetune_traindata = self.train_data + def _run_experiment(self, task, i, exp_name, test_name, local_rank): + fold_name = fold_key_namer(i) + print( + f"Running training on {self.train_data}, and testing on {self.test_data} for fold {i}" + ) + print("-------------------------") + print(fold_name) + print("-------------------------") + + exp_cfg = self.task_cfg.copy() + exp_cfg.model.finetune.exp_name = exp_name + exp_cfg.model.finetune.path.finetune_traindata = self.train_data + + finetuner = self._get_finetuner(exp_cfg, local_rank, fold_name) + ckpt = finetuner.finetune() + print("-------------------------") + print(ckpt) + print("-------------------------") + + wandb.init( + config=dict(self.task_cfg.model.inference), + project=self.task_cfg.model.logging.wandb_project, + name=test_name, + ) - finetuner = FinetuneModel(exp_cfg, local_rank, fold=fold_name) - ckpt = finetuner.finetune() - print("-------------------------") - print(ckpt) - print("-------------------------") + exp_cfg.model.inference.path.test_data = self.test_data + exp_cfg.model.inference.path.pretrained_checkpoint = ckpt - wandb.init( - config=dict(self.task_cfg.model.inference), - project=self.task_cfg.model.logging.wandb_project, - name=test_name, + try: + predict = self._get_inference(exp_cfg, fold_name) + predictions, prediction_ids = predict.predict() + self._record_predictions(task, i, predictions, prediction_ids) + except Exception as e: + print( + f"Error occurred during inference for finetuned checkpoint '{exp_name}':" ) + print(traceback.format_exc()) - exp_cfg.model.inference.path.test_data = self.test_data - exp_cfg.model.inference.path.pretrained_checkpoint = ckpt + @abstractmethod + def _get_finetuner(self, exp_cfg, local_rank, fold_name): + pass - try: - predict = Inference(exp_cfg, fold=fold_name) - predictions, prediction_ids = predict.predict() + @abstractmethod + def _get_inference(self, exp_cfg, fold_name): + pass - if self.task_type == "matbench" or self.task_type == "classification": - task.record(i, predictions) - else: - task.record_fold( - fold=i, prediction_ids=prediction_ids, predictions=predictions - ) - - except Exception as e: - print( - f"Error occurred during inference for finetuned checkpoint '{exp_name}':" - ) - print(traceback.format_exc()) + @abstractmethod + def _record_predictions(self, task, fold, predictions, prediction_ids): + pass + def _save_results(self, task): if not os.path.exists(self.benchmark_save_path): os.makedirs(self.benchmark_save_path) @@ -125,117 +103,62 @@ def run_benchmarking(self, local_rank=None) -> None: f"mattext_benchmark_{self.representation}_{self.benchmark}.json", ) task.to_file(file_name) - # Get final results after recording all folds - # final_results = task.get_final_results() - # print(final_results) - -class MatbenchmarkClassification: - """ - Class to perform predictions on Matbench datasets. - Args: - - task_cfg (DictConfig): Configuration dictionary containing task parameters. - """ - - def __init__(self, task_cfg: DictConfig): - """ - Initializes the object with the given task configuration. - - Parameters: - task_cfg (DictConfig): The configuration dictionary containing task parameters. - - Returns: - None - """ - self.task_cfg = task_cfg - self.representation = self.task_cfg.model.representation - self.task = self.task_cfg.model.dataset - self.task_type = self.task_cfg.model.dataset_type - self.benchmark = self.task_cfg.model.inference.benchmark_dataset - self.exp_names = self.task_cfg.model.finetune.exp_name - self.test_exp_names = self.task_cfg.model.inference.exp_name - self.train_data = self.task_cfg.model.finetune.dataset_name - self.test_data = self.task_cfg.model.inference.benchmark_dataset - self.benchmark_save_path = self.task_cfg.model.inference.benchmark_save_file +class Matbenchmark(BaseBenchmark): + def run_benchmarking(self, local_rank=None) -> None: + task = self._initialize_task() - # override wandb project name & tokenizer - self.wandb_project = self.task_cfg.model.logging.wandb_project + for i, (exp_name, test_name) in enumerate( + zip(self.exp_names, self.test_exp_names) + ): + wandb.init( + config=dict(self.task_cfg.model.finetune), + project=self.task_cfg.model.logging.wandb_project, + name=exp_name, + ) + self._run_experiment(task, i, exp_name, test_name, local_rank) - def run_benchmarking(self, local_rank=None) -> None: - """ - Runs benchmarking on the specified dataset. + self._save_results(task) - Args: - local_rank (int, optional): The local rank for distributed training. Defaults to None. + def _get_finetuner(self, exp_cfg, local_rank, fold_name): + return FinetuneModel(exp_cfg, local_rank, fold=fold_name) - Returns: - None + def _get_inference(self, exp_cfg, fold_name): + return Inference(exp_cfg, fold=fold_name) - Raises: - Exception: If an error occurs during inference for a finetuned checkpoint. + def _record_predictions(self, task, fold, predictions, prediction_ids): + if self.task_type == "matbench": + task.record(fold, predictions) + else: + task.record_fold( + fold=fold, prediction_ids=prediction_ids, predictions=predictions + ) - """ +class MatbenchmarkClassification(BaseBenchmark): + def run_benchmarking(self, local_rank=None) -> None: task = MatTextClassificationTask(task_name=self.task) for i, (exp_name, test_name) in enumerate( zip(self.exp_names, self.test_exp_names) ): - print( - f"Running training on {self.train_data}, and testing on {self.test_data} for fold {i}" - ) wandb.init( config=dict(self.task_cfg.model.finetune), project=self.task_cfg.model.logging.wandb_project, name=exp_name, ) - fold_name = fold_key_namer(i) - print("-------------------------") - print(fold_name) - print("-------------------------") - - exp_cfg = self.task_cfg.copy() - exp_cfg.model.finetune.exp_name = exp_name - exp_cfg.model.finetune.path.finetune_traindata = self.train_data + self._run_experiment(task, i, exp_name, test_name, local_rank) - finetuner = FinetuneClassificationModel(exp_cfg, local_rank, fold=fold_name) - ckpt = finetuner.finetune() - print("-------------------------") - print(ckpt) - print("-------------------------") - - wandb.init( - config=dict(self.task_cfg.model.inference), - project=self.task_cfg.model.logging.wandb_project, - name=test_name, - ) + self._save_results(task) - exp_cfg.model.inference.path.test_data = self.test_data - exp_cfg.model.inference.path.pretrained_checkpoint = ckpt - - try: - inference = InferenceClassification(exp_cfg, fold=fold_name) - predictions, prediction_ids = inference.predict() - # task.record_fold(fold=i, prediction_ids=prediction_ids, predictions=predictions.values) - print("--------------------") - print(len(prediction_ids), len(predictions)) - print("---------------------") - task.record_fold( - fold=i, prediction_ids=prediction_ids, predictions=predictions.values - ) - - except Exception as e: - print( - f"Error occurred during inference for finetuned checkpoint '{exp_name}':" - ) - print(traceback.format_exc()) + def _get_finetuner(self, exp_cfg, local_rank, fold_name): + return FinetuneClassificationModel(exp_cfg, local_rank, fold=fold_name) - if not os.path.exists(self.benchmark_save_path): - os.makedirs(self.benchmark_save_path) + def _get_inference(self, exp_cfg, fold_name): + return InferenceClassification(exp_cfg, fold=fold_name) - file_name = os.path.join( - self.benchmark_save_path, - f"mattext_benchmark_{self.representation}_{self.benchmark}.json", + def _record_predictions(self, task, fold, predictions, prediction_ids): + task.record_fold( + fold=fold, prediction_ids=prediction_ids, predictions=predictions.values ) - task.to_file(file_name) From 5e5bdd10e79729c7f5379141b9d5c8934deefb42 Mon Sep 17 00:00:00 2001 From: n0w0f Date: Tue, 20 Aug 2024 13:23:10 +0200 Subject: [PATCH 13/20] chore: abstract out finetuning to base class for reg and classification --- src/mattext/models/benchmark.py | 3 +- src/mattext/models/classification.py | 188 --------------------------- src/mattext/models/finetune.py | 144 +++++++++++--------- 3 files changed, 87 insertions(+), 248 deletions(-) delete mode 100644 src/mattext/models/classification.py diff --git a/src/mattext/models/benchmark.py b/src/mattext/models/benchmark.py index 73d4585..2daa457 100644 --- a/src/mattext/models/benchmark.py +++ b/src/mattext/models/benchmark.py @@ -6,8 +6,7 @@ from matbench.bench import MatbenchBenchmark from omegaconf import DictConfig -from mattext.models.classification import FinetuneClassificationModel -from mattext.models.finetune import FinetuneModel +from mattext.models.finetune import FinetuneModel, FinetuneClassificationModel from mattext.models.predict import Inference, InferenceClassification from mattext.models.score import ( MATTEXT_MATBENCH, diff --git a/src/mattext/models/classification.py b/src/mattext/models/classification.py deleted file mode 100644 index 630890a..0000000 --- a/src/mattext/models/classification.py +++ /dev/null @@ -1,188 +0,0 @@ -from functools import partial -from typing import Any, Dict, List - -import numpy as np -import wandb -from datasets import DatasetDict, load_dataset -from omegaconf import DictConfig -from sklearn.metrics import ( - accuracy_score, - precision_recall_fscore_support, - roc_auc_score, -) -from sklearn.preprocessing import label_binarize -from torch import nn -from transformers import ( - AutoModelForSequenceClassification, - EarlyStoppingCallback, - Trainer, - TrainerCallback, - TrainingArguments, -) - -from mattext.models.utils import ( - CustomWandbCallback_FineTune, - EvaluateFirstStepCallback, - TokenizerMixin, -) - - -class FinetuneClassificationModel(TokenizerMixin): - """Class to perform finetuning of a language model. - Initialize the FinetuneModel. - - Args: - cfg (DictConfig): Configuration for the fine-tuning. - local_rank (int, optional): Local rank for distributed training. Defaults to None. - """ - - def __init__(self, cfg: DictConfig, local_rank=None, fold="fold_0") -> None: - super().__init__( - cfg=cfg.model.representation, - special_tokens=cfg.model.special_tokens, - special_num_token=cfg.model.special_num_token, - ) - self.fold = fold - self.local_rank = local_rank - self.representation = cfg.model.representation - self.data_repository = cfg.model.data_repository - self.cfg = cfg.model.finetune - self.context_length: int = self.cfg.context_length - self.callbacks = self.cfg.callbacks - self.tokenized_dataset = self._prepare_datasets( - self.cfg.path.finetune_traindata - ) - - def _prepare_datasets(self, subset: str) -> DatasetDict: - """ - Prepare training and validation datasets. - - Args: - train_df (pd.DataFrame): DataFrame containing training data. - - Returns: - DatasetDict: Dictionary containing training and validation datasets. - """ - - def replace_none(example, replacement="[PAD]"): - for key, value in example.items(): - if value is None: - example[key] = replacement - return example - - ds = load_dataset(self.data_repository, subset) - dataset = ds[self.fold].train_test_split(shuffle=True, test_size=0.2, seed=42) - dataset = dataset.filter( - lambda example: example[self.representation] is not None - ) - return dataset.map( - partial( - self._tokenize_pad_and_truncate, context_length=self.context_length - ), - batched=True, - ) - - def _callbacks(self) -> List[TrainerCallback]: - """Returns a list of callbacks for early stopping, and custom logging.""" - callbacks = [] - - if self.callbacks.early_stopping: - callbacks.append( - EarlyStoppingCallback( - early_stopping_patience=self.callbacks.early_stopping_patience, - early_stopping_threshold=self.callbacks.early_stopping_threshold, - ) - ) - - if self.callbacks.custom_logger: - callbacks.append(CustomWandbCallback_FineTune()) - - callbacks.append(EvaluateFirstStepCallback) - - return callbacks - - def _compute_metrics(self, p: Any) -> Dict[str, float]: - preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions - preds_argmax = np.argmax(preds, axis=1) - labels = p.label_ids - precision, recall, f1, _ = precision_recall_fscore_support(labels, preds_argmax, average='weighted') - acc = accuracy_score(labels, preds_argmax) - - # Compute ROC AUC - n_classes = preds.shape[1] - if n_classes == 2: - # Binary classification - roc_auc = roc_auc_score(labels, preds[:, 1]) - else: - # Multi-class classification - labels_binarized = label_binarize(labels, classes=range(n_classes)) - roc_auc = roc_auc_score(labels_binarized, preds, average='weighted', multi_class='ovr') - - return { - 'accuracy': acc, - 'f1': f1, - 'precision': precision, - 'recall': recall, - 'roc_auc': roc_auc - } - - def finetune(self) -> None: - """ - Perform fine-tuning of the language model. - """ - - pretrained_ckpt = self.cfg.path.pretrained_checkpoint - - config_train_args = self.cfg.training_arguments - callbacks = self._callbacks() - - training_args = TrainingArguments( - **config_train_args, - metric_for_best_model="f1", # or "accuracy", depending on your preference - greater_is_better=True, - ) - - model = AutoModelForSequenceClassification.from_pretrained( - pretrained_ckpt, num_labels=2, ignore_mismatched_sizes=False - ) - - if self.cfg.freeze_base_model: - for param in model.base_model.parameters(): - param.requires_grad = False - - if self.local_rank is not None: - model = model.to(self.local_rank) - model = nn.parallel.DistributedDataParallel( - model, device_ids=[self.local_rank] - ) - else: - model = model.to("cuda") - - trainer = Trainer( - model=model, - args=training_args, - data_collator=None, - compute_metrics=self._compute_metrics, - tokenizer=self._wrapped_tokenizer, - train_dataset=self.tokenized_dataset["train"], - eval_dataset=self.tokenized_dataset["test"], - callbacks=callbacks, - ) - - wandb.log({"Training Arguments": str(config_train_args)}) - wandb.log({"model_summary": str(model)}) - - trainer.train() - - eval_result = trainer.evaluate(eval_dataset=self.tokenized_dataset["test"]) - wandb.log(eval_result) - - model.save_pretrained(self.cfg.path.finetuned_modelname) - wandb.finish() - return self.cfg.path.finetuned_modelname - - def evaluate(self): - """ - Evaluate the fine-tuned model on the test dataset. - """ - ckpt = self.finetune() diff --git a/src/mattext/models/finetune.py b/src/mattext/models/finetune.py index 2a3536d..76ceadd 100644 --- a/src/mattext/models/finetune.py +++ b/src/mattext/models/finetune.py @@ -1,10 +1,18 @@ +from abc import ABC, abstractmethod from functools import partial from typing import Any, Dict, List +import numpy as np import torch import wandb from datasets import DatasetDict, load_dataset from omegaconf import DictConfig +from sklearn.metrics import ( + accuracy_score, + precision_recall_fscore_support, + roc_auc_score, +) +from sklearn.preprocessing import label_binarize from torch import nn from transformers import ( AutoModelForSequenceClassification, @@ -21,15 +29,7 @@ ) -class FinetuneModel(TokenizerMixin): - """Class to perform finetuning of a language model. - Initialize the FinetuneModel. - - Args: - cfg (DictConfig): Configuration for the fine-tuning. - local_rank (int, optional): Local rank for distributed training. Defaults to None. - """ - +class BaseFinetuneModel(TokenizerMixin, ABC): def __init__(self, cfg: DictConfig, local_rank=None, fold="fold_0") -> None: super().__init__( cfg=cfg.model.representation, @@ -48,22 +48,6 @@ def __init__(self, cfg: DictConfig, local_rank=None, fold="fold_0") -> None: ) def _prepare_datasets(self, subset: str) -> DatasetDict: - """ - Prepare training and validation datasets. - - Args: - train_df (pd.DataFrame): DataFrame containing training data. - - Returns: - DatasetDict: Dictionary containing training and validation datasets. - """ - - def replace_none(example, replacement="[PAD]"): - for key, value in example.items(): - if value is None: - example[key] = replacement - return example - ds = load_dataset(self.data_repository, subset) dataset = ds[self.fold].train_test_split(shuffle=True, test_size=0.2, seed=42) dataset = dataset.filter( @@ -77,9 +61,7 @@ def replace_none(example, replacement="[PAD]"): ) def _callbacks(self) -> List[TrainerCallback]: - """Returns a list of callbacks for early stopping, and custom logging.""" callbacks = [] - if self.callbacks.early_stopping: callbacks.append( EarlyStoppingCallback( @@ -87,48 +69,27 @@ def _callbacks(self) -> List[TrainerCallback]: early_stopping_threshold=self.callbacks.early_stopping_threshold, ) ) - if self.callbacks.custom_logger: callbacks.append(CustomWandbCallback_FineTune()) - callbacks.append(EvaluateFirstStepCallback) - return callbacks - def _compute_metrics(self, p: Any, eval=True) -> Dict[str, float]: - preds = torch.tensor( - p.predictions.squeeze() - ) # Convert predictions to PyTorch tensor - label_ids = torch.tensor(p.label_ids) # Convert label_ids to PyTorch tensor - - if eval: - # Calculate RMSE as evaluation metric - eval_rmse = torch.sqrt(((preds - label_ids) ** 2).mean()).item() - return {"eval_rmse": round(eval_rmse, 3)} - else: - # Calculate RMSE as training metric - loss = torch.sqrt(((preds - label_ids) ** 2).mean()).item() - return {"train_rmse": round(loss, 3), "loss": round(loss, 3)} - - def finetune(self) -> None: - """ - Perform fine-tuning of the language model. - """ + @abstractmethod + def _compute_metrics(self, p: Any) -> Dict[str, float]: + pass + def finetune(self) -> str: pretrained_ckpt = self.cfg.path.pretrained_checkpoint - config_train_args = self.cfg.training_arguments callbacks = self._callbacks() training_args = TrainingArguments( **config_train_args, - metric_for_best_model="eval_rmse", # Metric to use for determining the best model - greater_is_better=False, # Lower eval_rmse is better + metric_for_best_model=self.get_best_metric(), + greater_is_better=self.is_greater_better(), ) - model = AutoModelForSequenceClassification.from_pretrained( - pretrained_ckpt, num_labels=1, ignore_mismatched_sizes=False - ) + model = self.get_model(pretrained_ckpt) if self.cfg.freeze_base_model: for param in model.base_model.parameters(): @@ -165,8 +126,75 @@ def finetune(self) -> None: wandb.finish() return self.cfg.path.finetuned_modelname + @abstractmethod + def get_best_metric(self) -> str: + pass + + @abstractmethod + def is_greater_better(self) -> bool: + pass + + @abstractmethod + def get_model(self, pretrained_ckpt: str): + pass + def evaluate(self): - """ - Evaluate the fine-tuned model on the test dataset. - """ ckpt = self.finetune() + + +class FinetuneModel(BaseFinetuneModel): + def _compute_metrics(self, p: Any) -> Dict[str, float]: + preds = torch.tensor(p.predictions.squeeze()) + label_ids = torch.tensor(p.label_ids) + eval_rmse = torch.sqrt(((preds - label_ids) ** 2).mean()).item() + return {"eval_rmse": round(eval_rmse, 3)} + + def get_best_metric(self) -> str: + return "eval_rmse" + + def is_greater_better(self) -> bool: + return False + + def get_model(self, pretrained_ckpt: str): + return AutoModelForSequenceClassification.from_pretrained( + pretrained_ckpt, num_labels=1, ignore_mismatched_sizes=False + ) + + +class FinetuneClassificationModel(BaseFinetuneModel): + def _compute_metrics(self, p: Any) -> Dict[str, float]: + preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions + preds_argmax = np.argmax(preds, axis=1) + labels = p.label_ids + precision, recall, f1, _ = precision_recall_fscore_support( + labels, preds_argmax, average="weighted" + ) + acc = accuracy_score(labels, preds_argmax) + + n_classes = preds.shape[1] + if n_classes == 2: + roc_auc = roc_auc_score(labels, preds[:, 1]) + else: + labels_binarized = label_binarize(labels, classes=range(n_classes)) + roc_auc = roc_auc_score( + labels_binarized, preds, average="weighted", multi_class="ovr" + ) + + return { + "accuracy": acc, + "f1": f1, + "precision": precision, + "recall": recall, + "roc_auc": roc_auc, + } + + def get_best_metric(self) -> str: + return "f1" + + def is_greater_better(self) -> bool: + return True + + def get_model(self, pretrained_ckpt: str): + return AutoModelForSequenceClassification.from_pretrained( + pretrained_ckpt, num_labels=2, ignore_mismatched_sizes=False + ) From c6548a8f645d51bc5eed9097c58e378a85ec1ab4 Mon Sep 17 00:00:00 2001 From: n0w0f Date: Tue, 20 Aug 2024 14:33:44 +0200 Subject: [PATCH 14/20] chore: abstract out inference from reg and classification to base --- src/mattext/models/predict.py | 139 ++++++++-------------------------- 1 file changed, 30 insertions(+), 109 deletions(-) diff --git a/src/mattext/models/predict.py b/src/mattext/models/predict.py index b537299..96b4836 100644 --- a/src/mattext/models/predict.py +++ b/src/mattext/models/predict.py @@ -1,5 +1,6 @@ +from abc import ABC, abstractmethod from functools import partial -from typing import List, Tuple +from typing import List, Tuple, Union import numpy as np import pandas as pd @@ -17,9 +18,7 @@ from mattext.models.utils import CustomWandbCallback_Inference, TokenizerMixin -class Inference(TokenizerMixin): - """Class to perform inference on a language model with a sequence classification head.""" - +class BaseInference(TokenizerMixin, ABC): def __init__(self, cfg: DictConfig, fold="fold_0"): super().__init__( cfg=cfg.model.representation, @@ -36,20 +35,10 @@ def __init__(self, cfg: DictConfig, fold="fold_0"): self.prediction_ids = None def _prepare_datasets(self, path: str) -> DatasetDict: - """ - Prepare training and validation datasets. - - Args: - train_df (pd.DataFrame): DataFrame containing training data. - - Returns: - DatasetDict: Dictionary containing training and validation datasets. - """ dataset = load_dataset(self.data_repository, path) filtered_dataset = dataset[self.fold].filter( lambda example: example[self.representation] is not None ) - return filtered_dataset.map( partial( self._tokenize_pad_and_truncate, context_length=self.context_length @@ -58,16 +47,21 @@ def _prepare_datasets(self, path: str) -> DatasetDict: ) def _callbacks(self) -> List[TrainerCallback]: - """Returns a list of callbacks for logging.""" return [CustomWandbCallback_Inference()] - def predict(self): + @abstractmethod + def get_model(self, pretrained_ckpt: str): + pass + + @abstractmethod + def process_predictions(self, predictions) -> Union[pd.Series, pd.DataFrame]: + pass + + def predict(self) -> Tuple[Union[pd.Series, pd.DataFrame], List[str]]: pretrained_ckpt = self.cfg.path.pretrained_checkpoint callbacks = self._callbacks() - model = AutoModelForSequenceClassification.from_pretrained( - pretrained_ckpt, num_labels=1, ignore_mismatched_sizes=False - ) + model = self.get_model(pretrained_ckpt) trainer = Trainer( model=model.to("cuda"), data_collator=None, callbacks=callbacks @@ -75,119 +69,46 @@ def predict(self): predictions = trainer.predict(self.tokenized_test_datasets) for callback in callbacks: - callback.on_predict_end( - None, None, None, model, predictions - ) # Manually trigger callback + callback.on_predict_end(None, None, None, model, predictions) torch.cuda.empty_cache() - # TODO: Save predictions to disk optional - # os.makedirs(self.cfg.path.predictions, exist_ok=True) - # predictions_path = os.path.join(self.cfg.path.predictions, 'predictions.npy') - # np.save(predictions_path, predictions.predictions) prediction_ids = self.tokenized_test_datasets["mbid"] self.prediction_ids = prediction_ids - return pd.Series(predictions.predictions.flatten()), prediction_ids + processed_predictions = self.process_predictions(predictions) + return processed_predictions, prediction_ids -class InferenceClassification(TokenizerMixin): - """Class to perform inference on a language model with a sequence classification head for classification tasks.""" - - def __init__(self, cfg: DictConfig, fold="fold_0"): - super().__init__( - cfg=cfg.model.representation, - special_tokens=cfg.model.special_tokens, - special_num_token=cfg.model.special_num_token, - ) - self.fold = fold - self.representation = cfg.model.representation - self.data_repository = cfg.model.data_repository - self.dataset_name = cfg.model.finetune.dataset_name - self.cfg = cfg.model.inference - self.context_length: int = self.cfg.context_length - #self.num_labels = cfg.model.num_labels - self.num_labels = 2 - self.tokenized_test_datasets = self._prepare_datasets(self.cfg.path.test_data) - self.prediction_ids = None - - def _prepare_datasets(self, path: str) -> DatasetDict: - """ - Prepare test datasets. - - Args: - path (str): Path to the test data. - - Returns: - DatasetDict: Dictionary containing the test dataset. - """ - dataset = load_dataset(self.data_repository, path) - filtered_dataset = dataset[self.fold].filter( - lambda example: example[self.representation] is not None - ) - return filtered_dataset.map( - partial( - self._tokenize_pad_and_truncate, context_length=self.context_length - ), - batched=True, +class Inference(BaseInference): + def get_model(self, pretrained_ckpt: str): + return AutoModelForSequenceClassification.from_pretrained( + pretrained_ckpt, num_labels=1, ignore_mismatched_sizes=False ) - def _callbacks(self) -> List[TrainerCallback]: - """Returns a list of callbacks for logging.""" - return [CustomWandbCallback_Inference()] + def process_predictions(self, predictions) -> pd.Series: + return pd.Series(predictions.predictions.flatten()) - def predict(self) -> Tuple[pd.DataFrame, List[str]]: - """ - Perform prediction on the test dataset. - Returns: - Tuple[pd.DataFrame, List[str]]: A tuple containing the predictions as a DataFrame - and the prediction IDs as a list. - """ - pretrained_ckpt = self.cfg.path.pretrained_checkpoint - callbacks = self._callbacks() +class InferenceClassification(BaseInference): + def __init__(self, cfg: DictConfig, fold="fold_0"): + super().__init__(cfg, fold) + self.num_labels = 2 # You might want to make this configurable - model = AutoModelForSequenceClassification.from_pretrained( + def get_model(self, pretrained_ckpt: str): + return AutoModelForSequenceClassification.from_pretrained( pretrained_ckpt, num_labels=self.num_labels, ignore_mismatched_sizes=False ) - trainer = Trainer( - model=model.to("cuda"), data_collator=None, callbacks=callbacks - ) - - predictions = trainer.predict(self.tokenized_test_datasets) - for callback in callbacks: - callback.on_predict_end( - None, None, None, model, predictions - ) # Manually trigger callback - torch.cuda.empty_cache() - - prediction_ids = self.tokenized_test_datasets["mbid"] - self.prediction_ids = prediction_ids - - # Convert predictions to probabilities + def process_predictions(self, predictions) -> pd.DataFrame: probabilities = torch.nn.functional.softmax( torch.from_numpy(predictions.predictions), dim=-1 ).numpy() - - #Create a DataFrame with prediction probabilities - prediction_df = pd.DataFrame( + return pd.DataFrame( probabilities, columns=[f"class_{i}" for i in range(self.num_labels)] ) - return prediction_df, prediction_ids - def evaluate(self, true_labels: List[int]) -> dict: - """ - Evaluate the model's predictions against true labels. - - Args: - true_labels (List[int]): The true labels for the test set. - - Returns: - dict: A dictionary containing evaluation metrics. - """ - predictions, _ = self.predict() pred_labels = np.argmax(predictions.values, axis=1) From 824471adb1c25fc5903739576146fb60f033bf1a Mon Sep 17 00:00:00 2001 From: n0w0f Date: Tue, 20 Aug 2024 14:42:56 +0200 Subject: [PATCH 15/20] chore: refactor task --- src/mattext/models/benchmark.py | 4 +- src/mattext/models/score.py | 174 ++++++++------------------------ 2 files changed, 46 insertions(+), 132 deletions(-) diff --git a/src/mattext/models/benchmark.py b/src/mattext/models/benchmark.py index 2daa457..23f5ae9 100644 --- a/src/mattext/models/benchmark.py +++ b/src/mattext/models/benchmark.py @@ -77,8 +77,10 @@ def _run_experiment(self, task, i, exp_name, test_name, local_rank): self._record_predictions(task, i, predictions, prediction_ids) except Exception as e: print( - f"Error occurred during inference for finetuned checkpoint '{exp_name}':" + f"Error occurred during inference for finetuned checkpoint '{exp_name}': {str(e)}" ) + if isinstance(e, (ValueError, TypeError)): + raise print(traceback.format_exc()) @abstractmethod diff --git a/src/mattext/models/score.py b/src/mattext/models/score.py index dc4f32b..50a5607 100644 --- a/src/mattext/models/score.py +++ b/src/mattext/models/score.py @@ -1,8 +1,8 @@ -import json import math -from dataclasses import asdict, dataclass, field +from dataclasses import dataclass, field from typing import Any, Dict, List +import jsonpickle import numpy as np import pandas as pd from matbench.data_ops import load @@ -13,13 +13,12 @@ precision_recall_fscore_support, roc_auc_score, ) -from sklearn.preprocessing import label_binarize MATTEXT_MATBENCH = { "kvrh": "matbench_log_kvrh", "gvrh": "matbench_log_gvrh", "perovskites": "matbench_perovskites", - "bandgap" : "matbench_mp_gap", + "bandgap": "matbench_mp_gap", "form_energy": "matbench_mp_e_form", "is-metal": "matbench_mp_is_metal", } @@ -33,37 +32,20 @@ "form_energy": "e_form", } -METRIC_MAP = { - "mae": mean_absolute_error, - "rmse": lambda true, pred: math.sqrt(mean_squared_error(true, pred)), -} - - -def fold_key_namer(fold_key): - return f"fold_{fold_key}" - def load_true_scores(dataset, mbids): data_frame = load(MATTEXT_MATBENCH[dataset]) - print(MATMINER_COLUMNS) scores = [] for mbid in mbids: - # Get the score for the mbid score = data_frame.loc[mbid][MATMINER_COLUMNS[dataset]] scores.append(score) return scores -def mattext_score(prediction_ids, predictions, task_name): - true = load_true_scores(task_name, prediction_ids) - return mean_squared_error(true, predictions) - - @dataclass -class MatTextTask: +class BaseMatTextTask: task_name: str num_folds: int = 5 - # metric: str folds_results: Dict[int, Dict[str, Any]] = field(default_factory=dict) recorded_folds: List[int] = field(default_factory=list) @@ -73,6 +55,35 @@ def record_fold( if fold in self.recorded_folds: raise ValueError(f"Fold {fold} has already been recorded.") true_scores = load_true_scores(self.task_name, prediction_ids) + self._calculate_metrics(fold, prediction_ids, predictions, true_scores) + self.recorded_folds.append(fold) + + def _calculate_metrics(self, fold, prediction_ids, predictions, true_scores): + raise NotImplementedError("Subclasses must implement this method") + + def get_final_results(self): + if len(self.recorded_folds) < self.num_folds: + raise ValueError( + f"All {self.num_folds} folds must be recorded before getting final results." + ) + return self._aggregate_results() + + def _aggregate_results(self): + raise NotImplementedError("Subclasses must implement this method") + + def to_file(self, file_path: str): + with open(file_path, "w") as f: + f.write(jsonpickle.encode(self)) + + @staticmethod + def from_file(file_path: str): + with open(file_path) as f: + return jsonpickle.decode(f.read()) + + +@dataclass +class MatTextTask(BaseMatTextTask): + def _calculate_metrics(self, fold, prediction_ids, predictions, true_scores): mae = mean_absolute_error(true_scores, predictions) rmse = math.sqrt(mean_squared_error(true_scores, predictions)) self.folds_results[fold] = { @@ -82,91 +93,33 @@ def record_fold( "mae": mae, "rmse": rmse, } - self.recorded_folds.append(fold) - def get_final_results(self): - if len(self.recorded_folds) < self.num_folds: - raise ValueError( - f"All {self.num_folds} folds must be recorded before getting final results." - ) + def _aggregate_results(self): final_scores_mae = [ self.folds_results[fold]["mae"] for fold in range(self.num_folds) ] final_scores_rmse = [ self.folds_results[fold]["rmse"] for fold in range(self.num_folds) ] - return { "mean_mae_score": np.mean(final_scores_mae), "std_mae_score": np.std(final_scores_mae), "mean_rmse_score": np.mean(final_scores_rmse), "std_rmse_score": np.std(final_scores_rmse), - "std_score": np.std(final_scores_mae), } - def to_file(self, file_path: str): - final_results = ( - self.get_final_results() - if len(self.recorded_folds) == self.num_folds - else {} - ) - data_to_save = asdict(self) - data_to_save["final_results"] = final_results - with open(file_path, "w") as f: - json.dump(data_to_save, f, default=self._json_serializable) - - @staticmethod - def from_file(file_path: str): - with open(file_path) as f: - data = json.load(f) - task = MatTextTask(task_name=data["task_name"], metric=data["metric"]) - task.folds_results = data["folds_results"] - task.recorded_folds = data["recorded_folds"] - return task - - @staticmethod - def _prepare_for_serialization(obj): - if isinstance(obj, dict): - return { - k: MatTextTask._prepare_for_serialization(v) for k, v in obj.items() - } - elif ( - isinstance(obj, (list, pd.Series, np.ndarray)) - ): - return MatTextTask._prepare_for_serialization(obj.tolist()) - else: - return obj - - @staticmethod - def _json_serializable(obj): - if isinstance(obj, (np.ndarray, pd.Series)): - return obj.tolist() - raise TypeError(f"Type {type(obj)} not serializable") - - - @dataclass -class MatTextClassificationTask: - task_name: str - num_folds: int = 5 +class MatTextClassificationTask(BaseMatTextTask): num_classes: int = 2 - folds_results: Dict[int, Dict[str, Any]] = field(default_factory=dict) - recorded_folds: List[int] = field(default_factory=list) - - def record_fold( - self, fold: int, prediction_ids: List[str], predictions: List[float] - ): - if fold in self.recorded_folds: - raise ValueError(f"Fold {fold} has already been recorded.") - true_labels = load_true_scores(self.task_name, prediction_ids) + def _calculate_metrics(self, fold, prediction_ids, predictions, true_labels): pred_labels = np.argmax(predictions, axis=1) - accuracy = accuracy_score(true_labels, pred_labels) - precision, recall, f1, _ = precision_recall_fscore_support(true_labels, pred_labels, average='weighted') + precision, recall, f1, _ = precision_recall_fscore_support( + true_labels, pred_labels, average="weighted" + ) roc_auc = roc_auc_score(true_labels, predictions[:, 1]) - self.folds_results[fold] = { "prediction_ids": prediction_ids, "predictions": predictions, @@ -175,56 +128,15 @@ def record_fold( "precision": precision, "recall": recall, "f1": f1, - "roc_auc": roc_auc + "roc_auc": roc_auc, } - self.recorded_folds.append(fold) - def get_final_results(self): - if len(self.recorded_folds) < self.num_folds: - raise ValueError( - f"All {self.num_folds} folds must be recorded before getting final results." - ) - metrics = ['accuracy', 'precision', 'recall', 'f1', 'roc_auc'] + def _aggregate_results(self): + metrics = ["accuracy", "precision", "recall", "f1", "roc_auc"] final_scores = {metric: [] for metric in metrics} - for fold in range(self.num_folds): for metric in metrics: final_scores[metric].append(self.folds_results[fold][metric]) - return { f"mean_{metric}": np.mean(scores) for metric, scores in final_scores.items() - } | { - f"std_{metric}": np.std(scores) for metric, scores in final_scores.items() - } - - def to_file(self, file_path: str): - final_results = ( - self.get_final_results() - if len(self.recorded_folds) == self.num_folds - else {} - ) - data_to_save = asdict(self) - data_to_save["final_results"] = final_results - with open(file_path, "w") as f: - json.dump(data_to_save, f, default=self._json_serializable) - - @staticmethod - def from_file(file_path: str): - with open(file_path) as f: - data = json.load(f) - task = MatTextClassificationTask(task_name=data["task_name"], num_classes=data["num_classes"]) - task.folds_results = data["folds_results"] - task.recorded_folds = data["recorded_folds"] - return task - - @staticmethod - def _json_serializable(obj): - if isinstance(obj, (np.ndarray, pd.Series)): - return obj.tolist() - elif isinstance(obj, np.bool_): - return bool(obj) - elif isinstance(obj, np.integer): - return int(obj) - elif isinstance(obj, np.floating): - return float(obj) - raise TypeError(f"Type {type(obj)} not serializable") + } | {f"std_{metric}": np.std(scores) for metric, scores in final_scores.items()} From e26424bb1e3541e42dc61ac5a93eb5c1835c1290 Mon Sep 17 00:00:00 2001 From: n0w0f Date: Tue, 20 Aug 2024 14:59:07 +0200 Subject: [PATCH 16/20] chore: improve Mattext Tasks --- src/mattext/models/benchmark.py | 3 +- src/mattext/models/score.py | 129 ++++++++++++++++---------------- 2 files changed, 67 insertions(+), 65 deletions(-) diff --git a/src/mattext/models/benchmark.py b/src/mattext/models/benchmark.py index 23f5ae9..235c325 100644 --- a/src/mattext/models/benchmark.py +++ b/src/mattext/models/benchmark.py @@ -10,7 +10,6 @@ from mattext.models.predict import Inference, InferenceClassification from mattext.models.score import ( MATTEXT_MATBENCH, - MatTextClassificationTask, MatTextTask, ) from mattext.models.utils import fold_key_namer @@ -139,7 +138,7 @@ def _record_predictions(self, task, fold, predictions, prediction_ids): class MatbenchmarkClassification(BaseBenchmark): def run_benchmarking(self, local_rank=None) -> None: - task = MatTextClassificationTask(task_name=self.task) + task = MatTextTask(task_name=self.task, is_classification=True) for i, (exp_name, test_name) in enumerate( zip(self.exp_names, self.test_exp_names) diff --git a/src/mattext/models/score.py b/src/mattext/models/score.py index 50a5607..35343a3 100644 --- a/src/mattext/models/score.py +++ b/src/mattext/models/score.py @@ -2,7 +2,6 @@ from dataclasses import dataclass, field from typing import Any, Dict, List -import jsonpickle import numpy as np import pandas as pd from matbench.data_ops import load @@ -13,6 +12,7 @@ precision_recall_fscore_support, roc_auc_score, ) +import json MATTEXT_MATBENCH = { "kvrh": "matbench_log_kvrh", @@ -32,7 +32,6 @@ "form_energy": "e_form", } - def load_true_scores(dataset, mbids): data_frame = load(MATTEXT_MATBENCH[dataset]) scores = [] @@ -41,49 +40,28 @@ def load_true_scores(dataset, mbids): scores.append(score) return scores - @dataclass -class BaseMatTextTask: +class MatTextTask: task_name: str num_folds: int = 5 + is_classification: bool = False + num_classes: int = 2 folds_results: Dict[int, Dict[str, Any]] = field(default_factory=dict) recorded_folds: List[int] = field(default_factory=list) - def record_fold( - self, fold: int, prediction_ids: List[str], predictions: List[float] - ): + def record_fold(self, fold: int, prediction_ids: List[str], predictions: List[float]): if fold in self.recorded_folds: raise ValueError(f"Fold {fold} has already been recorded.") true_scores = load_true_scores(self.task_name, prediction_ids) - self._calculate_metrics(fold, prediction_ids, predictions, true_scores) + + if self.is_classification: + self._calculate_classification_metrics(fold, prediction_ids, predictions, true_scores) + else: + self._calculate_regression_metrics(fold, prediction_ids, predictions, true_scores) + self.recorded_folds.append(fold) - def _calculate_metrics(self, fold, prediction_ids, predictions, true_scores): - raise NotImplementedError("Subclasses must implement this method") - - def get_final_results(self): - if len(self.recorded_folds) < self.num_folds: - raise ValueError( - f"All {self.num_folds} folds must be recorded before getting final results." - ) - return self._aggregate_results() - - def _aggregate_results(self): - raise NotImplementedError("Subclasses must implement this method") - - def to_file(self, file_path: str): - with open(file_path, "w") as f: - f.write(jsonpickle.encode(self)) - - @staticmethod - def from_file(file_path: str): - with open(file_path) as f: - return jsonpickle.decode(f.read()) - - -@dataclass -class MatTextTask(BaseMatTextTask): - def _calculate_metrics(self, fold, prediction_ids, predictions, true_scores): + def _calculate_regression_metrics(self, fold, prediction_ids, predictions, true_scores): mae = mean_absolute_error(true_scores, predictions) rmse = math.sqrt(mean_squared_error(true_scores, predictions)) self.folds_results[fold] = { @@ -94,32 +72,11 @@ def _calculate_metrics(self, fold, prediction_ids, predictions, true_scores): "rmse": rmse, } - def _aggregate_results(self): - final_scores_mae = [ - self.folds_results[fold]["mae"] for fold in range(self.num_folds) - ] - final_scores_rmse = [ - self.folds_results[fold]["rmse"] for fold in range(self.num_folds) - ] - return { - "mean_mae_score": np.mean(final_scores_mae), - "std_mae_score": np.std(final_scores_mae), - "mean_rmse_score": np.mean(final_scores_rmse), - "std_rmse_score": np.std(final_scores_rmse), - } - - -@dataclass -class MatTextClassificationTask(BaseMatTextTask): - num_classes: int = 2 - - def _calculate_metrics(self, fold, prediction_ids, predictions, true_labels): + def _calculate_classification_metrics(self, fold, prediction_ids, predictions, true_labels): pred_labels = np.argmax(predictions, axis=1) accuracy = accuracy_score(true_labels, pred_labels) - precision, recall, f1, _ = precision_recall_fscore_support( - true_labels, pred_labels, average="weighted" - ) - roc_auc = roc_auc_score(true_labels, predictions[:, 1]) + precision, recall, f1, _ = precision_recall_fscore_support(true_labels, pred_labels, average='weighted') + roc_auc = roc_auc_score(true_labels, predictions[:, 1]) if self.num_classes == 2 else None self.folds_results[fold] = { "prediction_ids": prediction_ids, "predictions": predictions, @@ -128,15 +85,61 @@ def _calculate_metrics(self, fold, prediction_ids, predictions, true_labels): "precision": precision, "recall": recall, "f1": f1, - "roc_auc": roc_auc, + "roc_auc": roc_auc } + def get_final_results(self): + if len(self.recorded_folds) < self.num_folds: + raise ValueError( + f"All {self.num_folds} folds must be recorded before getting final results." + ) + return self._aggregate_results() + def _aggregate_results(self): - metrics = ["accuracy", "precision", "recall", "f1", "roc_auc"] + if self.is_classification: + metrics = ['accuracy', 'precision', 'recall', 'f1', 'roc_auc'] + else: + metrics = ['mae', 'rmse'] + final_scores = {metric: [] for metric in metrics} for fold in range(self.num_folds): for metric in metrics: - final_scores[metric].append(self.folds_results[fold][metric]) + if metric in self.folds_results[fold]: + final_scores[metric].append(self.folds_results[fold][metric]) + return { - f"mean_{metric}": np.mean(scores) for metric, scores in final_scores.items() - } | {f"std_{metric}": np.std(scores) for metric, scores in final_scores.items()} + f"mean_{metric}": np.mean(scores) for metric, scores in final_scores.items() if scores + } | { + f"std_{metric}": np.std(scores) for metric, scores in final_scores.items() if scores + } + + def to_file(self, file_path: str): + with open(file_path, "w") as f: + json.dump(self, f, default=self._json_serializable) + + @staticmethod + def from_file(file_path: str): + with open(file_path) as f: + data = json.load(f) + task = MatTextTask(task_name=data["task_name"], num_folds=data["num_folds"], + is_classification=data["is_classification"], num_classes=data["num_classes"]) + task.folds_results = data["folds_results"] + task.recorded_folds = data["recorded_folds"] + return task + + @staticmethod + def _json_serializable(obj): + if isinstance(obj, (np.ndarray, pd.Series)): + return obj.tolist() + elif isinstance(obj, (np.bool_, np.integer, np.floating)): + return obj.item() + elif isinstance(obj, MatTextTask): + return { + "task_name": obj.task_name, + "num_folds": obj.num_folds, + "is_classification": obj.is_classification, + "num_classes": obj.num_classes, + "folds_results": obj.folds_results, + "recorded_folds": obj.recorded_folds + } + raise TypeError(f"Type {type(obj)} not serializable") \ No newline at end of file From 624dd4c2d1201b1ebc717b30f98d6ce8f47477d1 Mon Sep 17 00:00:00 2001 From: n0w0f Date: Wed, 21 Aug 2024 10:29:48 +0200 Subject: [PATCH 17/20] chore: improve benchmarking abstraction --- revision-scripts/mp_classification.py | 4 ++-- src/mattext/models/benchmark.py | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/revision-scripts/mp_classification.py b/revision-scripts/mp_classification.py index 219caec..763e5d1 100644 --- a/revision-scripts/mp_classification.py +++ b/revision-scripts/mp_classification.py @@ -20,8 +20,8 @@ def __len__(self): return self.txn.stat()['entries'] def get(self, index): - id = f"{index}".encode("ascii") - return pickle.loads(self.txn.get(id)) + id_ = f"{index}".encode("ascii") + return pickle.loads(self.txn.get(id_)) def create_json_from_lmdb(lmdb_path, output_dir): dataset = Dataset(lmdb_path) diff --git a/src/mattext/models/benchmark.py b/src/mattext/models/benchmark.py index 235c325..b79bfd5 100644 --- a/src/mattext/models/benchmark.py +++ b/src/mattext/models/benchmark.py @@ -138,7 +138,7 @@ def _record_predictions(self, task, fold, predictions, prediction_ids): class MatbenchmarkClassification(BaseBenchmark): def run_benchmarking(self, local_rank=None) -> None: - task = MatTextTask(task_name=self.task, is_classification=True) + task = self._initialize_task() for i, (exp_name, test_name) in enumerate( zip(self.exp_names, self.test_exp_names) @@ -152,6 +152,9 @@ def run_benchmarking(self, local_rank=None) -> None: self._save_results(task) + def _initialize_task(self): + return MatTextTask(task_name=self.task, is_classification=True) + def _get_finetuner(self, exp_cfg, local_rank, fold_name): return FinetuneClassificationModel(exp_cfg, local_rank, fold=fold_name) From c45f5682f37657bc6a053aff54629079dc4db939 Mon Sep 17 00:00:00 2001 From: n0w0f Date: Wed, 21 Aug 2024 14:22:30 +0200 Subject: [PATCH 18/20] refactor --- revision-scripts/mp_classification.py | 44 +++++++++++++++------------ 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/revision-scripts/mp_classification.py b/revision-scripts/mp_classification.py index 763e5d1..6230a06 100644 --- a/revision-scripts/mp_classification.py +++ b/revision-scripts/mp_classification.py @@ -1,46 +1,51 @@ -import lmdb -import pickle import json import os -from pymatgen.core import Structure +import pickle + import fire +import lmdb +from pymatgen.core import Structure + class Dataset: def __init__(self, lmdb_path, max_readers=1): - self.env = lmdb.open(lmdb_path, - subdir=False, - readonly=True, - lock=False, - readahead=False, - meminit=False, - max_readers=max_readers) + self.env = lmdb.open( + lmdb_path, + subdir=False, + readonly=True, + lock=False, + readahead=False, + meminit=False, + max_readers=max_readers, + ) self.txn = self.env.begin() def __len__(self): - return self.txn.stat()['entries'] + return self.txn.stat()["entries"] def get(self, index): id_ = f"{index}".encode("ascii") return pickle.loads(self.txn.get(id_)) + def create_json_from_lmdb(lmdb_path, output_dir): dataset = Dataset(lmdb_path) output_data = [] for i in range(len(dataset)): d = dataset.get(i) - + # Convert structure to CIF - structure = d['structure'] + structure = d["structure"] cif = structure.to(fmt="cif") entry = { "structure": cif, - "is_stable": d['is_stable'], - "is_metal": d['is_metal'], - "is_magnetic": d['is_magnetic'] + "is_stable": d["is_stable"], + "is_metal": d["is_metal"], + "is_magnetic": d["is_magnetic"], } - + output_data.append(entry) # Ensure output directory exists @@ -48,10 +53,11 @@ def create_json_from_lmdb(lmdb_path, output_dir): # Write to JSON file output_file = os.path.join(output_dir, "mp_test.json") - with open(output_file, 'w') as f: + with open(output_file, "w") as f: json.dump(output_data, f, indent=2) print(f"JSON file created: {output_file}") + if __name__ == "__main__": - fire.Fire(create_json_from_lmdb) \ No newline at end of file + fire.Fire(create_json_from_lmdb) From bfd872c23eadac7d56cf9abed62233caac45aee4 Mon Sep 17 00:00:00 2001 From: n0w0f Date: Sun, 22 Sep 2024 19:37:02 +0200 Subject: [PATCH 19/20] fix: add logger --- src/mattext/models/benchmark.py | 15 ++++---- src/mattext/models/helper.py | 5 +-- src/mattext/models/inference.py | 15 +++----- src/mattext/models/llama_sft.py | 2 +- src/mattext/models/score.py | 64 +++++++++++++++++++++++---------- src/mattext/models/utils.py | 7 ++-- 6 files changed, 64 insertions(+), 44 deletions(-) diff --git a/src/mattext/models/benchmark.py b/src/mattext/models/benchmark.py index b79bfd5..ef93603 100644 --- a/src/mattext/models/benchmark.py +++ b/src/mattext/models/benchmark.py @@ -13,6 +13,7 @@ MatTextTask, ) from mattext.models.utils import fold_key_namer +from loguru import logger class BaseBenchmark(ABC): @@ -44,12 +45,10 @@ def _initialize_task(self): def _run_experiment(self, task, i, exp_name, test_name, local_rank): fold_name = fold_key_namer(i) - print( + logger.info( f"Running training on {self.train_data}, and testing on {self.test_data} for fold {i}" ) - print("-------------------------") - print(fold_name) - print("-------------------------") + logger.info("Fold Name: ",fold_name) exp_cfg = self.task_cfg.copy() exp_cfg.model.finetune.exp_name = exp_name @@ -57,9 +56,7 @@ def _run_experiment(self, task, i, exp_name, test_name, local_rank): finetuner = self._get_finetuner(exp_cfg, local_rank, fold_name) ckpt = finetuner.finetune() - print("-------------------------") - print(ckpt) - print("-------------------------") + logger.info("Checkpoint: ",ckpt) wandb.init( config=dict(self.task_cfg.model.inference), @@ -75,12 +72,12 @@ def _run_experiment(self, task, i, exp_name, test_name, local_rank): predictions, prediction_ids = predict.predict() self._record_predictions(task, i, predictions, prediction_ids) except Exception as e: - print( + logger.error( f"Error occurred during inference for finetuned checkpoint '{exp_name}': {str(e)}" ) if isinstance(e, (ValueError, TypeError)): raise - print(traceback.format_exc()) + logger.error(traceback.format_exc()) @abstractmethod def _get_finetuner(self, exp_cfg, local_rank, fold_name): diff --git a/src/mattext/models/helper.py b/src/mattext/models/helper.py index 4f281b3..f360c67 100644 --- a/src/mattext/models/helper.py +++ b/src/mattext/models/helper.py @@ -2,6 +2,7 @@ import matplotlib.pyplot as plt from datasets import load_dataset from tqdm import tqdm +from loguru import logger from mattext.models.utils import TokenizerMixin @@ -15,8 +16,8 @@ def count_tokens_and_plot( ): tokenizer = TokenizerMixin(representation) ds = load_dataset("json", data_files=dataset_path, split="train") - print(ds) - print(representation) + logger.info("Dataset: ",ds) + logger.info("Representation: "representation) dataset = ds[representation] token_counts = [] diff --git a/src/mattext/models/inference.py b/src/mattext/models/inference.py index 49352e4..f927b65 100644 --- a/src/mattext/models/inference.py +++ b/src/mattext/models/inference.py @@ -2,6 +2,7 @@ import traceback import wandb +from loguru import logger from matbench.bench import MatbenchBenchmark from omegaconf import DictConfig @@ -65,21 +66,15 @@ def run_benchmarking(self, local_rank=None) -> None: for i, (exp_name, test_name, train_data_path, test_data_path) in enumerate( zip(self.exp_names, self.test_exp_names, self.train_data, self.test_data) ): - print( + logger.info( f"Running training on {train_data_path}, and testing on {test_data_path}" ) - # wandb.init( - # config=dict(self.task_cfg.model.finetune), - # project=self.task_cfg.model.logging.wandb_project, name=exp_name) - exp_cfg = self.task_cfg.copy() exp_cfg.model.finetune.exp_name = exp_name exp_cfg.model.finetune.path.finetune_traindata = train_data_path ckpt = exp_cfg.model.finetune.path.finetuned_modelname - print("-------------------------") - print(ckpt) - print("-------------------------") + logger.info("Checkpoint: ", ckpt) wandb.init( config=dict(self.task_cfg.model.inference), @@ -95,10 +90,10 @@ def run_benchmarking(self, local_rank=None) -> None: predictions = predict.predict() benchmark.record(i, predictions) except Exception as e: - print( + logger.error( f"Error occurred during inference for finetuned checkpoint '{exp_name}':" ) - print(traceback.format_exc()) + logger.error(traceback.format_exc()) if not os.path.exists(self.benchmark_save_path): os.makedirs(self.benchmark_save_path) diff --git a/src/mattext/models/llama_sft.py b/src/mattext/models/llama_sft.py index 53942a1..8ab47eb 100644 --- a/src/mattext/models/llama_sft.py +++ b/src/mattext/models/llama_sft.py @@ -4,6 +4,7 @@ import torch import wandb from datasets import load_dataset +from loguru import logger from omegaconf import DictConfig from peft import ( LoraConfig, @@ -22,7 +23,6 @@ from mattext.models.utils import ( EvaluateFirstStepCallback, ) -from loguru import logger class FinetuneLLamaSFT: diff --git a/src/mattext/models/score.py b/src/mattext/models/score.py index 35343a3..62248ff 100644 --- a/src/mattext/models/score.py +++ b/src/mattext/models/score.py @@ -1,3 +1,4 @@ +import json import math from dataclasses import dataclass, field from typing import Any, Dict, List @@ -12,7 +13,6 @@ precision_recall_fscore_support, roc_auc_score, ) -import json MATTEXT_MATBENCH = { "kvrh": "matbench_log_kvrh", @@ -32,6 +32,7 @@ "form_energy": "e_form", } + def load_true_scores(dataset, mbids): data_frame = load(MATTEXT_MATBENCH[dataset]) scores = [] @@ -40,6 +41,7 @@ def load_true_scores(dataset, mbids): scores.append(score) return scores + @dataclass class MatTextTask: task_name: str @@ -49,19 +51,27 @@ class MatTextTask: folds_results: Dict[int, Dict[str, Any]] = field(default_factory=dict) recorded_folds: List[int] = field(default_factory=list) - def record_fold(self, fold: int, prediction_ids: List[str], predictions: List[float]): + def record_fold( + self, fold: int, prediction_ids: List[str], predictions: List[float] + ): if fold in self.recorded_folds: raise ValueError(f"Fold {fold} has already been recorded.") true_scores = load_true_scores(self.task_name, prediction_ids) - + if self.is_classification: - self._calculate_classification_metrics(fold, prediction_ids, predictions, true_scores) + self._calculate_classification_metrics( + fold, prediction_ids, predictions, true_scores + ) else: - self._calculate_regression_metrics(fold, prediction_ids, predictions, true_scores) - + self._calculate_regression_metrics( + fold, prediction_ids, predictions, true_scores + ) + self.recorded_folds.append(fold) - def _calculate_regression_metrics(self, fold, prediction_ids, predictions, true_scores): + def _calculate_regression_metrics( + self, fold, prediction_ids, predictions, true_scores + ): mae = mean_absolute_error(true_scores, predictions) rmse = math.sqrt(mean_squared_error(true_scores, predictions)) self.folds_results[fold] = { @@ -72,11 +82,19 @@ def _calculate_regression_metrics(self, fold, prediction_ids, predictions, true_ "rmse": rmse, } - def _calculate_classification_metrics(self, fold, prediction_ids, predictions, true_labels): + def _calculate_classification_metrics( + self, fold, prediction_ids, predictions, true_labels + ): pred_labels = np.argmax(predictions, axis=1) accuracy = accuracy_score(true_labels, pred_labels) - precision, recall, f1, _ = precision_recall_fscore_support(true_labels, pred_labels, average='weighted') - roc_auc = roc_auc_score(true_labels, predictions[:, 1]) if self.num_classes == 2 else None + precision, recall, f1, _ = precision_recall_fscore_support( + true_labels, pred_labels, average="weighted" + ) + roc_auc = ( + roc_auc_score(true_labels, predictions[:, 1]) + if self.num_classes == 2 + else None + ) self.folds_results[fold] = { "prediction_ids": prediction_ids, "predictions": predictions, @@ -85,7 +103,7 @@ def _calculate_classification_metrics(self, fold, prediction_ids, predictions, t "precision": precision, "recall": recall, "f1": f1, - "roc_auc": roc_auc + "roc_auc": roc_auc, } def get_final_results(self): @@ -97,9 +115,9 @@ def get_final_results(self): def _aggregate_results(self): if self.is_classification: - metrics = ['accuracy', 'precision', 'recall', 'f1', 'roc_auc'] + metrics = ["accuracy", "precision", "recall", "f1", "roc_auc"] else: - metrics = ['mae', 'rmse'] + metrics = ["mae", "rmse"] final_scores = {metric: [] for metric in metrics} for fold in range(self.num_folds): @@ -108,9 +126,13 @@ def _aggregate_results(self): final_scores[metric].append(self.folds_results[fold][metric]) return { - f"mean_{metric}": np.mean(scores) for metric, scores in final_scores.items() if scores + f"mean_{metric}": np.mean(scores) + for metric, scores in final_scores.items() + if scores } | { - f"std_{metric}": np.std(scores) for metric, scores in final_scores.items() if scores + f"std_{metric}": np.std(scores) + for metric, scores in final_scores.items() + if scores } def to_file(self, file_path: str): @@ -121,8 +143,12 @@ def to_file(self, file_path: str): def from_file(file_path: str): with open(file_path) as f: data = json.load(f) - task = MatTextTask(task_name=data["task_name"], num_folds=data["num_folds"], - is_classification=data["is_classification"], num_classes=data["num_classes"]) + task = MatTextTask( + task_name=data["task_name"], + num_folds=data["num_folds"], + is_classification=data["is_classification"], + num_classes=data["num_classes"], + ) task.folds_results = data["folds_results"] task.recorded_folds = data["recorded_folds"] return task @@ -140,6 +166,6 @@ def _json_serializable(obj): "is_classification": obj.is_classification, "num_classes": obj.num_classes, "folds_results": obj.folds_results, - "recorded_folds": obj.recorded_folds + "recorded_folds": obj.recorded_folds, } - raise TypeError(f"Type {type(obj)} not serializable") \ No newline at end of file + raise TypeError(f"Type {type(obj)} not serializable") diff --git a/src/mattext/models/utils.py b/src/mattext/models/utils.py index 13fd2cc..864df8c 100644 --- a/src/mattext/models/utils.py +++ b/src/mattext/models/utils.py @@ -2,6 +2,7 @@ import torch import wandb +from loguru import logger from tqdm import tqdm from transformers import GenerationConfig, TrainerCallback from transformers.integrations import WandbCallback @@ -117,8 +118,8 @@ def __init__( truncation=False, padding=False, ) - print(f"special_tokens: {special_tokens}") - print(self._wrapped_tokenizer.tokenize("Se2Se3")) + logger.info(f"special_tokens: {special_tokens}") + logger.info(self._wrapped_tokenizer.tokenize("Se2Se3")) # self._wrapped_tokenizer.add_special_tokens(special_tokens=special_tokens) @@ -188,7 +189,7 @@ def on_log( if state.is_world_process_zero: step = state.global_step # Retrieve the current step epoch = state.epoch # Retrieve the current epoch - print(f"Step: {step}, Epoch: {round(epoch,5)}") + logger.info(f"Step: {step}, Epoch: {round(epoch,5)}") if ( "loss" in logs and "eval_loss" in logs From 2d5d9019c111deceee1ee9004279e84b73309157 Mon Sep 17 00:00:00 2001 From: n0w0f Date: Fri, 11 Oct 2024 10:28:16 +0200 Subject: [PATCH 20/20] fix: empty gpu vram after each fold --- src/mattext/models/llama_sft.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/src/mattext/models/llama_sft.py b/src/mattext/models/llama_sft.py index 8ab47eb..5cf015d 100644 --- a/src/mattext/models/llama_sft.py +++ b/src/mattext/models/llama_sft.py @@ -209,14 +209,14 @@ def finetune(self) -> None: trainer.save_state() trainer.save_model(self.output_dir_) - # Merge LoRA and base model - merged_model = trainer.model.merge_and_unload() - # Save the merged model - merged_model.save_pretrained( - f"{self.cfg.path.finetuned_modelname}_{self.fold}/llamav3-8b-lora-save-pretrained", - save_config=True, - safe_serialization=True, - ) + # # Merge LoRA and base model + # merged_model = trainer.model.merge_and_unload() + # # Save the merged model + # merged_model.save_pretrained( + # f"{self.cfg.path.finetuned_modelname}_{self.fold}/llamav3-8b-lora-save-pretrained", + # save_config=True, + # safe_serialization=True, + # ) self.tokenizer.save_pretrained( f"{self.cfg.path.finetuned_modelname}_{self.fold}/llamav3-8b-lora-save-pretrained" ) @@ -231,5 +231,15 @@ def finetune(self) -> None: ) as json_file: json.dump(merge_pred, json_file) + # Empty VRAM + del trainer + del collator + del pipe + del self.model + del self.tokenizer + import gc + + gc.collect() + gc.collect() wandb.finish() return self.cfg.path.finetuned_modelname