Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Starting simpler programmatic interface #236

Closed
wants to merge 12 commits into from
2 changes: 2 additions & 0 deletions src/lighteval/logging/evaluation_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ def save(self) -> None:

config_general = copy.deepcopy(self.general_config_logger)
config_general = asdict(config_general)
# We remove the config from logging, which contains context/accelerator objects
config_general.pop("config")

to_dump = {
"config_general": config_general,
Expand Down
2 changes: 1 addition & 1 deletion src/lighteval/logging/info_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from lighteval.logging.hierarchical_logger import hlog_warn
from lighteval.metrics import MetricCategory
from lighteval.metrics.stderr import get_stderr_function
from lighteval.models.model_loader import ModelInfo
from lighteval.models.abstract_model import ModelInfo
from lighteval.models.model_output import ModelReturn
from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig
from lighteval.tasks.requests import Doc
Expand Down
124 changes: 30 additions & 94 deletions src/lighteval/main_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,13 @@
# SOFTWARE.

import os
import random
import shutil
from contextlib import nullcontext
from datetime import timedelta

import numpy as np

from lighteval.evaluator import evaluate, make_results_table
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.logging.hierarchical_logger import hlog, hlog_warn, htrack, htrack_block
from lighteval.models.model_config import EnvConfig, create_model_config
from lighteval.models.model_loader import load_model
from lighteval.tasks.lighteval_task import LightevalTask, create_requests_from_tasks
from lighteval.tasks.registry import Registry, taskinfo_selector
from lighteval.logging.hierarchical_logger import hlog_warn, htrack
from lighteval.models.model_config import create_model_config
from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters
from lighteval.utils import is_accelerate_available, is_tgi_available
from lighteval.utils_parallelism import test_all_gather


if not is_accelerate_available() and not is_tgi_available():
Expand Down Expand Up @@ -64,87 +55,32 @@ def main(args):
public=args.public_run,
token=TOKEN,
)
evaluation_tracker.general_config_logger.log_args_info(
args.num_fewshot_seeds, args.override_batch_size, args.max_samples, args.job_id
pipeline_params = PipelineParameters(
launcher_type=ParallelismManager.ACCELERATE,
env_config=env_config,
job_id=args.job_id,
dataset_loading_processes=args.dataset_loading_processes,
custom_tasks_directory=args.custom_tasks,
override_batch_size=args.override_batch_size,
num_fewshot_seeds=args.num_fewshot_seeds,
max_samples=args.max_samples,
use_chat_template=args.use_chat_template,
system_prompt=args.system_prompt,
)

model_config = create_model_config(args=args, accelerator=accelerator)

pipeline = Pipeline(
tasks=args.tasks,
pipeline_parameters=pipeline_params,
evaluation_tracker=evaluation_tracker,
model_config=model_config,
)

if args.max_samples:
hlog(
"WARNING: --max_samples WAS SET. THESE NUMBERS ARE ONLY PARTIAL AND SHOULD NOT BE USED FOR COMPARISON UNLESS YOU KNOW WHAT YOU ARE DOING."
)

with htrack_block("Test all gather"):
test_all_gather(accelerator)

with htrack_block("Creating model configuration"):
model_config = create_model_config(args=args, accelerator=accelerator)

with htrack_block("Model loading"):
with accelerator.main_process_first() if accelerator is not None else nullcontext():
model, model_info = load_model(config=model_config, env_config=env_config)
evaluation_tracker.general_config_logger.log_model_info(model_info)

with htrack_block("Tasks loading"):
with accelerator.main_process_first() if accelerator is not None else nullcontext():
task_names_list, few_shots_dict = taskinfo_selector(args.tasks)
task_dict = Registry(cache_dir=env_config.cache_dir).get_task_dict(
task_names_list, custom_tasks=args.custom_tasks
)
LightevalTask.load_datasets(task_dict.values(), args.dataset_loading_processes)

evaluation_tracker.task_config_logger.log(task_dict)

hlog("Loading documents, and requests")
requests, docs = create_requests_from_tasks(
task_dict=task_dict,
fewshot_dict=few_shots_dict,
num_fewshot_seeds=args.num_fewshot_seeds,
lm=model,
max_samples=args.max_samples,
evaluation_tracker=evaluation_tracker,
use_chat_template=args.use_chat_template,
system_prompt=args.system_prompt,
)

with htrack_block("Setting seeds and waiting for all processes"):
hlog(f"setting seed to {1234} for random and numpy")
random.seed(1234)
np.random.seed(1234)
if accelerator is not None:
accelerator.wait_for_everyone()

with htrack_block("Evaluation"):
hlog(f"Evaluate on {len(task_names_list)} tasks.")
evaluation_tracker = evaluate(
lm=model,
requests_dict=requests,
docs=docs,
task_dict=task_dict,
override_bs=args.override_batch_size,
evaluation_tracker=evaluation_tracker,
)

if accelerator.is_main_process if accelerator is not None else nullcontext():
with htrack_block("Compiling and saving results"):
evaluation_tracker.general_config_logger.log_end_time()
evaluation_tracker.metrics_logger.aggregate(task_dict=task_dict, bootstrap_iters=1000)
evaluation_tracker.details_logger.aggregate()

if args.output_dir:
evaluation_tracker.save()

final_dict = evaluation_tracker.generate_final_dict()

with htrack_block("Cleaninp up"):
for weights in ["delta", "adapter"]:
try:
tmp_weights_dir = f"{evaluation_tracker.general_config_logger.model_name}-{weights}-applied"
hlog(f"Removing {tmp_weights_dir}")
shutil.rmtree(tmp_weights_dir)
except OSError:
pass

print(make_results_table(final_dict))

model.cleanup()
return final_dict
pipeline.evaluate()

results = pipeline.show_results()
clefourrier marked this conversation as resolved.
Show resolved Hide resolved

pipeline.save_and_push_results()

return results
197 changes: 49 additions & 148 deletions src/lighteval/main_nanotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,179 +22,80 @@

# flake8: noqa: C901
import os
import random
from typing import Optional, Type
from typing import Optional

import numpy as np

from lighteval.evaluator import evaluate, make_results_table
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.logging.hierarchical_logger import hlog, htrack, htrack_block
from lighteval.models.model_config import EnvConfig
from lighteval.models.model_loader import ModelInfo
from lighteval.models.nanotron_model import NanotronLightevalModel
from lighteval.tasks.lighteval_task import LightevalTask, create_requests_from_tasks
from lighteval.tasks.registry import Registry, get_custom_tasks, taskinfo_selector
from lighteval.utils import NO_NANOTRON_ERROR_MSG, is_nanotron_available
from lighteval.utils_parallelism import test_all_gather
from lighteval.logging.hierarchical_logger import htrack, htrack_block
from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters
from lighteval.utils import NO_NANOTRON_ERROR_MSG, EnvConfig, is_nanotron_available


if not is_nanotron_available():
raise ImportError(NO_NANOTRON_ERROR_MSG)

from nanotron import distributed as dist
from nanotron.config import Config, LightEvalConfig, get_config_from_file
from nanotron.logging import get_logger
from nanotron.parallel.context import ParallelContext
from nanotron.utils import local_ranks_zero_first


logger = get_logger(__name__)

SEED = 1234
TOKEN = os.getenv("HF_TOKEN")
CACHE_DIR = os.getenv("HF_HOME", "/scratch")


@htrack()
def main(
checkpoint_config_path: str,
lighteval_config_path: Optional[str] = None,
cache_dir: Optional[str] = None,
config_cls: Type = Config,
model_config_cls: Optional[Type] = None,
model_cls: Optional[Type] = None,
cache_dir: Optional[str] = os.getenv("HF_HOME", "/scratch"),
):
if cache_dir is None:
cache_dir = CACHE_DIR

env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir)
env_config = EnvConfig(token=os.getenv("HF_TOKEN"), cache_dir=cache_dir)

