Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
penguine-ip committed Jul 30, 2024
1 parent bd4b7ed commit d9b379d
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 32 deletions.
19 changes: 13 additions & 6 deletions deepeval/metrics/tool_correctness/tool_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
ConversationalTestCase,
)
from deepeval.metrics import BaseMetric

required_params: List[LLMTestCaseParams] = [
LLMTestCaseParams.TOOLS_USED,
LLMTestCaseParams.EXPECTED_TOOLS,
]


class ToolCorrectnessMetric(BaseMetric):
def __init__(
self,
Expand All @@ -39,7 +41,9 @@ def measure(

self.tools_used: Set[str] = set(test_case.tools_used)
self.expected_tools: Set[str] = set(test_case.expected_tools)
self.expected_tools_used = self.tools_used.intersection(self.expected_tools)
self.expected_tools_used = self.tools_used.intersection(
self.expected_tools
)
self.score = self._calculate_score()
self.reason = self._generate_reason()
self.success = self.score >= self.threshold
Expand All @@ -52,26 +56,30 @@ def measure(
],
)
return self.score

async def a_measure(
self, test_case: Union[LLMTestCase, ConversationalTestCase]
) -> float:
return self.measure(test_case)

def _generate_reason(self):
reason = f"The score is {self.score} because {len(self.expected_tools_used)} out of {len(self.expected_tools)} expected tools were used. "
tools_unused = list(self.expected_tools - self.expected_tools_used)
if len(tools_unused) > 0:
reason += f""
reason += f"Tool {tools_unused} was " if len(tools_unused) == 1 else f"Tools {tools_unused} were "
reason += (
f"Tool {tools_unused} was "
if len(tools_unused) == 1
else f"Tools {tools_unused} were "
)
reason += "expected but not used"

return reason

def _calculate_score(self):
number_of_expected_tools_used = len(self.expected_tools_used)
number_of_expected_tools = len(self.expected_tools)
score = number_of_expected_tools_used / number_of_expected_tools
score = number_of_expected_tools_used / number_of_expected_tools
return 0 if self.strict_mode and score < self.threshold else score

def is_successful(self) -> bool:
Expand All @@ -84,4 +92,3 @@ def is_successful(self) -> bool:
@property
def __name__(self):
return "Tool Correctness"

7 changes: 6 additions & 1 deletion deepeval/synthesizer/schema.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
from typing import List
from pydantic import BaseModel


class SyntheticData(BaseModel):
input: str


class SyntheticDataList(BaseModel):
data: List[SyntheticData]


class SQLData(BaseModel):
sql: str


class ComplianceData(BaseModel):
non_compliant: bool



class Response(BaseModel):
response: str
50 changes: 34 additions & 16 deletions deepeval/synthesizer/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@
)
from deepeval.synthesizer.context_generator import ContextGenerator
from deepeval.synthesizer.utils import initialize_embedding_model
from deepeval.synthesizer.schema import SyntheticData, SyntheticDataList, SQLData, ComplianceData, Response
from deepeval.synthesizer.schema import (
SyntheticData,
SyntheticDataList,
SQLData,
ComplianceData,
Response,
)
from deepeval.models import DeepEvalBaseLLM
from deepeval.progress_context import synthesizer_progress_context
from deepeval.metrics.utils import trimAndLoadJson, initialize_model
Expand Down Expand Up @@ -61,6 +67,7 @@

##################################################################


