diff --git a/pyproject.toml b/pyproject.toml index 9a4d3a3ce..2c3a76f5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,7 @@ dependencies = [ ] [project.optional-dependencies] +litellm = ["litellm", "diskcache"] tgi = ["text-generation==0.6.0"] optimum = ["optimum==1.12.0"] quantization = ["bitsandbytes>=0.41.0", "auto-gptq>=0.4.2"] diff --git a/src/lighteval/main_endpoint.py b/src/lighteval/main_endpoint.py index 208cc8386..19eb2a0e9 100644 --- a/src/lighteval/main_endpoint.py +++ b/src/lighteval/main_endpoint.py @@ -369,3 +369,112 @@ def tgi( pipeline.save_and_push_results() return results + + +@app.command(rich_help_panel="Evaluation Backends") +def litellm( + # === general === + model_name: Annotated[ + str, Argument(help="The model name to evaluate (has to be available through the litellm API.") + ], + tasks: Annotated[str, Argument(help="Comma-separated list of tasks to evaluate on.")], + # === Common parameters === + use_chat_template: Annotated[ + bool, Option(help="Use chat template for evaluation.", rich_help_panel=HELP_PANEL_NAME_4) + ] = False, + system_prompt: Annotated[ + Optional[str], Option(help="Use system prompt for evaluation.", rich_help_panel=HELP_PANEL_NAME_4) + ] = None, + dataset_loading_processes: Annotated[ + int, Option(help="Number of processes to use for dataset loading.", rich_help_panel=HELP_PANEL_NAME_1) + ] = 1, + custom_tasks: Annotated[ + Optional[str], Option(help="Path to custom tasks directory.", rich_help_panel=HELP_PANEL_NAME_1) + ] = None, + cache_dir: Annotated[ + str, Option(help="Cache directory for datasets and models.", rich_help_panel=HELP_PANEL_NAME_1) + ] = CACHE_DIR, + num_fewshot_seeds: Annotated[ + int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1) + ] = 1, + # === saving === + output_dir: Annotated[ + str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2) + ] = "results", + push_to_hub: Annotated[ + bool, Option(help="Push results to the huggingface hub.", rich_help_panel=HELP_PANEL_NAME_2) + ] = False, + push_to_tensorboard: Annotated[ + bool, Option(help="Push results to tensorboard.", rich_help_panel=HELP_PANEL_NAME_2) + ] = False, + public_run: Annotated[ + bool, Option(help="Push results and details to a public repo.", rich_help_panel=HELP_PANEL_NAME_2) + ] = False, + results_org: Annotated[ + Optional[str], Option(help="Organization to push results to.", rich_help_panel=HELP_PANEL_NAME_2) + ] = None, + save_details: Annotated[ + bool, Option(help="Save detailed, sample per sample, results.", rich_help_panel=HELP_PANEL_NAME_2) + ] = False, + # === debug === + max_samples: Annotated[ + Optional[int], Option(help="Maximum number of samples to evaluate on.", rich_help_panel=HELP_PANEL_NAME_3) + ] = None, + override_batch_size: Annotated[ + int, Option(help="Override batch size for evaluation.", rich_help_panel=HELP_PANEL_NAME_3) + ] = -1, + job_id: Annotated[ + int, Option(help="Optional job id for future refenrence.", rich_help_panel=HELP_PANEL_NAME_3) + ] = 0, +): + """ + Evaluate models using LiteLLM as backend. + """ + + from lighteval.logging.evaluation_tracker import EvaluationTracker + from lighteval.models.litellm_model import LiteLLMModelConfig + from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters + + env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir) + evaluation_tracker = EvaluationTracker( + output_dir=output_dir, + save_details=save_details, + push_to_hub=push_to_hub, + push_to_tensorboard=push_to_tensorboard, + public=public_run, + hub_results_org=results_org, + ) + + # TODO (nathan): better handling of model_args + parallelism_manager = ParallelismManager.NONE + + model_config = LiteLLMModelConfig(model=model_name) + + pipeline_params = PipelineParameters( + launcher_type=parallelism_manager, + env_config=env_config, + job_id=job_id, + dataset_loading_processes=dataset_loading_processes, + custom_tasks_directory=custom_tasks, + override_batch_size=override_batch_size, + num_fewshot_seeds=num_fewshot_seeds, + max_samples=max_samples, + use_chat_template=use_chat_template, + system_prompt=system_prompt, + ) + pipeline = Pipeline( + tasks=tasks, + pipeline_parameters=pipeline_params, + evaluation_tracker=evaluation_tracker, + model_config=model_config, + ) + + pipeline.evaluate() + + pipeline.show_results() + + results = pipeline.get_results() + + pipeline.save_and_push_results() + + return results diff --git a/src/lighteval/models/endpoints/openai_model.py b/src/lighteval/models/endpoints/openai_model.py index b2ca25285..8733474d0 100644 --- a/src/lighteval/models/endpoints/openai_model.py +++ b/src/lighteval/models/endpoints/openai_model.py @@ -145,7 +145,6 @@ def greedy_until( Args: requests (list[Request]): list of requests containing the context and ending conditions. - disable_tqdm (bool, optional): Whether to disable the progress bar. Defaults to False. override_bs (int, optional): Override the batch size for generation. Defaults to None. Returns: diff --git a/src/lighteval/models/litellm_model.py b/src/lighteval/models/litellm_model.py new file mode 100644 index 000000000..66f25ec16 --- /dev/null +++ b/src/lighteval/models/litellm_model.py @@ -0,0 +1,268 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import logging +import os +import time +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import Optional + +from tqdm import tqdm +from transformers import AutoTokenizer + +from lighteval.data import GenerativeTaskDataset +from lighteval.models.abstract_model import LightevalModel +from lighteval.models.endpoints.endpoint_model import ModelInfo +from lighteval.models.model_output import ( + GenerativeResponse, + LoglikelihoodResponse, + LoglikelihoodSingleTokenResponse, +) +from lighteval.tasks.requests import ( + GreedyUntilRequest, + LoglikelihoodRequest, + LoglikelihoodRollingRequest, + LoglikelihoodSingleTokenRequest, +) +from lighteval.utils.imports import is_litellm_available + + +logger = logging.getLogger(__name__) + +if is_litellm_available(): + import litellm + from litellm.caching.caching import Cache + + logging.getLogger("LiteLLM").setLevel(logging.WARNING) + logging.getLogger("LiteLLM").handlers.clear() + + litellm.cache = Cache(type="disk") + + +@dataclass +class LiteLLMModelConfig: + model: str + + +class LiteLLMClient(LightevalModel): + _DEFAULT_MAX_LENGTH: int = 4096 + + def __init__(self, config, env_config) -> None: + """ + IMPORTANT: Your API keys should be set in the environment variables. + If a base_url is not set, it will default to the public API. + """ + self.model_info = ModelInfo( + model_name=config.model, + model_sha="", + model_dtype=None, + model_size="", + ) + self.provider = config.model.split("/")[0] + self.base_url = os.getenv(f"{self.provider.upper()}_BASE_URL", None) + self.API_MAX_RETRY = 5 + self.API_RETRY_SLEEP = 3 + self.API_RETRY_MULTIPLIER = 2 + self.CONCURENT_CALLS = 20 # 100 leads to hitting Anthropic rate limits + self.TEMPERATURE = 0.7 + self.TOP_P = 0.95 + self.model = config.model + self._tokenizer = AutoTokenizer.from_pretrained("gpt2") # Use a dummy tokenizer for compatibility + self.pairwise_tokenization = False + litellm.drop_params = True + litellm.verbose = True + + def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_sequence, system_prompt): + for attempt in range(self.API_MAX_RETRY): + try: + if self.provider == "anthropic": + # Filter out whitespace-only stop sequences + if stop_sequence: + stop_sequence = [s for s in stop_sequence if s.strip()] + if not stop_sequence: # If empty after filtering + stop_sequence = ["\n"] + + # Handle max_new_tokens + completion_tokens = None + if max_new_tokens and max_new_tokens > 0: + completion_tokens = max_new_tokens + if "o1" in self.model: + # We need to allow more tokens to include reasoning tokens + completion_tokens = min(max_new_tokens * 10, 32000) + + response = litellm.completion( + model=self.model, + messages=prompt, + max_completion_tokens=completion_tokens, + logprobs=return_logits if self.provider == "openai" else None, + stop=stop_sequence, + base_url=self.base_url, + n=num_samples, + temperature=self.TEMPERATURE, + top_p=self.TOP_P, + caching=True, + ) + return response + except Exception as e: + wait_time = min(64, self.API_RETRY_SLEEP * (2**attempt)) # Exponential backoff with max 64s + logger.warning( + f"Error in API call: {e}, waiting {wait_time} seconds before retry {attempt + 1}/{self.API_MAX_RETRY}" + ) + time.sleep(wait_time) + + logger.error(f"API call failed after {self.API_MAX_RETRY} attempts, skipping entry.") + + def __call_api_parallel( + self, + prompts, + return_logits: bool | list[bool], + max_new_tokens: int | list[int], + num_samples: int | list[int], + stop_sequence: list[str] | None = None, + system_prompt: str | list[str] = None, + ): + results = [] + + return_logitss = [return_logits for _ in prompts] if not isinstance(return_logits, list) else return_logits + max_new_tokenss = [max_new_tokens for _ in prompts] if not isinstance(max_new_tokens, list) else max_new_tokens + num_sampless = [num_samples for _ in prompts] if not isinstance(num_samples, list) else num_samples + stop_sequencess = [stop_sequence for _ in prompts] + system_prompts = [system_prompt for _ in prompts] if not isinstance(system_prompt, list) else system_prompt + assert ( + len(prompts) + == len(return_logitss) + == len(max_new_tokenss) + == len(num_sampless) + == len(stop_sequencess) + == len(system_prompts) + ), f"Length of prompts, return_logitss, max_new_tokenss, num_sampless, stop_sequences, system_prompts should be the same but are {len(prompts)}, {len(return_logitss)}, {len(max_new_tokenss)}, {len(num_sampless)}, {len(stop_sequencess)}, {len(system_prompts)}" + + with ThreadPoolExecutor(self.CONCURENT_CALLS) as executor: + for entry in tqdm( + executor.map( + self.__call_api, + prompts, + return_logitss, + max_new_tokenss, + num_sampless, + stop_sequencess, + system_prompts, + ), + total=len(prompts), + ): + results.append(entry) + + if None in results: + raise ValueError("Some entries are not annotated due to errors in annotate_p, please inspect and retry.") + + return results + + def greedy_until( + self, + requests: list[GreedyUntilRequest], + override_bs: Optional[int] = None, + ) -> list[GenerativeResponse]: + """ + Generates responses using a greedy decoding strategy until certain ending conditions are met. + + Args: + requests (list[Request]): list of requests containing the context and ending conditions. + override_bs (int, optional): Override the batch size for generation. Defaults to None. + + Returns: + list[GenerativeResponse]: list of generated responses. + """ + for request in requests: + request.tokenized_context = self.tok_encode(request.context) + + dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS) + results = [] + + for _ in tqdm( + dataset.splits_start_end_iterator(), + total=dataset.num_dataset_splits, + desc="Splits", + position=0, + disable=False, # self.disable_tqdm, + ): + contexts = [c.context for c in dataset] + max_new_tokens = dataset[0].generation_size # could be none + return_logits = dataset[0].use_logits + num_samples = dataset[0].num_samples + stop_sequence = requests[0].stop_sequence + system_prompt = requests[0].system_prompt + + responses = self.__call_api_parallel( + contexts, return_logits, max_new_tokens, num_samples, stop_sequence, system_prompt + ) + + for response in responses: + result: list[str] = [choice.message.content for choice in response.choices] + + cur_response = GenerativeResponse( + result=result, + logits=None, + generated_tokens=[], + input_tokens=[], + ) + results.append(cur_response) + + return dataset.get_original_order(results) + + @property + def tokenizer(self): + return self._tokenizer + + def tok_encode(self, text: str): + return text + + @property + def add_special_tokens(self) -> bool: + return False + + @property + def max_length(self) -> int: + """Return the maximum sequence length of the model.""" + return 4096 + + def loglikelihood( + self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None + ) -> list[LoglikelihoodResponse]: + """Tokenize the context and continuation and compute the log likelihood of those + tokenized sequences. + """ + raise NotImplementedError + + def loglikelihood_rolling( + self, requests: list[LoglikelihoodRollingRequest], override_bs: Optional[int] = None + ) -> list[LoglikelihoodResponse]: + """This function is used to compute the log likelihood of the context for perplexity metrics.""" + raise NotImplementedError + + def loglikelihood_single_token( + self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: Optional[int] = None + ) -> list[LoglikelihoodSingleTokenResponse]: + """Tokenize the context and continuation and compute the log likelihood of those + tokenized sequences. + """ + raise NotImplementedError diff --git a/src/lighteval/models/model_loader.py b/src/lighteval/models/model_loader.py index b0817be4a..30aec21c0 100644 --- a/src/lighteval/models/model_loader.py +++ b/src/lighteval/models/model_loader.py @@ -31,13 +31,16 @@ ) from lighteval.models.endpoints.openai_model import OpenAIClient, OpenAIModelConfig from lighteval.models.endpoints.tgi_model import ModelClient, TGIModelConfig +from lighteval.models.litellm_model import LiteLLMClient, LiteLLMModelConfig from lighteval.models.transformers.adapter_model import AdapterModel, AdapterModelConfig from lighteval.models.transformers.base_model import BaseModel, BaseModelConfig from lighteval.models.transformers.delta_model import DeltaModel, DeltaModelConfig from lighteval.models.vllm.vllm_model import VLLMModel, VLLMModelConfig from lighteval.utils.imports import ( + NO_LITELLM_ERROR_MSG, NO_TGI_ERROR_MSG, NO_VLLM_ERROR_MSG, + is_litellm_available, is_openai_available, is_tgi_available, is_vllm_available, @@ -58,6 +61,7 @@ def load_model( # noqa: C901 DummyModelConfig, VLLMModelConfig, OpenAIModelConfig, + LiteLLMModelConfig, ], env_config: EnvConfig, ) -> Union[BaseModel, AdapterModel, DeltaModel, ModelClient, DummyModel]: @@ -95,6 +99,9 @@ def load_model( # noqa: C901 if isinstance(config, OpenAIModelConfig): return load_openai_model(config=config, env_config=env_config) + if isinstance(config, LiteLLMModelConfig): + return load_litellm_model(config=config, env_config=env_config) + def load_model_with_tgi(config: TGIModelConfig): if not is_tgi_available(): @@ -107,6 +114,14 @@ def load_model_with_tgi(config: TGIModelConfig): return model +def load_litellm_model(config: LiteLLMModelConfig, env_config: EnvConfig): + if not is_litellm_available(): + raise ImportError(NO_LITELLM_ERROR_MSG) + + model = LiteLLMClient(config, env_config) + return model + + def load_openai_model(config: OpenAIModelConfig, env_config: EnvConfig): if not is_openai_available(): raise ImportError() diff --git a/src/lighteval/models/vllm/vllm_model.py b/src/lighteval/models/vllm/vllm_model.py index 2d413807d..206fd3a55 100644 --- a/src/lighteval/models/vllm/vllm_model.py +++ b/src/lighteval/models/vllm/vllm_model.py @@ -54,6 +54,12 @@ from vllm import LLM, SamplingParams from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel from vllm.transformers_utils.tokenizer import get_tokenizer + + logging.getLogger("vllm").propagate = True + logging.getLogger("vllm").handlers.clear() + + logging.getLogger("ray").propagate = True + logging.getLogger("ray").handlers.clear() else: LLM = None SamplingParams = None diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py index ea01f81e4..be14d7445 100644 --- a/src/lighteval/tasks/lighteval_task.py +++ b/src/lighteval/tasks/lighteval_task.py @@ -339,7 +339,7 @@ def eval_docs(self) -> list[Doc]: return self._docs def construct_requests( - self, formatted_doc: Doc, context: str, document_id_seed: str, current_task_name: str + self, formatted_doc: Doc, context: str, document_id_seed: str, current_task_name: str, system_prompt: str ) -> Dict[RequestType, List[Request]]: """ Constructs a list of requests from the task based on the given parameters. @@ -349,7 +349,7 @@ def construct_requests( ctx (str): Context, which is the few shot examples + the query. document_id_seed (str): Index of the document in the task appended with the seed used for the few shot sampling. current_task_name (str): Name of the current task. - + system_prompt (str): System prompt to use for the request. Returns: dict[RequestType, List[Request]]: List of requests. """ @@ -365,6 +365,7 @@ def construct_requests( context=context, choice=gold, metric_categories=[MetricCategory.TARGET_PERPLEXITY], + system_prompt=system_prompt, ) for i, gold in enumerate(golds) ] @@ -376,6 +377,7 @@ def construct_requests( request_index=0, context=context, metric_categories=[MetricCategory.PERPLEXITY], + system_prompt=system_prompt, ) ] if self.has_metric_category[MetricCategory.GENERATIVE_SAMPLING]: @@ -395,6 +397,7 @@ def construct_requests( do_sample=True, use_logits=False, metric_categories=[MetricCategory.GENERATIVE_SAMPLING], + system_prompt=system_prompt, ) ] if ( @@ -421,6 +424,7 @@ def construct_requests( ] if self.has_metric_category[c] ], + system_prompt=system_prompt, ) ] if ( @@ -439,6 +443,7 @@ def construct_requests( for c in [MetricCategory.MULTICHOICE, MetricCategory.MULTICHOICE_PMI] if self.has_metric_category[c] ], + system_prompt=system_prompt, ) for i, choice in enumerate(formatted_doc.choices) ] @@ -455,6 +460,7 @@ def construct_requests( context=formatted_doc.unconditioned_query, choice=choice, metric_categories=[MetricCategory.MULTICHOICE_PMI], + system_prompt=system_prompt, ) for i, choice in enumerate(formatted_doc.choices) ] @@ -467,6 +473,7 @@ def construct_requests( context=context, choices=formatted_doc.choices, metric_categories=[MetricCategory.MULTICHOICE_ONE_TOKEN], + system_prompt=system_prompt, ) ] if self.has_metric_category[MetricCategory.LLM_AS_JUDGE_MULTI_TURN]: @@ -479,6 +486,7 @@ def construct_requests( stop_sequence=self.stop_sequence, generation_size=self.generation_size, metric_categories=[MetricCategory.LLM_AS_JUDGE_MULTI_TURN], + system_prompt=system_prompt, ) ] if self.has_metric_category[MetricCategory.LLM_AS_JUDGE]: @@ -493,6 +501,7 @@ def construct_requests( generation_grammar=self.generation_grammar, num_samples=1, metric_categories=[MetricCategory.LLM_AS_JUDGE], + system_prompt=system_prompt, ) ] @@ -652,7 +661,9 @@ def create_requests_from_tasks( # noqa: C901 # Constructing the requests cur_task_name = f"{task_name}|{num_fewshot}" docs[SampleUid(cur_task_name, doc_id_seed)] = doc - req_type_reqs_dict = task.construct_requests(doc, doc.ctx, doc_id_seed, cur_task_name) + req_type_reqs_dict = task.construct_requests( + doc, doc.ctx, doc_id_seed, cur_task_name, system_prompt + ) for req_type, reqs in req_type_reqs_dict.items(): requests[req_type].extend(reqs) diff --git a/src/lighteval/tasks/prompt_manager.py b/src/lighteval/tasks/prompt_manager.py index cb9f94d04..c8c842223 100644 --- a/src/lighteval/tasks/prompt_manager.py +++ b/src/lighteval/tasks/prompt_manager.py @@ -29,6 +29,7 @@ from typing import TYPE_CHECKING, Optional, Tuple, Union from lighteval.models.abstract_model import LightevalModel +from lighteval.models.litellm_model import LiteLLMClient from lighteval.tasks.requests import Doc from lighteval.utils.utils import as_list @@ -205,7 +206,10 @@ def _single_turn_context( system_prompt=system_prompt, use_chat_template=use_chat_template, ) - toks = self.model.tok_encode(output) + if not use_chat_template: + toks = self.model.tok_encode(output) + else: + toks = "".join([msg["content"] for msg in output]) # If we need to truncate few-shots to fit in the context if truncate_few_shots and self.model.max_length is not None and self.model.tokenizer is not None: @@ -223,9 +227,17 @@ def _single_turn_context( system_prompt=system_prompt, use_chat_template=use_chat_template, ) - toks = self.model.tokenizer(output)["input_ids"] + if not use_chat_template: + toks = self.model.tok_encode(output) + else: + toks = "".join([msg["content"] for msg in output]) + + if isinstance(self.model, LiteLLMClient): + return output, num_effective_fewshots - return output, num_effective_fewshots + return self.model.tokenizer.apply_chat_template( + output, tokenize=False, add_generation_prompt=True + ), num_effective_fewshots def get_examples( self, @@ -256,7 +268,7 @@ def get_examples( examples.insert(0, {"role": "system", "content": system_prompt + instruction}) else: # Else we add the instruction to the first example examples[0]["content"] = instruction + examples[0]["content"] - return self.model.tokenizer.apply_chat_template(examples, tokenize=False, add_generation_prompt=True) + return examples else: if system_prompt is not None: output = system_prompt + instruction + "\n\n".join(examples) diff --git a/src/lighteval/tasks/requests.py b/src/lighteval/tasks/requests.py index cd75ad402..e184f807f 100644 --- a/src/lighteval/tasks/requests.py +++ b/src/lighteval/tasks/requests.py @@ -52,6 +52,7 @@ class Request: request_index (int): The index of the request. context (str): The context for the request. metric_categories (list[MetricCategory]): All the metric categories which concern this request + system_prompt (str): System prompt to use for the request. """ task_name: str @@ -59,6 +60,7 @@ class Request: request_index: int context: str metric_categories: list["MetricCategory"] # noqa F821 + system_prompt: Optional[str] @dataclass diff --git a/src/lighteval/utils/imports.py b/src/lighteval/utils/imports.py index d36c1acb4..c8fb2ce73 100644 --- a/src/lighteval/utils/imports.py +++ b/src/lighteval/utils/imports.py @@ -77,6 +77,13 @@ def is_openai_available() -> bool: NO_OPENAI_ERROR_MSG = "You are trying to use an Open AI LLM as a judge, for which you need `openai`, which is not available in your environment. Please install it using pip." +def is_litellm_available() -> bool: + return importlib.util.find_spec("litellm") is not None + + +NO_LITELLM_ERROR_MSG = "You are trying to use a LiteLLM model, for which you need `litellm`, which is not available in your environment. Please install it using pip." + + def is_vllm_available() -> bool: return importlib.util.find_spec("vllm") is not None and importlib.util.find_spec("ray") is not None