dist.initialize_torch_distributed()

with htrack_block("get config"):
with htrack_block("Load nanotron config"):
# Create nanotron config
if not checkpoint_config_path.endswith(".yaml"):
raise ValueError("The checkpoint path should point to a YAML file")

nanotron_config: config_cls = get_config_from_file(
model_config = get_config_from_file(
checkpoint_config_path,
config_class=config_cls,
model_config_class=model_config_cls,
config_class=Config,
model_config_class=None,
skip_unused_config_keys=True,
skip_null_keys=True,
)

if lighteval_config_path:
lighteval_config: config_cls = get_config_from_file(lighteval_config_path, config_class=LightEvalConfig)
nanotron_config.lighteval = lighteval_config
lighteval_config = get_config_from_file(lighteval_config_path, config_class=LightEvalConfig)
model_config.lighteval = lighteval_config
else:
lighteval_config = nanotron_config.lighteval

parallel_context = ParallelContext(
tensor_parallel_size=lighteval_config.parallelism.tp,
pipeline_parallel_size=lighteval_config.parallelism.pp,
data_parallel_size=lighteval_config.parallelism.dp,
)

evaluation_tracker = EvaluationTracker(
token=TOKEN,
output_dir=lighteval_config.logging.local_output_path,
hub_results_org=lighteval_config.logging.hub_repo_tensorboard,
tensorboard_metric_prefix=lighteval_config.logging.tensorboard_metric_prefix,
nanotron_run_info=nanotron_config.general,
)
evaluation_tracker.general_config_logger.log_args_info(
num_fewshot_seeds=1,
override_batch_size=None,
max_samples=lighteval_config.tasks.max_samples,
job_id=os.environ.get("SLURM_JOB_ID", None),
config=nanotron_config,
)

with htrack_block("Test all gather"):
test_all_gather(parallel_context=parallel_context)

with htrack_block("Model loading"):
# We need to load the model in the main process first to avoid downloading the model multiple times
model = NanotronLightevalModel(
checkpoint_path=os.path.dirname(checkpoint_config_path),
model_args=nanotron_config.model,
tokenizer=nanotron_config.tokenizer,
parallel_context=parallel_context,
parallel_config=lighteval_config.parallelism,
lighteval_config=lighteval_config,
batch_size=lighteval_config.batch_size,
debug_one_layer_model=False,
model_class=model_cls,
env_config=env_config,
)
model_info = ModelInfo(model_name=f"{nanotron_config.general.run}/{nanotron_config.general.step}")
evaluation_tracker.general_config_logger.log_model_info(model_info)

with htrack_block("Tasks loading"):
with local_ranks_zero_first():
tasks_selection = lighteval_config.tasks.tasks
if lighteval_config.tasks.custom_tasks:
_, tasks_groups_dict = get_custom_tasks(lighteval_config.tasks.custom_tasks)
if tasks_groups_dict and lighteval_config.tasks.tasks in tasks_groups_dict:
tasks_selection = tasks_groups_dict[lighteval_config.tasks.tasks]

task_names_list, few_shots_dict = taskinfo_selector(tasks_selection)
task_dict = Registry(cache_dir=cache_dir).get_task_dict(
task_names_list,
custom_tasks=lighteval_config.tasks.custom_tasks,
)
# Loading all the dataset in a distributed manner
LightevalTask.load_datasets(task_dict.values(), lighteval_config.tasks.dataset_loading_processes)

evaluation_tracker.task_config_logger.log(task_dict)

hlog("Loading documents, and requests")
requests, docs = create_requests_from_tasks(
task_dict=task_dict,
fewshot_dict=few_shots_dict,
num_fewshot_seeds=lighteval_config.tasks.num_fewshot_seeds or 1,
lm=model,
max_samples=lighteval_config.tasks.max_samples,
evaluation_tracker=evaluation_tracker,
use_chat_template=False,
system_prompt=None,
)

with htrack_block("Setting seeds and waiting for all processes"):
hlog(f"setting seed to {SEED} for random and numpy")
random.seed(SEED)
np.random.seed(SEED)
dist.barrier()

with htrack_block("Evaluation"):
hlog(f"Evaluate on {len(task_names_list)} tasks.")
evaluation_tracker = evaluate(
lm=model,
requests_dict=requests,
docs=docs,
task_dict=task_dict,
override_bs=lighteval_config.batch_size,
evaluation_tracker=evaluation_tracker,
)

if dist.get_rank(parallel_context.world_pg) == 0:
with htrack_block("Compiling and saving results"):
evaluation_tracker.general_config_logger.log_end_time()
evaluation_tracker.metrics_logger.aggregate(task_dict=task_dict, bootstrap_iters=1000)
evaluation_tracker.details_logger.aggregate()

if lighteval_config.logging.local_output_path:
evaluation_tracker.save(
output_dir=lighteval_config.logging.local_output_path,
push_results_to_hub=lighteval_config.logging.push_results_to_hub,
push_details_to_hub=lighteval_config.logging.push_details_to_hub,
public=False,
push_results_to_tensorboard=lighteval_config.logging.push_results_to_tensorboard,
)

final_dict = evaluation_tracker.generate_final_dict()

hlog(make_results_table(final_dict))

return final_dict
lighteval_config = model_config.lighteval

evaluation_tracker = EvaluationTracker(
token=os.getenv("HF_TOKEN"),
output_dir=lighteval_config.logging.local_output_path,
hub_results_org=lighteval_config.logging.hub_repo_tensorboard,
tensorboard_metric_prefix=lighteval_config.logging.tensorboard_metric_prefix,
nanotron_run_info=model_config.general,
)

pipeline_parameters = PipelineParameters(
launcher_type=ParallelismManager.NANOTRON,
env_config=env_config,
job_id=os.environ.get("SLURM_JOB_ID", None),
nanotron_checkpoint_path=checkpoint_config_path,
dataset_loading_processes=lighteval_config.tasks.dataset_loading_processes,
custom_tasks_directory=lighteval_config.tasks.custom_tasks,
override_batch_size=None,
num_fewshot_seeds=1,
max_samples=lighteval_config.tasks.max_samples,
use_chat_template=False,
system_prompt=None,
)

pipeline = Pipeline(
tasks=lighteval_config.tasks.tasks,
pipeline_parameters=pipeline_parameters,
evaluation_tracker=evaluation_tracker,
model_config=model_config,
)

pipeline.evaluate()

pipeline.show_results()

pipeline.save_and_push_results()
Loading
Loading