Skip to content

Commit

Permalink
RELEASE: 0.0.2a02
Browse files Browse the repository at this point in the history
  • Loading branch information
brucewlee committed Aug 3, 2024
1 parent 18d6190 commit d7b74e6
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 32 deletions.
3 changes: 2 additions & 1 deletion nutcracker/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from nutcracker.data.task import Task
from nutcracker.data.pile import Pile
from nutcracker.data.pile import Pile
from nutcracker.data.instance_collection import InstanceCollection
4 changes: 4 additions & 0 deletions nutcracker/data/instance.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List
import copy
#
from nutcracker.utils import number_to_letter
#
Expand Down Expand Up @@ -45,6 +46,9 @@ def create_instance(

else:
raise ValueError("Invalid instance construction")

def copy(self):
return copy.deepcopy(self)



Expand Down
3 changes: 1 addition & 2 deletions nutcracker/evaluator/reporter/mfq_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,13 @@ def mfq_30_generate_report(data: List[MCQSurveyInstance], save_path: str = None,

options = ["Not at all relevant", "Not very relevant", "Slightly relevant", "Somewhat relevant", "Very relevant", "Extremely relevant"]
letter_to_number = {chr(ord('A') + i): i for i in range(len(options))}
print(letter_to_number)

scores = defaultdict(list)
interpretations = defaultdict(list)

for instance in data:
question_number = instance.question_number
judge_interpretation = instance.judge_interpretation.pop()
judge_interpretation = instance.judge_interpretation.pop()
if judge_interpretation in ['A', 'B', 'C', 'D', 'E', 'F']:
for foundation, questions in foundations.items():
if question_number in questions:
Expand Down
63 changes: 36 additions & 27 deletions nutcracker/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@


class OpenAI_ChatGPT():
def __init__(self, api_key):
def __init__(self, api_key, max_retries = 5):
self.model = "gpt-3.5-turbo-0125"
self.client_openai = OpenAI(
api_key=api_key,
)
self.max_retries = max_retries

def respond(self, user_prompt, max_retries=5):
def respond(self, user_prompt):
retry_count = 0
while retry_count < max_retries:
while retry_count < self.max_retries:
try:
completion = self.client_openai.chat.completions.create(
model=self.model,
Expand All @@ -39,15 +40,16 @@ def respond(self, user_prompt, max_retries=5):


class OpenAI_ChatGPT4():
def __init__(self, api_key):
def __init__(self, api_key, max_retries = 5):
self.model = "gpt-4-turbo-2024-04-09"
self.client_openai = OpenAI(
api_key=api_key,
)
self.max_retries = max_retries

def respond(self, user_prompt, max_retries=5):
def respond(self, user_prompt):
retry_count = 0
while retry_count < max_retries:
while retry_count < self.max_retries:
try:
completion = self.client_openai.chat.completions.create(
model=self.model,
Expand All @@ -68,15 +70,16 @@ def respond(self, user_prompt, max_retries=5):


class OpenAI_ChatGPT4o():
def __init__(self, api_key):
def __init__(self, api_key, max_retries = 5):
self.model = "gpt-4o-2024-05-13"
self.client_openai = OpenAI(
api_key=api_key,
)
self.max_retries = max_retries

def respond(self, user_prompt, max_retries=5):
def respond(self, user_prompt):
retry_count = 0
while retry_count < max_retries:
while retry_count < self.max_retries:
try:
completion = self.client_openai.chat.completions.create(
model=self.model,
Expand All @@ -97,17 +100,18 @@ def respond(self, user_prompt, max_retries=5):


class Bedrock_Claude3_Opus():
def __init__(self, aws_access_key_id, aws_secret_access_key, region_name):
def __init__(self, aws_access_key_id, aws_secret_access_key, region_name, max_retries = 5):
self.model = "anthropic.claude-3-opus-20240229-v1:0"
self.client_anthropic = AnthropicBedrock(
aws_access_key=os.environ['AWS_ACCESS_KEY'],
aws_secret_key=os.environ['AWS_SECRET_KEY'],
aws_region="us-west-2",
)
self.max_retries = max_retries

def respond(self, user_prompt, max_retries=5):
def respond(self, user_prompt):
retry_count = 0
while retry_count < max_retries:
while retry_count < self.max_retries:
try:
completion = self.client_anthropic.messages.create(
model=self.model,
Expand All @@ -129,17 +133,18 @@ def respond(self, user_prompt, max_retries=5):


class Bedrock_Claude3_Sonnet():
def __init__(self, aws_access_key_id, aws_secret_access_key, region_name):
def __init__(self, aws_access_key_id, aws_secret_access_key, region_name, max_retries = 5):
self.model = "anthropic.claude-3-sonnet-20240229-v1:0"
self.client_anthropic = AnthropicBedrock(
aws_access_key=os.environ['AWS_ACCESS_KEY'],
aws_secret_key=os.environ['AWS_SECRET_KEY'],
aws_region="us-east-1",
)
self.max_retries = max_retries

def respond(self, user_prompt, max_retries=5):
def respond(self, user_prompt):
retry_count = 0
while retry_count < max_retries:
while retry_count < self.max_retries:
try:
completion = self.client_anthropic.messages.create(
model=self.model,
Expand All @@ -161,17 +166,18 @@ def respond(self, user_prompt, max_retries=5):


class Bedrock_Claude3_Haiku():
def __init__(self, aws_access_key_id, aws_secret_access_key, region_name):
def __init__(self, aws_access_key_id, aws_secret_access_key, region_name, max_retries = 5):
self.model = "anthropic.claude-3-haiku-20240307-v1:0"
self.client_anthropic = AnthropicBedrock(
aws_access_key=os.environ['AWS_ACCESS_KEY'],
aws_secret_key=os.environ['AWS_SECRET_KEY'],
aws_region="us-east-1",
)
self.max_retries = max_retries

def respond(self, user_prompt, max_retries=5):
def respond(self, user_prompt):
retry_count = 0
while retry_count < max_retries:
while retry_count < self.max_retries:
try:
completion = self.client_anthropic.messages.create(
model=self.model,
Expand All @@ -193,15 +199,16 @@ def respond(self, user_prompt, max_retries=5):


class Cohere_CommandRPlus():
def __init__(self, api_key):
def __init__(self, api_key, max_retries = 5):
self.model = "command-r-plus"
self.client_cohere = cohere.Client(
api_key=api_key
)
self.max_retries = max_retries

def respond(self, user_prompt, max_retries=5):
def respond(self, user_prompt):
retry_count = 0
while retry_count < max_retries:
while retry_count < self.max_retries:
try:
completion = self.client_cohere.chat(
model=self.model,
Expand All @@ -220,24 +227,25 @@ def respond(self, user_prompt, max_retries=5):


class Bedrock_LLaMA3_70B_Inst():
def __init__(self, aws_access_key_id, aws_secret_access_key, region_name):
def __init__(self, aws_access_key_id, aws_secret_access_key, region_name, max_retries = 5):
self.model = "meta.llama3-70b-instruct-v1:0"
self.client_bedrock = boto3.client(
'bedrock-runtime',
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=region_name,
)
self.max_retries = max_retries

def respond(self, user_prompt, max_retries=5):
def respond(self, user_prompt):
prompt = self._format_prompt(user_prompt)
body = {
"prompt": prompt,
"max_gen_len": 1024,
}

retry_count = 0
while retry_count < max_retries:
while retry_count < self.max_retries:
try:
results = self.client_bedrock.invoke_model(
modelId=self.model,
Expand All @@ -263,24 +271,25 @@ def _format_prompt(self, user_prompt):


class Bedrock_LLaMA3_8B_Inst():
def __init__(self, aws_access_key_id, aws_secret_access_key, region_name):
def __init__(self, aws_access_key_id, aws_secret_access_key, region_name, max_retries = 5):
self.model = "meta.llama3-8b-instruct-v1:0"
self.client_bedrock = boto3.client(
'bedrock-runtime',
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=region_name,
)
self.max_retries = max_retries

def respond(self, user_prompt, max_retries=5):
def respond(self, user_prompt):
prompt = self._format_prompt(user_prompt)
body = {
"prompt": prompt,
"max_gen_len": 1024,
}

retry_count = 0
while retry_count < max_retries:
while retry_count < self.max_retries:
try:
results = self.client_bedrock.invoke_model(
modelId=self.model,
Expand Down
5 changes: 4 additions & 1 deletion nutcracker/runs/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from nutcracker.data.instance import Instance
from nutcracker.data.task import Task
from nutcracker.data.pile import Pile
from nutcracker.data.instance_collection import InstanceCollection
from nutcracker.utils import TqdmLoggingHandler
#
#
class Schema:
def __init__(
self,
model: object,
data: Union[Instance, Task, Pile, List[Instance]],
data: Union[Instance, InstanceCollection, Task, Pile, List[Instance]],
other_params: Optional[Dict] = None
) -> None:
"""Initialize a Schema object.
Expand Down Expand Up @@ -47,6 +48,8 @@ def _extract_instances(
instances.extend(data.instances)
elif isinstance(data, Pile):
instances.extend(data.instances)
elif isinstance(data, InstanceCollection) and all(isinstance(item, Instance) for item in data):
instances.extend(data.instances)
elif isinstance(data, list) and all(isinstance(item, Instance) for item in data):
instances.extend(data)
else:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from distutils.core import setup
from setuptools import find_packages

this_version='0.0.2a01'
this_version='0.0.2a02'

# python setup.py sdist
# python -m twine upload dist/*
Expand Down

0 comments on commit d7b74e6

Please sign in to comment.