Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LCM_MSE eval fails with cnn_dailymail prepared parquet due to missing keys #19

Open
jamesdhope opened this issue Jan 14, 2025 · 15 comments

Comments

@jamesdhope
Copy link

jamesdhope commented Jan 14, 2025

Following evaluation instructions to evaluate the pre-trained LCM_MSE on cnn_dailymail parquet data.

Run command

CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --standalone --nnodes=1 --nproc-per-node=4 -m lcm.evaluation  \
  --predictor base_lcm --sample_latent_variable False \
  --model_card checkpoints/mse_lcm/checkpoints/step_2000/model_card.yaml \
  --launcher standalone \
  --dataset.parquet_path examples/evaluation/parquet_dataset/cnn_dailymail/0_ae89e535f2a41f33_0_0.parquet \
  --tasks lcm_generation \
  --task_args '{"max_gen_len": 200}' \
  --data_loading.batch_size 4  --generator_batch_size 4 \
  --dump_dir /mnt/large_concept_model/output

Error:

'Key _source_column not found in batch.'

Investigation

Adding a print(batch.keys) to data_utilis.py reveals keys the iterate_batches is looking for _source_column key

dict_keys(['split', '__batch_index', '__fragment_index', '__filename', '__row_groups_ids', '__index_in_fragement'])

The cnn-dailymail generated parquet columns using the prepare script are:

Index(['prompt', 'split', 'category', 'answer', 'answer_sentences',
       'prompt_sentences', 'answer_sentences_sonar_emb',
       'prompt_sentences_sonar_emb'],
      dtype='object')

and cnn_dailymail.py also must be modified with the following:

if form != "inverse_":
        source_text_column = "prompt"
        target_text_column = "answer"
        dataset.source_prefix_text = "[INST] Summarize the following article: "
        dataset.source_suffix_text = " [/INST]"
    else:
        source_text_column = "answer"
        target_text_column = "prompt"
        dataset.source_prefix_text = ("[INST] Write an article from the following summary: ")  # fmt: skip
        dataset.source_suffix_text = " [/INST]"
@jamesdhope
Copy link
Author

jamesdhope commented Jan 15, 2025

Updated run command with dataset.source_column and dataset.target_column yields further errors:

CUDA_VISIBLE_DEVICES=0,1 torchrun --standalone --nnodes=1 --nproc-per-node=2 -m lcm.evaluation  \
  --predictor base_lcm --sample_latent_variable False \
  --model_card checkpoints/mse_lcm/checkpoints/step_2000/model_card.yaml \
  --launcher standalone \
  --dataset.parquet_path /mnt/large_concept_model/examples/evaluation/parquet_dataset/cnn_dailymail/0_ae89e535f2a41f33_0_0.parquet \
  --dataset.source_column prompt_sentences_sonar_emb \
  --dataset.target_column answer_sentences_sonar_emb \
  --tasks lcm_generation \
  --task_args '{"max_gen_len": 200}' \
  --data_loading.batch_size 4  --generator_batch_size 4 \
  --dump_dir /mnt/large_concept_model/output
'targets'
> /mnt/large_concept_model/lcm/evaluation/utils/common.py(150)<genexpr>()
-> outputs = fn(*(x[k] for k in input_keys), **kwargs)
(Pdb) 

@antoine-tran
Copy link
Contributor

@jamesdhope : There is a patch #21 (we haven't merged this yet due to some bugs in the third party lib stopes that failed the CI). Do you want to give this a try ?

@jamesdhope
Copy link
Author

jamesdhope commented Jan 15, 2025

@antoine-tran Reproduced the same issue with the patch merged locally in #21 , re-running prepare, embed and eval. Eval fails with the error below.

This appears to be a separate issue affecting LCM eval.

CUDA_VISIBLE_DEVICES=0,1 torchrun --standalone --nnodes=1 --nproc-per-node=2 -m lcm.evaluation  \
  --predictor base_lcm --sample_latent_variable False \
  --model_card checkpoints/mse_lcm/checkpoints/step_2000/model_card.yaml \
  --launcher standalone \
  --dataset.parquet_path /mnt/large_concept_model/examples/evaluation/parquet_dataset/cnn_dailymail/0_3e1f58ddc7724a53_0_0.parquet \
  --dataset.source_column prompt_sentences_sonar_emb \
  --dataset.target_column answer_sentences_sonar_emb \
  --tasks lcm_generation \
  --task_args '{"max_gen_len": 200}' \
  --data_loading.batch_size 4  --generator_batch_size 4 \
  --dump_dir /mnt/large_concept_model/output
