Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
zvigrinberg committed Dec 30, 2024
2 parents 18c3d6a + f464f9c commit 8534f50
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 34 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@ The configuration defines how the workflow operates, including model settings, i
2. **LLM engine configuration (`engine`)**: The `engine` section configures various models for the LLM nodes.
- LLM processing nodes: `agent`, `checklist_model`, `justification_model`, `summary_model`
- `model_name`: The name of the LLM model used by the node.
- `prompt`: Manually set the prompt for the specific model in the configuration. The prompt can either be passed in as a string of text or as a path to a text file containing the desired prompting.
- `service`: Specifies the service for running the LLM inference. (Set to `nvfoundation` if using NIM.)
- `max_tokens`: Defines the maximum number of tokens that can be generated in one output step.
- `temperature`: Controls randomness in the output. A lower temperature produces more deterministic results.
Expand Down
36 changes: 36 additions & 0 deletions configs/schemas/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,18 @@
"title": "Model Name",
"type": "string"
},
"prompt": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null,
"title": "Prompt"
},
"temperature": {
"default": 0.0,
"title": "Temperature",
Expand Down Expand Up @@ -628,6 +640,18 @@
"title": "Model Name",
"type": "string"
},
"prompt": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null,
"title": "Prompt"
},
"customization_id": {
"anyOf": [
{
Expand Down Expand Up @@ -793,6 +817,18 @@
"title": "Model Name",
"type": "string"
},
"prompt": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null,
"title": "Prompt"
},
"temperature": {
"default": 0.0,
"title": "Temperature",
Expand Down
1 change: 1 addition & 0 deletions requirements.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies:
# - faiss-gpu=1.7 # Uncomment this when the package supports CUDA 12. See: https://github.com/conda-forge/faiss-split-feedstock/pull/72
- faiss=1.7
- gitpython=3.1
- httpx>=0.23,<0.28 # work-around for https://github.com/openai/openai-python/issues/1915
- json5=0.9
- morpheus-llm=24.10
# cuda-python=12.6.1 cannot be used because of -
Expand Down
16 changes: 16 additions & 0 deletions src/cve/data_models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pydantic import PositiveInt
from pydantic import Tag
from pydantic import field_validator
from pydantic.functional_validators import AfterValidator

from morpheus.utils.http_utils import HTTPMethod

Expand Down Expand Up @@ -72,6 +73,18 @@ def get_num_threads(cls, v: int) -> int:


########################### LLM Model Configs ###########################
def check_prompt_file(prompt: str) -> str:
if prompt is not None and os.path.isfile(prompt):
with open(prompt, 'r') as prompt_file:
read_prompt = prompt_file.read()
return read_prompt
else:
return prompt


PromptConfig = typing.Annotated[str, AfterValidator(check_prompt_file)]


class NeMoLLMServiceConfig(TypedBaseModel[typing.Literal["nemo"]]):

api_key: str | None = None
Expand All @@ -84,6 +97,7 @@ class NeMoLLMModelConfig(BaseModel):
service: NeMoLLMServiceConfig

model_name: str
prompt: PromptConfig | None = None
customization_id: str | None = None
temperature: typing.Annotated[float, Field(ge=0.0, le=1.0)] = 0.0
top_k: NonNegativeInt = 0
Expand All @@ -110,6 +124,7 @@ class NVFoundationLLMModelConfig(TypedBaseModel[typing.Literal["nvfoundation"]])
service: NVFoundationLLMServiceConfig

model_name: str
prompt: PromptConfig | None = None
temperature: float = 0.0
top_p: float | None = None
max_tokens: PositiveInt = 300
Expand All @@ -129,6 +144,7 @@ class OpenAIModelConfig(BaseModel):
service: OpenAIServiceConfig

model_name: str
prompt: PromptConfig | None = None
temperature: float = 0.0
top_p: float = 1.0
seed: int | None = None
Expand Down
42 changes: 21 additions & 21 deletions src/cve/nodes/cve_checklist_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,16 @@
from morpheus_llm.llm import LLMNode
from morpheus_llm.llm.nodes.llm_generate_node import LLMGenerateNode
from morpheus_llm.llm.nodes.prompt_template_node import PromptTemplateNode
from morpheus_llm.llm.services.llm_service import LLMService
from morpheus_llm.llm.services.llm_service import LLMClient

from ..data_models.config import LLMModelConfig
from ..utils.prompting import MOD_FEW_SHOT
from ..utils.prompting import additional_intel_prompting
from ..utils.prompting import get_mod_examples
from ..utils.string_utils import attempt_fix_list_string

logger = logging.getLogger(__name__)

cve_prompt1 = (
MOD_FEW_SHOT.format(examples=get_mod_examples())
+ additional_intel_prompting
+ "\n\nIf a vulnerable function or method is mentioned in the CVE description, ensure the first checklist item verifies whether this function or method is being called from the code or used by the code."
+ "\nThe vulnerable version of the vulnerable package is already verified to be installed within the container. Check only the other factors that affect exploitability, no need to verify version again."
)
DEFAULT_CHECKLIST_PROMPT = MOD_FEW_SHOT.format(examples=get_mod_examples())

cve_prompt2 = """Parse the following numbered checklist into a python list in the format ["x", "y", "z"], a comma separated list surrounded by square braces: {{template}}"""

Expand Down Expand Up @@ -109,7 +103,9 @@ class CVEChecklistNode(LLMNode):
It integrates various nodes that handle CVE lookup, prompting, generation, and parsing to produce an actionable checklist.
"""

def __init__(self, *, checklist_model_config: LLMModelConfig, enable_llm_list_parsing: bool = False):
def __init__(self, *, prompt: str,
llm_client: LLMClient,
enable_llm_list_parsing: bool = False):
"""
Initialize the CVEChecklistNode with optional caching and a vulnerability endpoint retriever.
Expand All @@ -124,20 +120,27 @@ def __init__(self, *, checklist_model_config: LLMModelConfig, enable_llm_list_pa
"""
super().__init__()

chat_service = LLMService.create(checklist_model_config.service.type,
**checklist_model_config.service.model_dump(exclude={"type"}, by_alias=True))
if not prompt:
prompt = DEFAULT_CHECKLIST_PROMPT

intel = (
additional_intel_prompting +
"\n\nIf a vulnerable function or method is mentioned in the CVE description, ensure the first checklist item verifies whether this function or method is being called from the code or used by the code."
"\nThe vulnerable version of the vulnerable package is already verified to be installed within the container. Check only the other factors that affect exploitability, no need to verify version again."
)

cve_prompt1 = (
prompt
+ intel
)

# Add a node to create a prompt for CVE checklist generation based on the CVE details obtained from the lookup
# node
self.add_node("checklist_prompt",
inputs=[("*", "*")],
node=PromptTemplateNode(template=cve_prompt1, template_format="jinja"))

# Instantiate a chat service and configure a client for generating responses to the checklist prompt
llm_client_1 = chat_service.get_client(
**checklist_model_config.model_dump(exclude={"service", "type"}, by_alias=True)
)
gen_node_1 = LLMGenerateNode(llm_client=llm_client_1)
gen_node_1 = LLMGenerateNode(llm_client=llm_client)
self.add_node("chat1", inputs=["/checklist_prompt"], node=gen_node_1)

if enable_llm_list_parsing:
Expand All @@ -146,11 +149,8 @@ def __init__(self, *, checklist_model_config: LLMModelConfig, enable_llm_list_pa
inputs=["/chat1"],
node=PromptTemplateNode(template=cve_prompt2, template_format="jinja"))

# Configure a second client for generating a follow-up response based on the parsed checklist prompt
llm_client_2 = chat_service.get_client(
**checklist_model_config.model_dump(exclude={"service", "type"}, by_alias=True)
)
gen_node_2 = LLMGenerateNode(llm_client=llm_client_2)
# Configure a second node for generating a follow-up response based on the parsed checklist prompt
gen_node_2 = LLMGenerateNode(llm_client=llm_client)
self.add_node("chat2", inputs=[("/parse_checklist_prompt", "prompt")], node=gen_node_2)

checklist_prompts = ["/chat2"] if enable_llm_list_parsing else ["/chat1"]
Expand Down
8 changes: 5 additions & 3 deletions src/cve/nodes/cve_justification_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class CVEJustifyNode(LLMNode):
JUSTIFICATION_REASON_COL_NAME = "justification"
AFFECTED_STATUS_COL_NAME = "affected_status"

JUSTIFICATION_PROMPT = dedent("""
DEFAULT_JUSTIFICATION_PROMPT = dedent("""
The summary provided below (delimited with XML tags), generated by the software agent named Vulnerability
Analysis for Container Security, evaluates a specific CVE (Common Vulnerabilities and Exposures) against the
backdrop of a software package or environment information. This information may include a Software Bill of
Expand Down Expand Up @@ -97,7 +97,7 @@ class CVEJustifyNode(LLMNode):
"vulnerable": "TRUE"
}

def __init__(self, *, llm_client: LLMClient):
def __init__(self, *, prompt: str, llm_client: LLMClient):
"""
Initialize the CVEJustificationNode with a selected model.
Expand All @@ -113,9 +113,11 @@ async def _strip_summaries(summaries: list[str]) -> list[str]:

self.add_node('stripped_summaries', inputs=['summaries'], node=LLMLambdaNode(_strip_summaries))

justification_prompt = prompt or self.DEFAULT_JUSTIFICATION_PROMPT

self.add_node("justification_prompt",
inputs=[("/stripped_summaries", "summary")],
node=PromptTemplateNode(template=self.JUSTIFICATION_PROMPT, template_format='jinja'))
node=PromptTemplateNode(template=justification_prompt, template_format='jinja'))

self.add_node("justify", inputs=["/justification_prompt"], node=LLMGenerateNode(llm_client=llm_client))

Expand Down
8 changes: 5 additions & 3 deletions src/cve/nodes/cve_summary_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

logger = logging.getLogger(__name__)

SUMMARY_PROMPT = """Summarize the exploitability investigation results of a Common Vulnerabilities and Exposures (CVE) \
DEFAULT_SUMMARY_PROMPT = """Summarize the exploitability investigation results of a Common Vulnerabilities and Exposures (CVE) \
based on the provided Checklist and Findings. Write a concise paragraph focusing only on checklist items with \
definitive answers. Begin your response by clearly stating whether the CVE is exploitable. Disregard any ambiguous \
checklist items.
Expand All @@ -39,7 +39,7 @@ class CVESummaryNode(LLMNode):
A node to summarize the results of the checklist responses.
"""

def __init__(self, *, llm_client: LLMClient):
def __init__(self, *, prompt: str, llm_client: LLMClient):
"""
Initialize the CVESummaryNode with a selected model.
Expand Down Expand Up @@ -80,9 +80,11 @@ async def concat_checklist_responses(agent_q_and_a: list[list[dict]]) -> list[st

self.add_node('results', inputs=['/checklist'], node=LLMLambdaNode(concat_checklist_responses))

summary_prompt = prompt or DEFAULT_SUMMARY_PROMPT

self.add_node('summary_prompt',
inputs=[('/results', 'response')],
node=PromptTemplateNode(template=SUMMARY_PROMPT, template_format='jinja'))
node=PromptTemplateNode(template=summary_prompt, template_format='jinja'))

# Generate a summary from the combined checklist responses
self.add_node("summary",
Expand Down
28 changes: 21 additions & 7 deletions src/cve/pipeline/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _build_dynamic_agent_fn(run_config: RunConfig, embeddings: Embeddings):
chat_service = LLMService.create(run_config.engine.agent.model.service.type,
**run_config.engine.agent.model.service.model_dump(exclude={"type"},
by_alias=True))
chat_client = chat_service.get_client(**run_config.engine.agent.model.model_dump(exclude={"service", "type"},
chat_client = chat_service.get_client(**run_config.engine.agent.model.model_dump(exclude={"service", "type", "prompt"},
by_alias=True))
langchain_llm = LangchainLLMClientWrapper(client=chat_client)

Expand Down Expand Up @@ -146,11 +146,13 @@ def run_retrieval_qa_tool(retrieval_qa_tool: RetrievalQA, query: str) -> str | d

# Define a system prompt that sets the context for the language model's task. This prompt positions the assistant
# as a powerful entity capable of investigating CVE impacts on container images.
sys_prompt = (
DEFAULT_SYS_PROMPT = (
"You are a very powerful assistant who helps investigate the impact of reported Common Vulnerabilities and "
"Exposures (CVE) on container images. Information about the container image under investigation is "
"stored in vector databases available to you via tools.")

sys_prompt = run_config.engine.agent.model.prompt or DEFAULT_SYS_PROMPT

# Initialize an agent with the tools and settings defined above.
# This agent is designed to handle zero-shot reaction descriptions and parse errors.
agent = initialize_agent(
Expand Down Expand Up @@ -191,9 +193,13 @@ def build_engine(*, run_config: RunConfig, embeddings: Embeddings):
run_config.engine.justification_model.service.type,
**run_config.engine.justification_model.service.model_dump(exclude={"type"}))

checklist_service = LLMService.create(run_config.engine.checklist_model.service.type,
**run_config.engine.checklist_model.service.model_dump(exclude={"type"}, by_alias=True))
engine = LLMEngine()

checklist_node = CVEChecklistNode(checklist_model_config=run_config.engine.checklist_model,
checklist_node = CVEChecklistNode(prompt=run_config.engine.checklist_model.prompt,
llm_client=checklist_service.get_client(
**run_config.engine.checklist_model.model_dump(exclude={"service", "type", "prompt"}, by_alias=True)),
enable_llm_list_parsing=run_config.general.enable_llm_list_parsing)

engine.add_node("extract_prompt", node=ManualExtracterNode(input_names=checklist_node.get_input_names()))
Expand All @@ -210,13 +216,21 @@ def build_engine(*, run_config: RunConfig, embeddings: Embeddings):
engine.add_node('summary',
inputs=[("/checklist", "checklist_inputs"), ("/agent/outputs", "checklist_outputs"),
"/agent/intermediate_steps"],
node=CVESummaryNode(llm_client=summary_service.get_client(
**run_config.engine.summary_model.model_dump(exclude={"service", "type"}))))
node=CVESummaryNode(prompt=run_config.engine.summary_model.prompt,
llm_client=summary_service.get_client(
**run_config.engine.summary_model.model_dump(exclude={"service", "type", "prompt"}, by_alias=True)
)
)
)

engine.add_node('justification',
inputs=[("/summary/summary", "summaries")],
node=CVEJustifyNode(llm_client=justification_service.get_client(
**run_config.engine.justification_model.model_dump(exclude={"service", "type"}))))
node=CVEJustifyNode(prompt=run_config.engine.justification_model.prompt,
llm_client=justification_service.get_client(
**run_config.engine.justification_model.model_dump(exclude={"service", "type", "prompt"}, by_alias=True)
)
)
)

handler_inputs = [
"/summary/checklist",
Expand Down

0 comments on commit 8534f50

Please sign in to comment.