Skip to content

Commit

Permalink
DH-5033/ finalized llm finetuning with openai
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadrezaPourreza committed Nov 30, 2023
1 parent 9922110 commit d683a32
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 90 deletions.
25 changes: 17 additions & 8 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
TableDescriptionRepository,
)
from dataherald.eval import Evaluator
from dataherald.finetuning.openai_finetuning import OpenAIFineTuning
from dataherald.repositories.base import ResponseRepository
from dataherald.repositories.database_connections import DatabaseConnectionRepository
from dataherald.repositories.finetunings import FinetuningsRepository
Expand Down Expand Up @@ -56,6 +57,7 @@
TableDescriptionRequest,
UpdateInstruction,
)
from dataherald.utils.models_context_window import OPENAI_CONTEXT_WIDNOW_SIZES
from dataherald.utils.s3 import S3

logger = logging.getLogger(__name__)
Expand All @@ -72,8 +74,10 @@ def async_scanning(scanner, database, scanner_request, storage):
)


def async_fine_tuning():
pass
def async_fine_tuning(storage, model):
openai_fine_tuning = OpenAIFineTuning(storage, model)
openai_fine_tuning.create_fintuning_dataset()
openai_fine_tuning.create_fine_tuning_job()


def delete_file(file_location: str):
Expand Down Expand Up @@ -678,6 +682,12 @@ def create_finetuning_job(
if not golden_records:
raise HTTPException(status_code=404, detail="No golden records found")

if fine_tuning_request.base_llm.model_name not in OPENAI_CONTEXT_WIDNOW_SIZES:
raise HTTPException(
status_code=400,
detail=f"Model {fine_tuning_request.base_llm.model_name} not supported",
)

model_repository = FinetuningsRepository(self.storage)
model = model_repository.insert(
Finetuning(
Expand All @@ -689,9 +699,7 @@ def create_finetuning_job(
)
)

background_tasks.add_task(
async_fine_tuning, fine_tuning_request, self.storage, golden_records
)
background_tasks.add_task(async_fine_tuning, self.storage, model)

return model

Expand All @@ -717,14 +725,15 @@ def cancel_finetuning_job(
status_code=400, detail="Model has already been cancelled."
)

# Todo: Add code to cancel the fine tuning job
openai_fine_tuning = OpenAIFineTuning(self.storage, model)

return model
return openai_fine_tuning.cancel_finetuning_job()

@override
def get_finetuning_job(self, finetuning_job_id: str) -> Finetuning:
model_repository = FinetuningsRepository(self.storage)
model = model_repository.find_by_id(finetuning_job_id)
if not model:
raise HTTPException(status_code=404, detail="Model not found")
return model
openai_fine_tuning = OpenAIFineTuning(self.storage, model)
return openai_fine_tuning.retrieve_finetuning_job()
29 changes: 29 additions & 0 deletions dataherald/finetuning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from abc import ABC, abstractmethod

from dataherald.config import Component
from dataherald.types import Finetuning


class FinetuningModel(Component, ABC):
def __init__(self, storage):
self.storage = storage

@abstractmethod
def count_tokens(self, messages: dict) -> int:
pass

@abstractmethod
def create_fintuning_dataset(self):
pass

@abstractmethod
def create_fine_tuning_job(self):
pass

@abstractmethod
def retrieve_finetuning_job(self) -> Finetuning:
pass

@abstractmethod
def cancel_finetuning_job(self) -> Finetuning:
pass
2 changes: 0 additions & 2 deletions dataherald/finetuning/__init__py

This file was deleted.

194 changes: 116 additions & 78 deletions dataherald/finetuning/openai_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,47 @@
import uuid
from typing import Any, List

import openai
import tiktoken
from bson.objectid import ObjectId
from openai import OpenAI
from overrides import override
from tiktoken import Encoding

from dataherald.db_scanner.models.types import TableDescription, TableDescriptionStatus
from dataherald.db_scanner.repository.base import TableDescriptionRepository
from dataherald.finetuning import FinetuningModel
from dataherald.repositories.database_connections import DatabaseConnectionRepository
from dataherald.repositories.finetunings import FinetuningsRepository
from dataherald.repositories.golden_records import GoldenRecordRepository
from dataherald.types import Finetuning
from dataherald.utils.agent_prompts import FINETUNING_SYSTEM_INFORMATION
from dataherald.utils.models_context_window import OPENAI_CONTEXT_WIDNOW_SIZES

FILE_PROCESSING_ATTEMPTS = 20

logger = logging.getLogger(__name__)

class OpenAIFineTuning:
finetuning_dataset_path: str

def format_columns(self, table: TableDescription, top_k: int = 100) -> str:
class OpenAIFineTuning(FinetuningModel):
encoding: Encoding
fine_tuning_model: Finetuning
storage: Any
client: OpenAI

def __init__(self, storage: Any, fine_tuning_model: Finetuning):
self.storage = storage
self.fine_tuning_model = fine_tuning_model
db_connection_repository = DatabaseConnectionRepository(storage)
db_connection = db_connection_repository.find_by_id(
fine_tuning_model.db_connection_id
)
self.encoding = tiktoken.encoding_for_model(
fine_tuning_model.base_llm.model_name
)
self.client = OpenAI(api_key=db_connection.decrypt_api_key())

@classmethod
def format_columns(cls, table: TableDescription, top_k: int = 100) -> str:
"""
format_columns formats the columns.
Expand Down Expand Up @@ -65,14 +87,14 @@ def format_columns(self, table: TableDescription, top_k: int = 100) -> str:
)
return columns_information

@staticmethod
def format_dataset(self, db_scan: List[TableDescription]) -> str:
@classmethod
def format_dataset(cls, db_scan: List[TableDescription]) -> str:
schema_of_database = ""
for table in db_scan:
tables_schema = table.table_schema
schema_of_database += f"{tables_schema}\n"
schema_of_database += "# Categorical Columns:\n"
columns_information = self.format_columns(table)
columns_information = cls.format_columns(table)
schema_of_database += columns_information
sample_rows = table.examples
schema_of_database += "# Sample rows:\n"
Expand All @@ -83,107 +105,123 @@ def format_dataset(self, db_scan: List[TableDescription]) -> str:
schema_of_database += "\n\n"
return schema_of_database

@classmethod
def create_fintuning_dataset(cls, fine_tuning_request: Finetuning, storage: Any):
db_connection_id = fine_tuning_request.db_connection_id
repository = TableDescriptionRepository(storage)
@override
def count_tokens(self, messages: dict) -> int:
prompt = ""
for message in messages["messages"]:
prompt += message["content"]
return len(self.encoding.encode(prompt))

@override
def create_fintuning_dataset(self):
db_connection_id = self.fine_tuning_model.db_connection_id
repository = TableDescriptionRepository(self.storage)
db_scan = repository.get_all_tables_by_db(
{
"db_connection_id": ObjectId(db_connection_id),
"status": TableDescriptionStatus.SYNCHRONIZED.value,
}
)
golden_records_repository = GoldenRecordRepository(storage)
database_schema = cls.format_dataset(db_scan)
cls.finetuning_dataset_path = f"tmp/{str(uuid.uuid4())}.jsonl"
for golden_record_id in fine_tuning_request.golden_records:
golden_records_repository = GoldenRecordRepository(self.storage)
database_schema = self.format_dataset(db_scan)
finetuning_dataset_path = f"tmp/{str(uuid.uuid4())}.jsonl"
model_repository = FinetuningsRepository(self.storage)
model = model_repository.find_by_id(self.fine_tuning_model.id)
for golden_record_id in self.fine_tuning_model.golden_records:
golden_record = golden_records_repository.find_by_id(golden_record_id)
question = golden_record.question
query = golden_record.sql_query
system_prompt = FINETUNING_SYSTEM_INFORMATION + database_schema
user_prompt = "User Question: " + question + "\n SQL: "
assistant_prompt = query + "\n"
with open(cls.finetuning_dataset_path, "a") as outfile:
with open(finetuning_dataset_path, "a") as outfile:
messages = {
"messages": [
{"role": "system", "content": f"{system_prompt}"},
{"role": "user", "content": f"Question : {user_prompt}"},
{"role": "assistant", "content": f"{assistant_prompt}"},
]
}
number_of_tokens = self.count_tokens(messages)
if (
number_of_tokens
> OPENAI_CONTEXT_WIDNOW_SIZES[
self.fine_tuning_model.base_llm.model_name
]
):
model.status = "failed"
model.error = "The number of tokens in the prompt is too large"
model_repository.update(model)
os.remove(finetuning_dataset_path)
return
json.dump(messages, outfile)
outfile.write("\n")
db_connection_repository = DatabaseConnectionRepository(storage)
db_connection = db_connection_repository.find_by_id(
fine_tuning_request.db_connection_id
)
openai.api_key = db_connection.decrypt_api_key()
model_repository = FinetuningsRepository(storage)
model = model_repository.find_by_id(fine_tuning_request.id)
model.finetuning_file_id = openai.File.create(file=open(cls.finetuning_dataset_path,purpose='fine-tune'))['id']
model.finetuning_file_id = self.client.files.create(
file=open(finetuning_dataset_path, "rb"), purpose="fine-tune"
).id
model_repository.update(model)
os.remove(cls.finetuning_dataset_path)

os.remove(finetuning_dataset_path)

@classmethod
def create_fine_tuning_job(cls, fine_tuning_request: Finetuning, storage: Any):
db_connection_repository = DatabaseConnectionRepository(storage)
db_connection = db_connection_repository.find_by_id(
fine_tuning_request.db_connection_id
)
openai.api_key = db_connection.decrypt_api_key()
model_repository = FinetuningsRepository(storage)
model = model_repository.find_by_id(fine_tuning_request.id)
def check_file_status(self, file_id: str) -> bool:
retrieve_file_attempt = 0
while True:
if openai.File.retrieve(id=model.finetuning_file_id)["status"] == "processed":
break
file_info = self.client.files.retrieve(file_id=file_id)
if file_info.status == "processed":
return True
time.sleep(5)
retrieve_file_attempt += 1
if retrieve_file_attempt == FILE_PROCESSING_ATTEMPTS:
model.status = "failed"
model.error = "File processing failed"
model_repository.update(model)
return
finetuning_request = openai.FineTune.create(
training_file=model.finetuning_file_id,
model=model.base_llm.model_name,
hyperparameters= model.base_llm.model_parameters
return False

@override
def create_fine_tuning_job(self):
model_repository = FinetuningsRepository(self.storage)
model = model_repository.find_by_id(self.fine_tuning_model.id)
if self.check_file_status(model.finetuning_file_id):
finetuning_request = self.client.fine_tuning.jobs.create(
training_file=model.finetuning_file_id,
model=model.base_llm.model_name,
hyperparameters=model.base_llm.model_parameters
if model.base_llm.model_parameters
else {
"batch_size": 1,
"learning_rate_multiplier": "auto",
"n_epochs": 3,
},
)
model.finetuning_job_id = finetuning_request.id
if finetuning_request.status == "failed":
model.error = "Fine tuning failed before starting"
model.status = finetuning_request.status
model_repository.update(model)
else:
model.status = "failed"
model.error = "File processing failed"
model_repository.update(model)

@override
def retrieve_finetuning_job(self) -> Finetuning:
model_repository = FinetuningsRepository(self.storage)
model = model_repository.find_by_id(self.fine_tuning_model.id)
finetuning_request = self.client.fine_tuning.jobs.retrieve(
fine_tuning_job_id=model.finetuning_job_id
)
model.finetuning_job_id = finetuning_request["id"]
if finetuning_request["status"] == "failed":
model.error = "Fine tuning failed before starting"
model.status = finetuning_request["status"]
if finetuning_request.status == "failed":
model.error = finetuning_request.error.message
model.status = finetuning_request.status
if finetuning_request.fine_tuned_model:
model.model_id = finetuning_request.fine_tuned_model
model_repository.update(model)

@classmethod
def retrieve_finetuning_job(cls, fine_tuning_request: Finetuning, storage: Any):
db_connection_repository = DatabaseConnectionRepository(storage)
db_connection = db_connection_repository.find_by_id(
fine_tuning_request.db_connection_id
return model

@override
def cancel_finetuning_job(self) -> Finetuning:
model_repository = FinetuningsRepository(self.storage)
model = model_repository.find_by_id(self.fine_tuning_model.id)
finetuning_request = self.client.fine_tuning.jobs.cancel(
fine_tuning_job_id=model.finetuning_job_id
)
openai.api_key = db_connection.decrypt_api_key()
model_repository = FinetuningsRepository(storage)
model = model_repository.find_by_id(fine_tuning_request.id)
finetuning_request = openai.FineTune.retrieve(id=model.finetuning_job_id)
if finetuning_request["status"] == "failed":
model.error = "Fine tuning failed during processing by OpenAI"
model.status = finetuning_request["status"]
model_repository.update(model)

@classmethod
def cancel_finetuning_job(cls, fine_tuning_request: Finetuning, storage: Any):
db_connection_repository = DatabaseConnectionRepository(storage)
db_connection = db_connection_repository.find_by_id(
fine_tuning_request.db_connection_id
)
openai.api_key = db_connection.decrypt_api_key()
model_repository = FinetuningsRepository(storage)
model = model_repository.find_by_id(fine_tuning_request.id)
finetuning_request = openai.FineTune.cancel(id=model.finetuning_job_id)
model.status = finetuning_request["status"]
model.status = finetuning_request.status
model.error = "Fine tuning cancelled by the user"
model_repository.update(model)



return model
1 change: 1 addition & 0 deletions dataherald/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class FineTuningStatus(Enum):
SUCCEEDED = "succeeded"
FAILED = "failed"
CANCELLED = "cancelled"
VALIDATING_FILES = "validating_files"


class BaseLLM(BaseModel):
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
dnspython==2.3.0
fastapi==0.98.0
httpx==0.24.1
langchain==0.0.312
langchain==0.0.335
load-dotenv==0.1.0
mypy-extensions==1.0.0
openai==0.27.8
openai==1.3.6
openapi-schema-pydantic==1.2.4
overrides==7.3.1
packaging==23.1
Expand Down

0 comments on commit d683a32

Please sign in to comment.