'targets'
> /mnt/large_concept_model/lcm/evaluation/utils/common.py(150)<genexpr>()
-> outputs = fn(*(x[k] for k in input_keys), **kwargs)

Please note the example run command for the two tower LCM does not include the source_column and target_column flags however the script fails earlier on if these are not supplied for the base LCM. Also the original issue may have been resolved with the dataset flags although I can't be sure.

parquet columns are:

Index(['prompt', 'split', 'category', 'answer', 'answer_sentences',
       'prompt_sentences', 'answer_sentences_sonar_emb',
       'prompt_sentences_sonar_emb'],
      dtype='object')

@jamesdhope jamesdhope changed the title LCM_MSE eval fails with cnn_dailymail prepared parquet due to missing _source_column key LCM_MSE eval fails with cnn_dailymail prepared parquet due to missing keys Jan 16, 2025
@YujiaHu0819
Copy link

Hi all! Could you please elaborate on how you acquired the model checkpoints used for evaluation? Did you conduct the entire training process independently, or did Meta provide pre-trained checkpoints for public access?

Thank you in advance!

@jamesdhope
Copy link
Author

@YujiaHu0819 I completed the pre-training step with the Wikipedia data as per the readme.md instructions to obtain the model checkpoints. Please note that I did not complete the fine tuning step.

@hasanyazarr
Copy link

hasanyazarr commented Jan 18, 2025

Hey, I also encounter this error too in the LCM evaluation part. My training code is:

!CUDA_VISIBLE_DEVICES=0 torchrun --standalone --nnodes=1 --nproc-per-node=1 -m large_concept_model.lcm.train \
    launcher=standalone +pretrain=mse \
    ++trainer.data_loading_config.max_tokens=512 \
    ++trainer.output_dir="/content/drive/MyDrive/LCM/checkpoints/mse_lcm" \
    +trainer.use_submitit=false

Then I downloaded LCM evaluation parquet file with:

# eval for LCM
!uv run torchrun --standalone --nnodes=1 --nproc-per-node=1 -m lcm.evaluation \
  --predictor base_lcm  \
  --model_card /content/drive/MyDrive/LCM/checkpoints/mse_lcm/checkpoints/step_10000/model_card.yaml \
  --generator_batch_size 16 \
  --tasks lcm_generation \
  --task_args '{"max_gen_len": 200}' \
  --dataset.parquet_path /content/parquet_dataset/cnn_dailymail/lcm_eval.parquet \
  --data_loading.batch_size 16 \
  --dump_dir content/output_results_lcm

and I get following error:

[2025-01-18 21:31:16,942] [rank 0] [INFO] Selected task execution: ['lcm_generation']
[2025-01-18 21:31:16,942] [rank 0] [INFO] Running evaluation on task lcm_generation
[2025-01-18 21:31:17,054] [rank 0] [INFO] Setting 'cuda:0' as the default device of the process.
[2025-01-18 21:31:17,236] [rank 0] [INFO] Card loaded: {'source': 'inproc', 'checkpoint': 'file:///content/drive/MyDrive/LCM/checkpoints/mse_lcm/checkpoints/step_10000/model.pt', 'model_arch': 'base_lcm_1_6B', 'model_family': 'base_lcm', 'name': 'on_the_fly_lcm'}
[2025-01-18 21:31:20,173] [rank 0] [INFO] Building sonar_normalizer = dummy_sonar_normalizer
[2025-01-18 21:31:20,174] [rank 0] [INFO] Using LCMFrontend with embeddings scaler = 1.0
[2025-01-18 21:31:20,177] [rank 0] [INFO] Initializing frontend embeddings (special and positional) ~ N(0, 0.006)
[2025-01-18 21:31:23,386] [rank 0] [WARNING] eos_threshold is set to 0.9, but eos_vec is not provided
[2025-01-18 21:31:23,386] [rank 0] [WARNING] eos_threshold is set to 0.9, but eos_vec is not provided
[2025-01-18 21:31:23,387] [rank 0] [INFO] Downloading the checkpoint of text_sonar_basic_decoder...
[2025-01-18 21:31:46,104] [rank 0] [INFO] Download complete.
[2025-01-18 21:31:57,479] [rank 0] [INFO] Using the cached tokenizer of text_sonar_basic_decoder. Set force to True to download again.
[2025-01-18 21:31:57,980] [rank 0] [INFO] Predictor loaded: LCMPredictor
[2025-01-18 21:31:57,981] [rank 0] [INFO] Using rank=0 among world_size=1 to build self._pipeline
[2025-01-18 21:31:57,983] [rank 0] [INFO] Following columns will be loaded: ['split']
100% 1/1 [00:00<00:00, 5426.01it/s]
[2025-01-18 21:31:58,008] [rank 0] [INFO] Bucketing will require at least: 10000 of tokens (source + target)
[2025-01-18 21:31:58,008] [rank 0] [INFO] Dataset stats: {'min_number_of_fragment': 1, 'mean_fragment_length': 11490.0, 'mean_fragment_number_of_tokens': None}
[2025-01-18 21:31:58,008] [rank 0] [INFO] Dataset Config: ParquetDatasetConfig(columns=['split'], source_text_column=None, target_text_column=None, source_prefix_text=None, source_suffix_text=None, target_prefix_text=None, target_suffix_text=None, source_sequences=None, target_sequences=None, silent_freeze=True, name=None, parquet_path='/content/parquet_dataset/cnn_dailymail/lcm_eval.parquet', weight=1.0, limit=None, source_column=None, target_column=None, source_quality_column=None, source_quality_range=None, partition_filters=None, filters=None, filesystem_expr=None, filesystem=None, split_to_row_groups=True, nb_parallel_fragments=1, sharding_in_memory=False)
[2025-01-18 21:31:58,008] [rank 0] [INFO] Using Loading Config: EvaluationDataLoadingConfig(multiple_dataset_chaining='concat', batch_size=16, order_by_length=True, max_tokens=None, len_to_wrap_long_seq=None, packing=False, wrap_before_affixing=False, max_sentence_len_in_doc=None, min_sentence_len_in_doc=None, max_sentence_len_in_target_doc=None, min_sentence_len_in_target_doc=None, min_length_of_sequences=1, min_length_of_sequences_after_batching=1, min_length_of_target_sequences=1, min_length_of_target_sequences_after_batching=1, output_format=<ParquetBatchFormat.torch: 2>, shuffle=False, drop_null=True, seed=123, nb_epochs=1, min_batch_size=1, nb_prefetch=3.0, num_parallel_calls=1.5, use_threads=False, ignore_checkpointed_pipeline=False, even_sharding=False, max_iteration_steps=None, sharding_in_memory=True, rank=0, world_size=1, max_samples=None)
[2025-01-18 21:31:58,009] [rank 0] [INFO] Activating sharding_in_memory
[2025-01-18 21:31:58,011] [rank 0] [INFO] /content/parquet_dataset/cnn_dailymail : full number of files 1
[2025-01-18 21:31:58,011] [rank 0] [INFO] /content/parquet_dataset/cnn_dailymail : starting split in row groups
/content/large_concept_model/lcm/datasets/parquet_utils.py:162: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
return torch.from_numpy(arr.to_numpy(zero_copy_only=True))
[rank0]: Traceback (most recent call last):
[rank0]: File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
[rank0]: return _run_code(code, main_globals, None,
[rank0]: File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
[rank0]: exec(code, run_globals)
[rank0]: File "/content/large_concept_model/lcm/evaluation/main.py", line 59, in
[rank0]: local.main(cfg, logger=logger)
[rank0]: File "/content/large_concept_model/lcm/evaluation/cli/local.py", line 38, in main
[rank0]: metrics, result_file = run_task(run_config, logger=logger, gang=get_gang())
[rank0]: File "/content/large_concept_model/lcm/evaluation/run.py", line 223, in run_task
[rank0]: result = task.run(
[rank0]: File "/content/large_concept_model/lcm/evaluation/tasks/base.py", line 205, in run
[rank0]: for batch in dataset:
[rank0]: File "/content/large_concept_model/lcm/evaluation/utils/data_utils.py", line 470, in iterate_batches
[rank0]: for x in batch[source_key]:
[rank0]: KeyError: '_source_column'

My LCM Training outputs are attached in logs.txt file. You can see training details. I trained MSE_LCM.

Also, my datacards.yaml file below:

# FIXME
name: "pretraining_data_train"
parquet_path:
  s3: /content/large_concept_model/sample_data/train_data.parquet
source_column: "text_sentences_sonar_emb"
source_text_column: "text_sentences"
---
# FIXME
name: "pretraining_data_val"
parquet_path:
  s3: /content/large_concept_model/sample_data/val_data.parquet
source_column: "text_sentences_sonar_emb"
source_text_column: "text_sentences"
---
# FIXME
name: "finetuning_data"
parquet_path:
  s3: "cosmopedia_sample"
source_column: prompt_sentences_sonar_emb
source_text_column: prompt_sentences
target_column: text_sentences_sonar_emb
target_text_column: text_sentences
# partition columns:
# "split" (train, validation)

logs.txt

@jamesdhope
Copy link
Author

jamesdhope commented Jan 18, 2025

@hasanyazarr fix for this one is to specify the dataset.source_column and target_column in the run command which is missing from the example in the readme.md, however it will likely fail with the error above complaining about a missing key ☝️

@antoine-tran
Copy link
Contributor

@jamesdhope : Checking the code, dataset.source_text_column and dataset.target_text_column are also required in the post-processing of the LCM results before the metric compution: https://github.com/facebookresearch/large_concept_model/blob/main/lcm/evaluation/utils/data_utils.py#L638

So in your case it should be:

CUDA_VISIBLE_DEVICES=0,1 torchrun --standalone --nnodes=1 --nproc-per-node=2 -m lcm.evaluation  \
  --predictor base_lcm --sample_latent_variable False \
  --model_card checkpoints/mse_lcm/checkpoints/step_2000/model_card.yaml \
  --launcher standalone \
  --dataset.parquet_path /mnt/large_concept_model/examples/evaluation/parquet_dataset/cnn_dailymail/0_3e1f58ddc7724a53_0_0.parquet \
  --dataset.source_column prompt_sentences_sonar_emb \
  --dataset.source_text_column prompt_sentences \
  --dataset.target_column answer_sentences_sonar_emb \
  --dataset.target_text_column prompt_sentences \
  --tasks lcm_generation \
  --task_args '{"max_gen_len": 200}' \
  --data_loading.batch_size 4  --generator_batch_size 4 \
  --dump_dir /mnt/large_concept_model/output

Current documentation states it wrong that we only need one types of column (source_column for LCM, source_text_column for LLM). I made the patch in #22

@jamesdhope
Copy link
Author

That's resolved this issue. @antoine-tran @hasanyazarr do you have rough order of time for the eval script to run? With a max_gen_len of 10 with an L4 GPU I have no raw results file after 3 hours.

@hasanyazarr
Copy link

@antoine-tran Do you know a solution for '[rank 0] [WARNING] filtering table whose nb sentences and nb sonar vectors are aligned, keeping 2 rows out of11490'

Evaluation code you shared works well but this warning causes to evaluate only 2 rows out of 11490. Full output in below:

[2025-01-23 07:35:57,150] [rank 0] [INFO] submitted single job for lcm_generation_base_lcm_a591ec0874_2025-01-23_07-35-57: DEBUG_138828927617776
[2025-01-23 07:35:57,150] [rank 0] [INFO] Logs at: /content/executor_logs/lcm_generation_base_lcm_a591ec0874_2025-01-23_07-35-56/DEBUG_138828927617776_0_log.err
[2025-01-23 07:35:57,152] [rank 0] [WARNING] Logging is written both to stderr/stdout and to /content/executor_logs/lcm_generation_base_lcm_a591ec0874_2025-01-23_07-35-56/DEBUG_138828927617776_0_log.out/err. But call to print will only appear in the console.
[2025-01-23 07:35:57,157] [rank 0] [INFO] Writing configs and metadata to /content/drive/MyDrive/LCM/output_results_lcm/metadata.jsonl
[2025-01-23 07:35:57,163] [rank 0] [INFO] Evals version 0.1.0.dev0 (/content/large_concept_model/lcm/evaluation)
[2025-01-23 07:35:57,163] [rank 0] [INFO] Config: {'timestamp': '2025_01_23_07_35_57', 'command': '/content/large_concept_model/lcm/evaluation/main.py --predictor base_lcm --model_card /content/drive/MyDrive/LCM/checkpoints/mse_lcm/checkpoints/step_10000/model_card.yaml --launcher standalone --dataset.parquet_path /content/drive/MyDrive/LCM/eval_data/0_55ac997a0bfaa427_0_0.parquet --dataset.source_column prompt_sentences_sonar_emb --dataset.source_text_column prompt_sentences --dataset.target_column answer_sentences_sonar_emb --dataset.target_text_column prompt_sentences --tasks lcm_generation --task_args '{"max_gen_len": 200}' --data_loading.batch_size 4096 --generator_batch_size 4096 --dump_dir /content/drive/MyDrive/LCM/output_results_lcm '\'', 'git_info': {'git_repo': '/content/large_concept_model/lcm', 'commit': 'd6402232cb7195530904d565cfe7c66d70c2b2a3', 'branch': 'main', 'user': 'root'}, 'config': {'name': 'lcm_generation', 'task_name': 'lcm_generation', 'dump_dir': '/content/drive/MyDrive/LCM/output_results_lcm', 'predictor': 'base_lcm', 'params': {'dataset': {'columns': None, 'source_text_column': 'prompt_sentences', 'target_text_column': 'prompt_sentences', 'source_prefix_text': None, 'source_suffix_text': None, 'target_prefix_text': None, 'target_suffix_text': None, 'source_sequences': None, 'target_sequences': None, 'silent_freeze': True, 'name': None, 'parquet_path': '/content/drive/MyDrive/LCM/eval_data/0_55ac997a0bfaa427_0_0.parquet', 'weight': 1.0, 'limit': None, 'source_column': 'prompt_sentences_sonar_emb', 'target_column': 'answer_sentences_sonar_emb', 'source_quality_column': None, 'source_quality_range': None, 'partition_filters': None, 'filters': None, 'filesystem_expr': None, 'filesystem': None, 'split_to_row_groups': None, 'nb_parallel_fragments': None, 'sharding_in_memory': False}, 'max_gen_len': 200, 'max_gen_len_ratio': None, 'max_prompt_len': 2048, 'eos_config': None}, 'data_loading': {'multiple_dataset_chaining': 'concat', 'batch_size': 4096, 'order_by_length': True, 'max_tokens': None, 'len_to_wrap_long_seq': None, 'packing': False, 'wrap_before_affixing': False, 'max_sentence_len_in_doc': None, 'min_sentence_len_in_doc': None, 'max_sentence_len_in_target_doc': None, 'min_sentence_len_in_target_doc': None, 'min_length_of_sequences': 1, 'min_length_of_sequences_after_batching': 1, 'min_length_of_target_sequences': 1, 'min_length_of_target_sequences_after_batching': 1, 'output_format': <ParquetBatchFormat.torch: 2>, 'shuffle': False, 'drop_null': True, 'seed': 123, 'nb_epochs': 1, 'min_batch_size': 1, 'nb_prefetch': 3.0, 'num_parallel_calls': 1.5, 'use_threads': False, 'ignore_checkpointed_pipeline': False, 'even_sharding': False, 'max_iteration_steps': None, 'sharding_in_memory': True, 'rank': 0, 'world_size': 1, 'max_samples': None}, 'dataset': {'columns': None, 'source_text_column': 'prompt_sentences', 'target_text_column': 'prompt_sentences', 'source_prefix_text': None, 'source_suffix_text': None, 'target_prefix_text': None, 'target_suffix_text': None, 'source_sequences': None, 'target_sequences': None, 'silent_freeze': True, 'name': None, 'parquet_path': '/content/drive/MyDrive/LCM/eval_data/0_55ac997a0bfaa427_0_0.parquet', 'weight': 1.0, 'limit': None, 'source_column': 'prompt_sentences_sonar_emb', 'target_column': 'answer_sentences_sonar_emb', 'source_quality_column': None, 'source_quality_range': None, 'partition_filters': None, 'filters': None, 'filesystem_expr': None, 'filesystem': None, 'split_to_row_groups': None, 'nb_parallel_fragments': None, 'sharding_in_memory': False}, 'dtype': 'torch.float32', 'predictor_config': {'max_seq_len': 200, 'min_seq_len': 1, 'eos_threshold': 0.9, 'sample_latent_variable': True, 'stop_on_repetition_cosine_threshold': None, 'include_eos_token': False, 'trim_hypotheses': False, 'seed': 42, 'lcm_temperature': 1.0, 'model_card': '/content/drive/MyDrive/LCM/checkpoints/mse_lcm/checkpoints/step_10000/model_card.yaml', 'decoder_config': {'tokenizer': 'text_sonar_basic_decoder', 'decoder': 'text_sonar_basic_decoder', 'lang': 'eng_Latn', 'max_tokens_in_sentence': 256, 'temperature': 1.0}, 'encoder_config': {'tokenizer': 'text_sonar_basic_encoder', 'encoder': 'text_sonar_basic_encoder', 'lang': 'eng_Latn'}, 'generator_batch_size': 4096}, 'seed': 42, 'confidence_level': None, 'disable_cache': False, 'temperature': 0.0, 'top_k': 0, 'top_p': 0, 'metric_log_dir': '/content/drive/MyDrive/LCM/output_results_lcm', 'tb_log_dir': None, 'no_resume': False, 'metrics_to_report': None, 'show_progress': False, 'log_raw_results': True, 'log_only_text': False, 'requirements': {'nodes': 1, 'mem_gb': None, 'tasks_per_node': 1, 'gpus_per_node': 1, 'cpus_per_task': 4, 'timeout_min': 150, 'constraint': None, 'max_num_timeout': 10}, 'nshards': None, 'os_environs': None}, 'task_configs': {'dataset': ParquetDatasetConfig(columns=None, source_text_column='prompt_sentences', target_text_column='prompt_sentences', source_prefix_text=None, source_suffix_text=None, target_prefix_text=None, target_suffix_text=None, source_sequences=None, target_sequences=None, silent_freeze=True, name=None, parquet_path='/content/drive/MyDrive/LCM/eval_data/0_55ac997a0bfaa427_0_0.parquet', weight=1.0, limit=None, source_column='prompt_sentences_sonar_emb', target_column='answer_sentences_sonar_emb', source_quality_column=None, source_quality_range=None, partition_filters=None, filters=None, filesystem_expr=None, filesystem=None, split_to_row_groups=None, nb_parallel_fragments=None, sharding_in_memory=False), 'max_gen_len': 200, 'max_gen_len_ratio': None, 'max_prompt_len': 2048, 'eos_config': None}}
[2025-01-23 07:35:57,238] [rank 0] [INFO] Running task lcm_generation on cuda:0
[2025-01-23 07:35:57,242] [rank 0] [INFO] Setting 'cuda:0' as the default device of the process.
[2025-01-23 07:35:57,426] [rank 0] [INFO] Card loaded: {'source': 'inproc', 'checkpoint': 'file:///content/drive/MyDrive/LCM/checkpoints/mse_lcm/checkpoints/step_10000/model.pt', 'model_arch': 'base_lcm_1_6B', 'model_family': 'base_lcm', 'name': 'on_the_fly_lcm'}
[2025-01-23 07:36:00,871] [rank 0] [INFO] Building sonar_normalizer = dummy_sonar_normalizer
[2025-01-23 07:36:00,872] [rank 0] [INFO] Using LCMFrontend with embeddings scaler = 1.0
[2025-01-23 07:36:00,873] [rank 0] [INFO] Initializing frontend embeddings (special and positional) ~ N(0, 0.006)
[2025-01-23 07:36:03,788] [rank 0] [WARNING] eos_threshold is set to 0.9, but eos_vec is not provided
[2025-01-23 07:36:03,789] [rank 0] [INFO] Using the cached checkpoint of text_sonar_basic_decoder. Set force to True to download again.
[2025-01-23 07:36:15,290] [rank 0] [INFO] Using the cached tokenizer of text_sonar_basic_decoder. Set force to True to download again.
[2025-01-23 07:36:15,676] [rank 0] [INFO] Predictor loaded: LCMPredictor
[2025-01-23 07:36:15,677] [rank 0] [INFO] Using rank=0 among world_size=1 to build self._pipeline
[2025-01-23 07:36:15,878] [rank 0] [INFO] Following columns will be loaded: ['answer_sentences_sonar_emb', 'prompt_sentences', 'prompt_sentences_sonar_emb', 'split']
0% 0/1 [00:18<?, ?it/s]
100% 1/1 [00:00<00:00, 5269.23it/s]
[2025-01-23 07:36:15,905] [rank 0] [INFO] Bucketing will require at least: 664882 of tokens (source + target)
[2025-01-23 07:36:15,905] [rank 0] [INFO] Dataset stats: {'min_number_of_fragment': 1, 'mean_fragment_length': 11490.0, 'mean_fragment_number_of_tokens': 443255.0}
[2025-01-23 07:36:15,905] [rank 0] [INFO] Dataset Config: ParquetDatasetConfig(columns=['answer_sentences_sonar_emb', 'prompt_sentences', 'prompt_sentences_sonar_emb', 'split'], source_text_column='prompt_sentences', target_text_column='prompt_sentences', source_prefix_text=None, source_suffix_text=None, target_prefix_text=None, target_suffix_text=None, source_sequences=None, target_sequences=None, silent_freeze=True, name=None, parquet_path='/content/drive/MyDrive/LCM/eval_data/0_55ac997a0bfaa427_0_0.parquet', weight=1.0, limit=None, source_column='prompt_sentences_sonar_emb', target_column='answer_sentences_sonar_emb', source_quality_column=None, source_quality_range=None, partition_filters=None, filters=None, filesystem_expr=None, filesystem=None, split_to_row_groups=True, nb_parallel_fragments=1, sharding_in_memory=False)
[2025-01-23 07:36:15,906] [rank 0] [INFO] Using Loading Config: EvaluationDataLoadingConfig(multiple_dataset_chaining='concat', batch_size=4096, order_by_length=True, max_tokens=None, len_to_wrap_long_seq=None, packing=False, wrap_before_affixing=False, max_sentence_len_in_doc=None, min_sentence_len_in_doc=None, max_sentence_len_in_target_doc=None, min_sentence_len_in_target_doc=None, min_length_of_sequences=1, min_length_of_sequences_after_batching=1, min_length_of_target_sequences=1, min_length_of_target_sequences_after_batching=1, output_format=<ParquetBatchFormat.torch: 2>, shuffle=False, drop_null=True, seed=123, nb_epochs=1, min_batch_size=1, nb_prefetch=3.0, num_parallel_calls=1.5, use_threads=False, ignore_checkpointed_pipeline=False, even_sharding=False, max_iteration_steps=None, sharding_in_memory=True, rank=0, world_size=1, max_samples=None)
[2025-01-23 07:36:15,906] [rank 0] [INFO] Activating sharding_in_memory
[2025-01-23 07:36:15,909] [rank 0] [INFO] /content/drive/MyDrive/LCM/eval_data : full number of files 1
[2025-01-23 07:36:15,909] [rank 0] [INFO] /content/drive/MyDrive/LCM/eval_data : starting split in row groups
[2025-01-23 07:36:26,397] [rank 0] [WARNING] filtering table whose nb sentences and nb sonar vectors are aligned, keeping 2 rows out of11490
0% 0/1 [00:29<?, ?it/s]/content/large_concept_model/lcm/datasets/parquet_utils.py:162: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
return torch.from_numpy(arr.to_numpy(zero_copy_only=True))
[2025-01-23 07:40:20,849] [rank 0] [INFO] Using default tokenizer.
[2025-01-23 07:40:21,314] [rank 0] [INFO] Using default tokenizer.
[2025-01-23 07:40:21,755] [rank 0] [INFO] Writing raw results to /content/drive/MyDrive/LCM/output_results_lcm/raw_results/lcm_generation/lcm_generation_0 ( *.json | *.pt)
[2025-01-23 07:40:21,808] [rank 0] [INFO] written cache for lcm_generation_base_lcm_a591ec0874_2025-01-23_07-40-21:0
[2025-01-23 07:40:21,810] [rank 0] [INFO] lcm_generation_base_lcm_a591ec0874_2025-01-23_07-40-21 done after full execution
[2025-01-23 07:40:21,810] [rank 0] [INFO] Writing metric results to /content/drive/MyDrive/LCM/output_results_lcm/results/lcm_generation.json
[2025-01-23 07:40:21,817] [rank 0] [INFO] All evaluation results: rouge2: 0.002397 | rougel: 0.010584 | rougelsum: 0.013138
100% 1/1 [04:24<00:00, 264.83s/it]2025-01-23 07:40:22.180669: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable TF_ENABLE_ONEDNN_OPTS=0.
2025-01-23 07:40:22.197087: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-01-23 07:40:22.215684: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-01-23 07:40:22.221060: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-23 07:40:22.235472: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-01-23 07:40:23.335161: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[2025-01-23 07:40:24,285] [rank 0] [INFO] Writing Tensorboard logs to /content/drive/MyDrive/LCM/output_results_lcm/tb
[2025-01-23 07:40:24,292] [rank 0] [INFO] Writing metric logs to /content/drive/MyDrive/LCM/output_results_lcm/metrics.eval.jsonl
[2025-01-23 07:40:24,296] [rank 0] [INFO] Tasks ['lcm_generation'] took 267.35 seconds (including scheduling).
100% 1/1 [04:27<00:00, 267.31s/it]

@antoine-tran
Copy link
Contributor

@jamesdhope @hasanyazarr We haven't supported vllm / triton or other optimized inference libraries yet, and the eval lib relies on submitit to parallelize the jobs.

If you specify launcher=standalone, it will only run one GPU and uses this to evaluate the whole dataset. If you have n GPUs, you can specify --launcher slurm then set up --job_args '{"launcher.cache": "null",.., "nshards": '<n value put here>', "requirements": {"gpus_per_node": 1, "timeout_min": '<timeout_in_minutes>'}}' , like in the doc.

That said, the eval lib is not very optimal. Internally our eval run on cnndm with 5 GPUs took about 15-20 minutes to finish.

@antoine-tran
Copy link
Contributor

antoine-tran commented Jan 23, 2025

@hasanyazarr : I could not reproduce the issues. Could you try to run the following script (or something similar, I wrote this directly in the comment box without testing) and tell me how many data you got ?

import torch
from fairseq2.gang import FakeGang
from from lcm.datasets.configs import ParquetDatasetConfig, EvaluationDataLoadingConfig
from lcm.evaluation.utils.data_utils import ParquetTestDataLoader

dataset = ParquetDatasetConfig(
    parquet_path="YOUR Parquet file", 
    source_column="prompt_sentences_sonar_emb",
    source_text_column="prompt_sentences",
    target_column="answer_sentences_sonar_emb",
    target_text_column= "answer_sentences"
)

data_loading = EvaluationDataLoadingConfig(batch_size=1, seed=23, min_length_of_sequences=1, nb_epochs=1)

data_loader = ParquetTestDataLoader(
    data_config=data_loading,
    datasets=[dataset],
    gang=FakeGang(device=torch.device("cuda:0")),
    dtype=torch.float32,
)

cnt = 0
for batch in data_loader.iterate_batches():
    cnt += len(batch)

@hasanyazarr
Copy link

@antoine-tran The exact code you provided produces an error, but after fixing it, I get cnt = 11490.

@antoine-tran
Copy link
Contributor

ok, so the dataloading should be good. What is your actual command for evaluation again ?

@hasanyazarr
Copy link

!CUDA_VISIBLE_DEVICES=0 torchrun --standalone --nnodes=1 --nproc-per-node=1 -m lcm.evaluation  \
  --predictor base_lcm \
  --model_card /content/drive/MyDrive/LCM/checkpoints/mse_lcm/checkpoints/step_10000/model_card.yaml \
  --launcher standalone \
  --dataset.parquet_path /content/drive/MyDrive/LCM/eval_data/0_55ac997a0bfaa427_0_0.parquet \
  --dataset.source_column prompt_sentences_sonar_emb \
  --dataset.source_text_column prompt_sentences \
  --dataset.target_column answer_sentences_sonar_emb \
  --dataset.target_text_column prompt_sentences \
  --tasks lcm_generation \
  --task_args '{"max_gen_len": 200}' \
  --data_loading.batch_size 16  --generator_batch_size 16 \
  --dump_dir /content/drive/MyDrive/LCM/output_results_lcm \

This code generates that error. I also tried different batch_size values but it always uses 2 rows out of 11490. In my model_card path, I have files: rank_0.pt, metadata.pt, model_card.yaml and model.pt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants