diff --git a/README.md b/README.md index b5456cf..6187879 100644 --- a/README.md +++ b/README.md @@ -506,6 +506,7 @@ The configuration defines how the workflow operates, including model settings, i 2. **LLM engine configuration (`engine`)**: The `engine` section configures various models for the LLM nodes. - LLM processing nodes: `agent`, `checklist_model`, `justification_model`, `summary_model` - `model_name`: The name of the LLM model used by the node. + - `prompt`: Manually set the prompt for the specific model in the configuration. The prompt can either be passed in as a string of text or as a path to a text file containing the desired prompting. - `service`: Specifies the service for running the LLM inference. (Set to `nvfoundation` if using NIM.) - `max_tokens`: Defines the maximum number of tokens that can be generated in one output step. - `temperature`: Controls randomness in the output. A lower temperature produces more deterministic results. diff --git a/configs/schemas/config.schema.json b/configs/schemas/config.schema.json index f3700f8..ce9a7eb 100644 --- a/configs/schemas/config.schema.json +++ b/configs/schemas/config.schema.json @@ -540,6 +540,18 @@ "title": "Model Name", "type": "string" }, + "prompt": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Prompt" + }, "temperature": { "default": 0.0, "title": "Temperature", @@ -628,6 +640,18 @@ "title": "Model Name", "type": "string" }, + "prompt": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Prompt" + }, "customization_id": { "anyOf": [ { @@ -793,6 +817,18 @@ "title": "Model Name", "type": "string" }, + "prompt": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Prompt" + }, "temperature": { "default": 0.0, "title": "Temperature", diff --git a/requirements.yaml b/requirements.yaml index a3035e8..2158dd1 100644 --- a/requirements.yaml +++ b/requirements.yaml @@ -29,6 +29,7 @@ dependencies: # - faiss-gpu=1.7 # Uncomment this when the package supports CUDA 12. See: https://github.com/conda-forge/faiss-split-feedstock/pull/72 - faiss=1.7 - gitpython=3.1 + - httpx>=0.23,<0.28 # work-around for https://github.com/openai/openai-python/issues/1915 - json5=0.9 - morpheus-llm=24.10 # cuda-python=12.6.1 cannot be used because of - diff --git a/src/cve/data_models/config.py b/src/cve/data_models/config.py index 1be4c6b..9eceb4c 100644 --- a/src/cve/data_models/config.py +++ b/src/cve/data_models/config.py @@ -27,6 +27,7 @@ from pydantic import PositiveInt from pydantic import Tag from pydantic import field_validator +from pydantic.functional_validators import AfterValidator from morpheus.utils.http_utils import HTTPMethod @@ -72,6 +73,18 @@ def get_num_threads(cls, v: int) -> int: ########################### LLM Model Configs ########################### +def check_prompt_file(prompt: str) -> str: + if prompt is not None and os.path.isfile(prompt): + with open(prompt, 'r') as prompt_file: + read_prompt = prompt_file.read() + return read_prompt + else: + return prompt + + +PromptConfig = typing.Annotated[str, AfterValidator(check_prompt_file)] + + class NeMoLLMServiceConfig(TypedBaseModel[typing.Literal["nemo"]]): api_key: str | None = None @@ -84,6 +97,7 @@ class NeMoLLMModelConfig(BaseModel): service: NeMoLLMServiceConfig model_name: str + prompt: PromptConfig | None = None customization_id: str | None = None temperature: typing.Annotated[float, Field(ge=0.0, le=1.0)] = 0.0 top_k: NonNegativeInt = 0 @@ -110,6 +124,7 @@ class NVFoundationLLMModelConfig(TypedBaseModel[typing.Literal["nvfoundation"]]) service: NVFoundationLLMServiceConfig model_name: str + prompt: PromptConfig | None = None temperature: float = 0.0 top_p: float | None = None max_tokens: PositiveInt = 300 @@ -129,6 +144,7 @@ class OpenAIModelConfig(BaseModel): service: OpenAIServiceConfig model_name: str + prompt: PromptConfig | None = None temperature: float = 0.0 top_p: float = 1.0 seed: int | None = None diff --git a/src/cve/nodes/cve_checklist_node.py b/src/cve/nodes/cve_checklist_node.py index 306f364..0976a83 100644 --- a/src/cve/nodes/cve_checklist_node.py +++ b/src/cve/nodes/cve_checklist_node.py @@ -21,9 +21,8 @@ from morpheus_llm.llm import LLMNode from morpheus_llm.llm.nodes.llm_generate_node import LLMGenerateNode from morpheus_llm.llm.nodes.prompt_template_node import PromptTemplateNode -from morpheus_llm.llm.services.llm_service import LLMService +from morpheus_llm.llm.services.llm_service import LLMClient -from ..data_models.config import LLMModelConfig from ..utils.prompting import MOD_FEW_SHOT from ..utils.prompting import additional_intel_prompting from ..utils.prompting import get_mod_examples @@ -31,12 +30,7 @@ logger = logging.getLogger(__name__) -cve_prompt1 = ( - MOD_FEW_SHOT.format(examples=get_mod_examples()) - + additional_intel_prompting - + "\n\nIf a vulnerable function or method is mentioned in the CVE description, ensure the first checklist item verifies whether this function or method is being called from the code or used by the code." - + "\nThe vulnerable version of the vulnerable package is already verified to be installed within the container. Check only the other factors that affect exploitability, no need to verify version again." -) +DEFAULT_CHECKLIST_PROMPT = MOD_FEW_SHOT.format(examples=get_mod_examples()) cve_prompt2 = """Parse the following numbered checklist into a python list in the format ["x", "y", "z"], a comma separated list surrounded by square braces: {{template}}""" @@ -109,7 +103,9 @@ class CVEChecklistNode(LLMNode): It integrates various nodes that handle CVE lookup, prompting, generation, and parsing to produce an actionable checklist. """ - def __init__(self, *, checklist_model_config: LLMModelConfig, enable_llm_list_parsing: bool = False): + def __init__(self, *, prompt: str, + llm_client: LLMClient, + enable_llm_list_parsing: bool = False): """ Initialize the CVEChecklistNode with optional caching and a vulnerability endpoint retriever. @@ -124,8 +120,19 @@ def __init__(self, *, checklist_model_config: LLMModelConfig, enable_llm_list_pa """ super().__init__() - chat_service = LLMService.create(checklist_model_config.service.type, - **checklist_model_config.service.model_dump(exclude={"type"}, by_alias=True)) + if not prompt: + prompt = DEFAULT_CHECKLIST_PROMPT + + intel = ( + additional_intel_prompting + + "\n\nIf a vulnerable function or method is mentioned in the CVE description, ensure the first checklist item verifies whether this function or method is being called from the code or used by the code." + "\nThe vulnerable version of the vulnerable package is already verified to be installed within the container. Check only the other factors that affect exploitability, no need to verify version again." + ) + + cve_prompt1 = ( + prompt + + intel + ) # Add a node to create a prompt for CVE checklist generation based on the CVE details obtained from the lookup # node @@ -133,11 +140,7 @@ def __init__(self, *, checklist_model_config: LLMModelConfig, enable_llm_list_pa inputs=[("*", "*")], node=PromptTemplateNode(template=cve_prompt1, template_format="jinja")) - # Instantiate a chat service and configure a client for generating responses to the checklist prompt - llm_client_1 = chat_service.get_client( - **checklist_model_config.model_dump(exclude={"service", "type"}, by_alias=True) - ) - gen_node_1 = LLMGenerateNode(llm_client=llm_client_1) + gen_node_1 = LLMGenerateNode(llm_client=llm_client) self.add_node("chat1", inputs=["/checklist_prompt"], node=gen_node_1) if enable_llm_list_parsing: @@ -146,11 +149,8 @@ def __init__(self, *, checklist_model_config: LLMModelConfig, enable_llm_list_pa inputs=["/chat1"], node=PromptTemplateNode(template=cve_prompt2, template_format="jinja")) - # Configure a second client for generating a follow-up response based on the parsed checklist prompt - llm_client_2 = chat_service.get_client( - **checklist_model_config.model_dump(exclude={"service", "type"}, by_alias=True) - ) - gen_node_2 = LLMGenerateNode(llm_client=llm_client_2) + # Configure a second node for generating a follow-up response based on the parsed checklist prompt + gen_node_2 = LLMGenerateNode(llm_client=llm_client) self.add_node("chat2", inputs=[("/parse_checklist_prompt", "prompt")], node=gen_node_2) checklist_prompts = ["/chat2"] if enable_llm_list_parsing else ["/chat1"] diff --git a/src/cve/nodes/cve_justification_node.py b/src/cve/nodes/cve_justification_node.py index 16de3d2..88e79b7 100644 --- a/src/cve/nodes/cve_justification_node.py +++ b/src/cve/nodes/cve_justification_node.py @@ -35,7 +35,7 @@ class CVEJustifyNode(LLMNode): JUSTIFICATION_REASON_COL_NAME = "justification" AFFECTED_STATUS_COL_NAME = "affected_status" - JUSTIFICATION_PROMPT = dedent(""" + DEFAULT_JUSTIFICATION_PROMPT = dedent(""" The summary provided below (delimited with XML tags), generated by the software agent named Vulnerability Analysis for Container Security, evaluates a specific CVE (Common Vulnerabilities and Exposures) against the backdrop of a software package or environment information. This information may include a Software Bill of @@ -97,7 +97,7 @@ class CVEJustifyNode(LLMNode): "vulnerable": "TRUE" } - def __init__(self, *, llm_client: LLMClient): + def __init__(self, *, prompt: str, llm_client: LLMClient): """ Initialize the CVEJustificationNode with a selected model. @@ -113,9 +113,11 @@ async def _strip_summaries(summaries: list[str]) -> list[str]: self.add_node('stripped_summaries', inputs=['summaries'], node=LLMLambdaNode(_strip_summaries)) + justification_prompt = prompt or self.DEFAULT_JUSTIFICATION_PROMPT + self.add_node("justification_prompt", inputs=[("/stripped_summaries", "summary")], - node=PromptTemplateNode(template=self.JUSTIFICATION_PROMPT, template_format='jinja')) + node=PromptTemplateNode(template=justification_prompt, template_format='jinja')) self.add_node("justify", inputs=["/justification_prompt"], node=LLMGenerateNode(llm_client=llm_client)) diff --git a/src/cve/nodes/cve_summary_node.py b/src/cve/nodes/cve_summary_node.py index 1dcedfd..4ba0d96 100644 --- a/src/cve/nodes/cve_summary_node.py +++ b/src/cve/nodes/cve_summary_node.py @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) -SUMMARY_PROMPT = """Summarize the exploitability investigation results of a Common Vulnerabilities and Exposures (CVE) \ +DEFAULT_SUMMARY_PROMPT = """Summarize the exploitability investigation results of a Common Vulnerabilities and Exposures (CVE) \ based on the provided Checklist and Findings. Write a concise paragraph focusing only on checklist items with \ definitive answers. Begin your response by clearly stating whether the CVE is exploitable. Disregard any ambiguous \ checklist items. @@ -39,7 +39,7 @@ class CVESummaryNode(LLMNode): A node to summarize the results of the checklist responses. """ - def __init__(self, *, llm_client: LLMClient): + def __init__(self, *, prompt: str, llm_client: LLMClient): """ Initialize the CVESummaryNode with a selected model. @@ -80,9 +80,11 @@ async def concat_checklist_responses(agent_q_and_a: list[list[dict]]) -> list[st self.add_node('results', inputs=['/checklist'], node=LLMLambdaNode(concat_checklist_responses)) + summary_prompt = prompt or DEFAULT_SUMMARY_PROMPT + self.add_node('summary_prompt', inputs=[('/results', 'response')], - node=PromptTemplateNode(template=SUMMARY_PROMPT, template_format='jinja')) + node=PromptTemplateNode(template=summary_prompt, template_format='jinja')) # Generate a summary from the combined checklist responses self.add_node("summary", diff --git a/src/cve/pipeline/engine.py b/src/cve/pipeline/engine.py index 3e97b5e..b607600 100644 --- a/src/cve/pipeline/engine.py +++ b/src/cve/pipeline/engine.py @@ -49,7 +49,7 @@ def _build_dynamic_agent_fn(run_config: RunConfig, embeddings: Embeddings): chat_service = LLMService.create(run_config.engine.agent.model.service.type, **run_config.engine.agent.model.service.model_dump(exclude={"type"}, by_alias=True)) - chat_client = chat_service.get_client(**run_config.engine.agent.model.model_dump(exclude={"service", "type"}, + chat_client = chat_service.get_client(**run_config.engine.agent.model.model_dump(exclude={"service", "type", "prompt"}, by_alias=True)) langchain_llm = LangchainLLMClientWrapper(client=chat_client) @@ -146,11 +146,13 @@ def run_retrieval_qa_tool(retrieval_qa_tool: RetrievalQA, query: str) -> str | d # Define a system prompt that sets the context for the language model's task. This prompt positions the assistant # as a powerful entity capable of investigating CVE impacts on container images. - sys_prompt = ( + DEFAULT_SYS_PROMPT = ( "You are a very powerful assistant who helps investigate the impact of reported Common Vulnerabilities and " "Exposures (CVE) on container images. Information about the container image under investigation is " "stored in vector databases available to you via tools.") + sys_prompt = run_config.engine.agent.model.prompt or DEFAULT_SYS_PROMPT + # Initialize an agent with the tools and settings defined above. # This agent is designed to handle zero-shot reaction descriptions and parse errors. agent = initialize_agent( @@ -191,9 +193,13 @@ def build_engine(*, run_config: RunConfig, embeddings: Embeddings): run_config.engine.justification_model.service.type, **run_config.engine.justification_model.service.model_dump(exclude={"type"})) + checklist_service = LLMService.create(run_config.engine.checklist_model.service.type, + **run_config.engine.checklist_model.service.model_dump(exclude={"type"}, by_alias=True)) engine = LLMEngine() - checklist_node = CVEChecklistNode(checklist_model_config=run_config.engine.checklist_model, + checklist_node = CVEChecklistNode(prompt=run_config.engine.checklist_model.prompt, + llm_client=checklist_service.get_client( + **run_config.engine.checklist_model.model_dump(exclude={"service", "type", "prompt"}, by_alias=True)), enable_llm_list_parsing=run_config.general.enable_llm_list_parsing) engine.add_node("extract_prompt", node=ManualExtracterNode(input_names=checklist_node.get_input_names())) @@ -210,13 +216,21 @@ def build_engine(*, run_config: RunConfig, embeddings: Embeddings): engine.add_node('summary', inputs=[("/checklist", "checklist_inputs"), ("/agent/outputs", "checklist_outputs"), "/agent/intermediate_steps"], - node=CVESummaryNode(llm_client=summary_service.get_client( - **run_config.engine.summary_model.model_dump(exclude={"service", "type"})))) + node=CVESummaryNode(prompt=run_config.engine.summary_model.prompt, + llm_client=summary_service.get_client( + **run_config.engine.summary_model.model_dump(exclude={"service", "type", "prompt"}, by_alias=True) + ) + ) + ) engine.add_node('justification', inputs=[("/summary/summary", "summaries")], - node=CVEJustifyNode(llm_client=justification_service.get_client( - **run_config.engine.justification_model.model_dump(exclude={"service", "type"})))) + node=CVEJustifyNode(prompt=run_config.engine.justification_model.prompt, + llm_client=justification_service.get_client( + **run_config.engine.justification_model.model_dump(exclude={"service", "type", "prompt"}, by_alias=True) + ) + ) + ) handler_inputs = [ "/summary/checklist",