class Synthesizer:
def __init__(
self,
Expand All @@ -86,7 +93,6 @@ def generate(self, prompt: str) -> Tuple[str, str]:
return res.response, 0
except TypeError:
return self.model.generate(prompt), 0


async def a_generate(self, prompt: str) -> Tuple[str, str]:
if self.using_native_model:
Expand All @@ -100,7 +106,7 @@ async def a_generate(self, prompt: str) -> Tuple[str, str]:
return res.response, 0
except TypeError:
return await self.model.a_generate(prompt), 0

def generate_synthetic_inputs(self, prompt: str) -> List[SyntheticData]:
if self.using_native_model:
res, _ = self.model.generate(prompt)
Expand All @@ -117,7 +123,9 @@ def generate_synthetic_inputs(self, prompt: str) -> List[SyntheticData]:
data = trimAndLoadJson(res, self)
return [SyntheticData(**item) for item in data["data"]]

async def a_generate_synthetic_inputs(self, prompt: str) -> List[SyntheticData]:
async def a_generate_synthetic_inputs(
self, prompt: str
) -> List[SyntheticData]:
if self.using_native_model:
res, _ = await self.model.a_generate(prompt)
data = trimAndLoadJson(res, self)
Expand All @@ -132,7 +140,7 @@ async def a_generate_synthetic_inputs(self, prompt: str) -> List[SyntheticData]:
res = await self.model.a_generate(prompt)
data = trimAndLoadJson(res, self)
return [SyntheticData(**item) for item in data["data"]]

def generate_expected_output_sql(self, prompt: str) -> List[SyntheticData]:
if self.using_native_model:
res, _ = self.model.generate(prompt)
Expand All @@ -148,8 +156,10 @@ def generate_expected_output_sql(self, prompt: str) -> List[SyntheticData]:
res = self.model.generate(prompt)
data = trimAndLoadJson(res, self)
return data["sql"]

async def a_generate_sql_expected_output(self, prompt: str) -> List[SyntheticData]:

async def a_generate_sql_expected_output(
self, prompt: str
) -> List[SyntheticData]:
if self.using_native_model:
res, _ = await self.model.a_generate(prompt)
data = trimAndLoadJson(res, self)
Expand All @@ -164,7 +174,7 @@ async def a_generate_sql_expected_output(self, prompt: str) -> List[SyntheticDat
res = await self.model.a_generate(prompt)
data = trimAndLoadJson(res, self)
return data["sql"]

def generate_non_compliance(self, prompt: str) -> List[SyntheticData]:
if self.using_native_model:
res, _ = self.model.generate(prompt)
Expand All @@ -180,8 +190,10 @@ def generate_non_compliance(self, prompt: str) -> List[SyntheticData]:
res = self.model.generate(prompt)
data = trimAndLoadJson(res, self)
return data["non_compliant"]

async def a_generate_non_compliance(self, prompt: str) -> List[SyntheticData]:

async def a_generate_non_compliance(
self, prompt: str
) -> List[SyntheticData]:
if self.using_native_model:
res, _ = await self.model.a_generate(prompt)
data = trimAndLoadJson(res, self)
Expand Down Expand Up @@ -251,7 +263,7 @@ def _evolve_text(
evolution_method = random.choice(evolution_methods)
prompt = evolution_method(input=evolved_text, context=context)
evolved_text, _ = self.generate(prompt)

return evolved_text

async def _a_evolve_text(
Expand All @@ -267,7 +279,7 @@ async def _a_evolve_text(
for _ in range(num_evolutions):
evolution_method = random.choice(evolution_methods)
prompt = evolution_method(input=evolved_text, context=context)
evolved_text, _ = await self.a_generate(prompt)
evolved_text, _ = await self.a_generate(prompt)
return evolved_text

def _evolve_red_teaming_attack(
Expand Down Expand Up @@ -366,7 +378,9 @@ async def _a_generate_text_to_sql_from_contexts(
prompt = SynthesizerTemplate.generate_text2sql_expected_output(
input=golden.input, context="\n".join(golden.context)
)
golden.expected_output = await self.a_generate_sql_expected_output(prompt)
golden.expected_output = (
await self.a_generate_sql_expected_output(prompt)
)
goldens.append(golden)

async def _a_generate_red_teaming_from_contexts(
Expand All @@ -382,7 +396,7 @@ async def _a_generate_red_teaming_from_contexts(
if context:
prompt = SynthesizerTemplate.generate_synthetic_inputs(
context, max_goldens
)
)
else:
prompt = RedTeamSynthesizerTemplate.generate_synthetic_inputs(
max_goldens
Expand Down Expand Up @@ -410,7 +424,9 @@ async def _a_generate_red_teaming_from_contexts(
non_compliance_prompt = RedTeamSynthesizerTemplate.non_compliant(
evolved_attack
)
non_compliant = await self.a_generate_non_compliance(non_compliance_prompt)
non_compliant = await self.a_generate_non_compliance(
non_compliance_prompt
)
if non_compliant == False:
golden = Golden(input=evolved_attack, context=context)
if include_expected_output and context is not None:
Expand Down Expand Up @@ -876,7 +892,9 @@ def generate_goldens(
input=golden.input,
context="\n".join(golden.context),
)
golden.expected_output = self.generate_expected_output_sql(prompt)
golden.expected_output = (
self.generate_expected_output_sql(prompt)
)
goldens.append(golden)
self.synthetic_goldens.extend(goldens)
return goldens
Expand Down
17 changes: 11 additions & 6 deletions deepeval/test_case/llm_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class LLMTestCaseParams(Enum):
EXPECTED_TOOLS = "expected_tools"
REASONING = "reasoning"


@dataclass
class LLMTestCase:
input: str
Expand Down Expand Up @@ -46,24 +47,28 @@ def __post_init__(self):
raise TypeError(
"'retrieval_context' must be None or a list of strings"
)

# Ensure `tools_used` is None or a list of strings
if self.tools_used is not None:
if not isinstance(self.tools_used, list) or not all(
isinstance(item, str) for item in self.tools_used
):
raise TypeError("'tools_used' must be None or a list of strings")

raise TypeError(
"'tools_used' must be None or a list of strings"
)

# Ensure `expected_tools` is None or a list of strings
if self.expected_tools is not None:
if not isinstance(self.expected_tools, list) or not all(
isinstance(item, str) for item in self.expected_tools
):
raise TypeError("'expected_tools' must be None or a list of strings")

raise TypeError(
"'expected_tools' must be None or a list of strings"
)

# Ensure `reasoning` is None or a list of strings
if self.reasoning is not None:
if not isinstance(self.reasoning, list) or not all(
isinstance(item, str) for item in self.reasoning
):
raise TypeError("'reasoning' must be None or a list of strings")
raise TypeError("'reasoning' must be None or a list of strings")
6 changes: 3 additions & 3 deletions tests/test_everything.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ToxicityMetric,
GEval,
SummarizationMetric,
ToolCorrectnessMetric
ToolCorrectnessMetric,
)
from deepeval.metrics.ragas import RagasMetric
from deepeval import assert_test
Expand Down Expand Up @@ -202,7 +202,7 @@ def test_everything_2():
# retrieval_context=["I love coffee"],
context=["I love coffee"],
expected_tools=["mixer", "creamer", "dripper"],
tools_used=["mixer", "creamer", "mixer"]
tools_used=["mixer", "creamer", "mixer"],
)
assert_test(
test_case,
Expand All @@ -218,7 +218,7 @@ def test_everything_2():
# metric9,
# metric10,
# metric11,
metric12
metric12,
],
# run_async=False,
)
Expand Down

0 comments on commit d9b379d

Please sign in to comment.