diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index 54495afb..dc1c50a9 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -25,6 +25,7 @@ TableDescriptionRepository, ) from dataherald.eval import Evaluator +from dataherald.finetuning.openai_finetuning import OpenAIFineTuning from dataherald.repositories.base import ResponseRepository from dataherald.repositories.database_connections import DatabaseConnectionRepository from dataherald.repositories.finetunings import FinetuningsRepository @@ -56,6 +57,7 @@ TableDescriptionRequest, UpdateInstruction, ) +from dataherald.utils.models_context_window import OPENAI_CONTEXT_WIDNOW_SIZES from dataherald.utils.s3 import S3 logger = logging.getLogger(__name__) @@ -72,8 +74,10 @@ def async_scanning(scanner, database, scanner_request, storage): ) -def async_fine_tuning(): - pass +def async_fine_tuning(storage, model): + openai_fine_tuning = OpenAIFineTuning(storage, model) + openai_fine_tuning.create_fintuning_dataset() + openai_fine_tuning.create_fine_tuning_job() def delete_file(file_location: str): @@ -678,6 +682,12 @@ def create_finetuning_job( if not golden_records: raise HTTPException(status_code=404, detail="No golden records found") + if fine_tuning_request.base_llm.model_name not in OPENAI_CONTEXT_WIDNOW_SIZES: + raise HTTPException( + status_code=400, + detail=f"Model {fine_tuning_request.base_llm.model_name} not supported", + ) + model_repository = FinetuningsRepository(self.storage) model = model_repository.insert( Finetuning( @@ -689,9 +699,7 @@ def create_finetuning_job( ) ) - background_tasks.add_task( - async_fine_tuning, fine_tuning_request, self.storage, golden_records - ) + background_tasks.add_task(async_fine_tuning, self.storage, model) return model @@ -717,9 +725,9 @@ def cancel_finetuning_job( status_code=400, detail="Model has already been cancelled." ) - # Todo: Add code to cancel the fine tuning job + openai_fine_tuning = OpenAIFineTuning(self.storage, model) - return model + return openai_fine_tuning.cancel_finetuning_job() @override def get_finetuning_job(self, finetuning_job_id: str) -> Finetuning: @@ -727,4 +735,5 @@ def get_finetuning_job(self, finetuning_job_id: str) -> Finetuning: model = model_repository.find_by_id(finetuning_job_id) if not model: raise HTTPException(status_code=404, detail="Model not found") - return model + openai_fine_tuning = OpenAIFineTuning(self.storage, model) + return openai_fine_tuning.retrieve_finetuning_job() diff --git a/dataherald/finetuning/__init__.py b/dataherald/finetuning/__init__.py new file mode 100644 index 00000000..4521214b --- /dev/null +++ b/dataherald/finetuning/__init__.py @@ -0,0 +1,29 @@ +from abc import ABC, abstractmethod + +from dataherald.config import Component +from dataherald.types import Finetuning + + +class FinetuningModel(Component, ABC): + def __init__(self, storage): + self.storage = storage + + @abstractmethod + def count_tokens(self, messages: dict) -> int: + pass + + @abstractmethod + def create_fintuning_dataset(self): + pass + + @abstractmethod + def create_fine_tuning_job(self): + pass + + @abstractmethod + def retrieve_finetuning_job(self) -> Finetuning: + pass + + @abstractmethod + def cancel_finetuning_job(self) -> Finetuning: + pass diff --git a/dataherald/finetuning/__init__py b/dataherald/finetuning/__init__py deleted file mode 100644 index 4979f15d..00000000 --- a/dataherald/finetuning/__init__py +++ /dev/null @@ -1,2 +0,0 @@ -class Finetuning: - pass diff --git a/dataherald/finetuning/openai_finetuning.py b/dataherald/finetuning/openai_finetuning.py index 5f7f9cb8..e4028ab2 100644 --- a/dataherald/finetuning/openai_finetuning.py +++ b/dataherald/finetuning/openai_finetuning.py @@ -5,25 +5,47 @@ import uuid from typing import Any, List -import openai +import tiktoken from bson.objectid import ObjectId +from openai import OpenAI +from overrides import override +from tiktoken import Encoding from dataherald.db_scanner.models.types import TableDescription, TableDescriptionStatus from dataherald.db_scanner.repository.base import TableDescriptionRepository +from dataherald.finetuning import FinetuningModel from dataherald.repositories.database_connections import DatabaseConnectionRepository from dataherald.repositories.finetunings import FinetuningsRepository from dataherald.repositories.golden_records import GoldenRecordRepository from dataherald.types import Finetuning from dataherald.utils.agent_prompts import FINETUNING_SYSTEM_INFORMATION +from dataherald.utils.models_context_window import OPENAI_CONTEXT_WIDNOW_SIZES FILE_PROCESSING_ATTEMPTS = 20 logger = logging.getLogger(__name__) -class OpenAIFineTuning: - finetuning_dataset_path: str - def format_columns(self, table: TableDescription, top_k: int = 100) -> str: +class OpenAIFineTuning(FinetuningModel): + encoding: Encoding + fine_tuning_model: Finetuning + storage: Any + client: OpenAI + + def __init__(self, storage: Any, fine_tuning_model: Finetuning): + self.storage = storage + self.fine_tuning_model = fine_tuning_model + db_connection_repository = DatabaseConnectionRepository(storage) + db_connection = db_connection_repository.find_by_id( + fine_tuning_model.db_connection_id + ) + self.encoding = tiktoken.encoding_for_model( + fine_tuning_model.base_llm.model_name + ) + self.client = OpenAI(api_key=db_connection.decrypt_api_key()) + + @classmethod + def format_columns(cls, table: TableDescription, top_k: int = 100) -> str: """ format_columns formats the columns. @@ -65,14 +87,14 @@ def format_columns(self, table: TableDescription, top_k: int = 100) -> str: ) return columns_information - @staticmethod - def format_dataset(self, db_scan: List[TableDescription]) -> str: + @classmethod + def format_dataset(cls, db_scan: List[TableDescription]) -> str: schema_of_database = "" for table in db_scan: tables_schema = table.table_schema schema_of_database += f"{tables_schema}\n" schema_of_database += "# Categorical Columns:\n" - columns_information = self.format_columns(table) + columns_information = cls.format_columns(table) schema_of_database += columns_information sample_rows = table.examples schema_of_database += "# Sample rows:\n" @@ -83,27 +105,36 @@ def format_dataset(self, db_scan: List[TableDescription]) -> str: schema_of_database += "\n\n" return schema_of_database - @classmethod - def create_fintuning_dataset(cls, fine_tuning_request: Finetuning, storage: Any): - db_connection_id = fine_tuning_request.db_connection_id - repository = TableDescriptionRepository(storage) + @override + def count_tokens(self, messages: dict) -> int: + prompt = "" + for message in messages["messages"]: + prompt += message["content"] + return len(self.encoding.encode(prompt)) + + @override + def create_fintuning_dataset(self): + db_connection_id = self.fine_tuning_model.db_connection_id + repository = TableDescriptionRepository(self.storage) db_scan = repository.get_all_tables_by_db( { "db_connection_id": ObjectId(db_connection_id), "status": TableDescriptionStatus.SYNCHRONIZED.value, } ) - golden_records_repository = GoldenRecordRepository(storage) - database_schema = cls.format_dataset(db_scan) - cls.finetuning_dataset_path = f"tmp/{str(uuid.uuid4())}.jsonl" - for golden_record_id in fine_tuning_request.golden_records: + golden_records_repository = GoldenRecordRepository(self.storage) + database_schema = self.format_dataset(db_scan) + finetuning_dataset_path = f"tmp/{str(uuid.uuid4())}.jsonl" + model_repository = FinetuningsRepository(self.storage) + model = model_repository.find_by_id(self.fine_tuning_model.id) + for golden_record_id in self.fine_tuning_model.golden_records: golden_record = golden_records_repository.find_by_id(golden_record_id) question = golden_record.question query = golden_record.sql_query system_prompt = FINETUNING_SYSTEM_INFORMATION + database_schema user_prompt = "User Question: " + question + "\n SQL: " assistant_prompt = query + "\n" - with open(cls.finetuning_dataset_path, "a") as outfile: + with open(finetuning_dataset_path, "a") as outfile: messages = { "messages": [ {"role": "system", "content": f"{system_prompt}"}, @@ -111,79 +142,86 @@ def create_fintuning_dataset(cls, fine_tuning_request: Finetuning, storage: Any) {"role": "assistant", "content": f"{assistant_prompt}"}, ] } + number_of_tokens = self.count_tokens(messages) + if ( + number_of_tokens + > OPENAI_CONTEXT_WIDNOW_SIZES[ + self.fine_tuning_model.base_llm.model_name + ] + ): + model.status = "failed" + model.error = "The number of tokens in the prompt is too large" + model_repository.update(model) + os.remove(finetuning_dataset_path) + return json.dump(messages, outfile) outfile.write("\n") - db_connection_repository = DatabaseConnectionRepository(storage) - db_connection = db_connection_repository.find_by_id( - fine_tuning_request.db_connection_id - ) - openai.api_key = db_connection.decrypt_api_key() - model_repository = FinetuningsRepository(storage) - model = model_repository.find_by_id(fine_tuning_request.id) - model.finetuning_file_id = openai.File.create(file=open(cls.finetuning_dataset_path,purpose='fine-tune'))['id'] + model.finetuning_file_id = self.client.files.create( + file=open(finetuning_dataset_path, "rb"), purpose="fine-tune" + ).id model_repository.update(model) - os.remove(cls.finetuning_dataset_path) - + os.remove(finetuning_dataset_path) - @classmethod - def create_fine_tuning_job(cls, fine_tuning_request: Finetuning, storage: Any): - db_connection_repository = DatabaseConnectionRepository(storage) - db_connection = db_connection_repository.find_by_id( - fine_tuning_request.db_connection_id - ) - openai.api_key = db_connection.decrypt_api_key() - model_repository = FinetuningsRepository(storage) - model = model_repository.find_by_id(fine_tuning_request.id) + def check_file_status(self, file_id: str) -> bool: retrieve_file_attempt = 0 while True: - if openai.File.retrieve(id=model.finetuning_file_id)["status"] == "processed": - break + file_info = self.client.files.retrieve(file_id=file_id) + if file_info.status == "processed": + return True time.sleep(5) retrieve_file_attempt += 1 if retrieve_file_attempt == FILE_PROCESSING_ATTEMPTS: - model.status = "failed" - model.error = "File processing failed" - model_repository.update(model) - return - finetuning_request = openai.FineTune.create( - training_file=model.finetuning_file_id, - model=model.base_llm.model_name, - hyperparameters= model.base_llm.model_parameters + return False + + @override + def create_fine_tuning_job(self): + model_repository = FinetuningsRepository(self.storage) + model = model_repository.find_by_id(self.fine_tuning_model.id) + if self.check_file_status(model.finetuning_file_id): + finetuning_request = self.client.fine_tuning.jobs.create( + training_file=model.finetuning_file_id, + model=model.base_llm.model_name, + hyperparameters=model.base_llm.model_parameters + if model.base_llm.model_parameters + else { + "batch_size": 1, + "learning_rate_multiplier": "auto", + "n_epochs": 3, + }, + ) + model.finetuning_job_id = finetuning_request.id + if finetuning_request.status == "failed": + model.error = "Fine tuning failed before starting" + model.status = finetuning_request.status + model_repository.update(model) + else: + model.status = "failed" + model.error = "File processing failed" + model_repository.update(model) + + @override + def retrieve_finetuning_job(self) -> Finetuning: + model_repository = FinetuningsRepository(self.storage) + model = model_repository.find_by_id(self.fine_tuning_model.id) + finetuning_request = self.client.fine_tuning.jobs.retrieve( + fine_tuning_job_id=model.finetuning_job_id ) - model.finetuning_job_id = finetuning_request["id"] - if finetuning_request["status"] == "failed": - model.error = "Fine tuning failed before starting" - model.status = finetuning_request["status"] + if finetuning_request.status == "failed": + model.error = finetuning_request.error.message + model.status = finetuning_request.status + if finetuning_request.fine_tuned_model: + model.model_id = finetuning_request.fine_tuned_model model_repository.update(model) - - @classmethod - def retrieve_finetuning_job(cls, fine_tuning_request: Finetuning, storage: Any): - db_connection_repository = DatabaseConnectionRepository(storage) - db_connection = db_connection_repository.find_by_id( - fine_tuning_request.db_connection_id + return model + + @override + def cancel_finetuning_job(self) -> Finetuning: + model_repository = FinetuningsRepository(self.storage) + model = model_repository.find_by_id(self.fine_tuning_model.id) + finetuning_request = self.client.fine_tuning.jobs.cancel( + fine_tuning_job_id=model.finetuning_job_id ) - openai.api_key = db_connection.decrypt_api_key() - model_repository = FinetuningsRepository(storage) - model = model_repository.find_by_id(fine_tuning_request.id) - finetuning_request = openai.FineTune.retrieve(id=model.finetuning_job_id) - if finetuning_request["status"] == "failed": - model.error = "Fine tuning failed during processing by OpenAI" - model.status = finetuning_request["status"] - model_repository.update(model) - - @classmethod - def cancel_finetuning_job(cls, fine_tuning_request: Finetuning, storage: Any): - db_connection_repository = DatabaseConnectionRepository(storage) - db_connection = db_connection_repository.find_by_id( - fine_tuning_request.db_connection_id - ) - openai.api_key = db_connection.decrypt_api_key() - model_repository = FinetuningsRepository(storage) - model = model_repository.find_by_id(fine_tuning_request.id) - finetuning_request = openai.FineTune.cancel(id=model.finetuning_job_id) - model.status = finetuning_request["status"] + model.status = finetuning_request.status model.error = "Fine tuning cancelled by the user" model_repository.update(model) - - - + return model diff --git a/dataherald/types.py b/dataherald/types.py index 3d69a340..99efe9e4 100644 --- a/dataherald/types.py +++ b/dataherald/types.py @@ -147,6 +147,7 @@ class FineTuningStatus(Enum): SUCCEEDED = "succeeded" FAILED = "failed" CANCELLED = "cancelled" + VALIDATING_FILES = "validating_files" class BaseLLM(BaseModel): diff --git a/requirements.txt b/requirements.txt index d102da95..4e4d9173 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ dnspython==2.3.0 fastapi==0.98.0 httpx==0.24.1 -langchain==0.0.312 +langchain==0.0.335 load-dotenv==0.1.0 mypy-extensions==1.0.0 -openai==0.27.8 +openai==1.3.6 openapi-schema-pydantic==1.2.4 overrides==7.3.1 packaging==23.1