diff --git a/dataherald/api/__init__.py b/dataherald/api/__init__.py index 84140c46..fdbc8f96 100644 --- a/dataherald/api/__init__.py +++ b/dataherald/api/__init__.py @@ -9,8 +9,11 @@ from dataherald.db_scanner.models.types import TableDescription from dataherald.sql_database.models.types import DatabaseConnection, SSHSettings from dataherald.types import ( + CancelFineTuningRequest, CreateResponseRequest, DatabaseConnectionRequest, + Finetuning, + FineTuningRequest, GoldenRecord, GoldenRecordRequest, Instruction, @@ -167,3 +170,19 @@ def update_instruction( instruction_request: UpdateInstruction, ) -> Instruction: pass + + @abstractmethod + def create_finetuning_job( + self, fine_tuning_request: FineTuningRequest, background_tasks: BackgroundTasks + ) -> Finetuning: + pass + + @abstractmethod + def cancel_finetuning_job( + self, cancel_fine_tuning_request: CancelFineTuningRequest + ) -> Finetuning: + pass + + @abstractmethod + def get_finetuning_job(self, finetuning_job_id: str) -> Finetuning: + pass diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index 26409845..54495afb 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -27,6 +27,7 @@ from dataherald.eval import Evaluator from dataherald.repositories.base import ResponseRepository from dataherald.repositories.database_connections import DatabaseConnectionRepository +from dataherald.repositories.finetunings import FinetuningsRepository from dataherald.repositories.golden_records import GoldenRecordRepository from dataherald.repositories.instructions import InstructionRepository from dataherald.repositories.question import QuestionRepository @@ -39,8 +40,11 @@ from dataherald.sql_generator import SQLGenerator from dataherald.sql_generator.generates_nl_answer import GeneratesNlAnswer from dataherald.types import ( + CancelFineTuningRequest, CreateResponseRequest, DatabaseConnectionRequest, + Finetuning, + FineTuningRequest, GoldenRecord, GoldenRecordRequest, Instruction, @@ -68,6 +72,10 @@ def async_scanning(scanner, database, scanner_request, storage): ) +def async_fine_tuning(): + pass + + def delete_file(file_location: str): os.remove(file_location) @@ -638,3 +646,85 @@ def update_instruction( ) instruction_repository.update(updated_instruction) return json.loads(json_util.dumps(updated_instruction)) + + @override + def create_finetuning_job( + self, fine_tuning_request: FineTuningRequest, background_tasks: BackgroundTasks + ) -> Finetuning: + db_connection_repository = DatabaseConnectionRepository(self.storage) + + db_connection = db_connection_repository.find_by_id( + fine_tuning_request.db_connection_id + ) + if not db_connection: + raise HTTPException(status_code=404, detail="Database connection not found") + + golden_records_repository = GoldenRecordRepository(self.storage) + golden_records = [] + if fine_tuning_request.golden_records: + for golden_record_id in fine_tuning_request.golden_records: + golden_record = golden_records_repository.find_by_id(golden_record_id) + if not golden_record: + raise HTTPException( + status_code=404, detail="Golden record not found" + ) + golden_records.append(golden_record) + else: + golden_records = golden_records_repository.find_by( + {"db_connection_id": ObjectId(fine_tuning_request.db_connection_id)}, + page=0, + limit=0, + ) + if not golden_records: + raise HTTPException(status_code=404, detail="No golden records found") + + model_repository = FinetuningsRepository(self.storage) + model = model_repository.insert( + Finetuning( + db_connection_id=fine_tuning_request.db_connection_id, + base_llm=fine_tuning_request.base_llm, + golden_records=[ + str(golden_record.id) for golden_record in golden_records + ], + ) + ) + + background_tasks.add_task( + async_fine_tuning, fine_tuning_request, self.storage, golden_records + ) + + return model + + @override + def cancel_finetuning_job( + self, cancel_fine_tuning_request: CancelFineTuningRequest + ) -> Finetuning: + model_repository = FinetuningsRepository(self.storage) + model = model_repository.find_by_id(cancel_fine_tuning_request.finetuning_id) + if not model: + raise HTTPException(status_code=404, detail="Model not found") + + if model.status == "succeeded": + raise HTTPException( + status_code=400, detail="Model has already succeeded. Cannot cancel." + ) + if model.status == "failed": + raise HTTPException( + status_code=400, detail="Model has already failed. Cannot cancel." + ) + if model.status == "cancelled": + raise HTTPException( + status_code=400, detail="Model has already been cancelled." + ) + + # Todo: Add code to cancel the fine tuning job + + return model + + @override + def get_finetuning_job(self, finetuning_job_id: str) -> Finetuning: + model_repository = FinetuningsRepository(self.storage) + model = model_repository.find_by_id(finetuning_job_id) + if not model: + raise HTTPException(status_code=404, detail="Model not found") + return model diff --git a/dataherald/finetuning/__init__py b/dataherald/finetuning/__init__py new file mode 100644 index 00000000..4979f15d --- /dev/null +++ b/dataherald/finetuning/__init__py @@ -0,0 +1,2 @@ +class Finetuning: + pass diff --git a/dataherald/finetuning/openai_finetuning.py b/dataherald/finetuning/openai_finetuning.py index 838814bb..5f7f9cb8 100644 --- a/dataherald/finetuning/openai_finetuning.py +++ b/dataherald/finetuning/openai_finetuning.py @@ -1,4 +1,5 @@ import json +import logging import os import time import uuid @@ -17,6 +18,7 @@ FILE_PROCESSING_ATTEMPTS = 20 +logger = logging.getLogger(__name__) class OpenAIFineTuning: finetuning_dataset_path: str @@ -118,12 +120,11 @@ def create_fintuning_dataset(cls, fine_tuning_request: Finetuning, storage: Any) openai.api_key = db_connection.decrypt_api_key() model_repository = FinetuningsRepository(storage) model = model_repository.find_by_id(fine_tuning_request.id) - model.finetuning_file_id = openai.File.create( - file=open(cls.finetuning_dataset_path, purpose="fine-tune") - )["id"] + model.finetuning_file_id = openai.File.create(file=open(cls.finetuning_dataset_path,purpose='fine-tune'))['id'] model_repository.update(model) os.remove(cls.finetuning_dataset_path) + @classmethod def create_fine_tuning_job(cls, fine_tuning_request: Finetuning, storage: Any): db_connection_repository = DatabaseConnectionRepository(storage) @@ -135,10 +136,7 @@ def create_fine_tuning_job(cls, fine_tuning_request: Finetuning, storage: Any): model = model_repository.find_by_id(fine_tuning_request.id) retrieve_file_attempt = 0 while True: - if ( - openai.File.retrieve(id=model.finetuning_file_id)["status"] - == "processed" - ): + if openai.File.retrieve(id=model.finetuning_file_id)["status"] == "processed": break time.sleep(5) retrieve_file_attempt += 1 @@ -150,7 +148,7 @@ def create_fine_tuning_job(cls, fine_tuning_request: Finetuning, storage: Any): finetuning_request = openai.FineTune.create( training_file=model.finetuning_file_id, model=model.base_llm.model_name, - hyperparameters=model.base_llm.model_parameters, + hyperparameters= model.base_llm.model_parameters ) model.finetuning_job_id = finetuning_request["id"] if finetuning_request["status"] == "failed": @@ -186,3 +184,6 @@ def cancel_finetuning_job(cls, fine_tuning_request: Finetuning, storage: Any): model.status = finetuning_request["status"] model.error = "Fine tuning cancelled by the user" model_repository.update(model) + + + diff --git a/dataherald/repositories/finetunings.py b/dataherald/repositories/finetunings.py new file mode 100644 index 00000000..baf3aef9 --- /dev/null +++ b/dataherald/repositories/finetunings.py @@ -0,0 +1,61 @@ +from bson.objectid import ObjectId + +from dataherald.types import Finetuning + +DB_COLLECTION = "finetunings" + + +class FinetuningsRepository: + def __init__(self, storage): + self.storage = storage + + def insert(self, model: Finetuning) -> Finetuning: + model.id = str( + self.storage.insert_one(DB_COLLECTION, model.dict(exclude={"id"})) + ) + return model + + def find_one(self, query: dict) -> Finetuning | None: + row = self.storage.find_one(DB_COLLECTION, query) + if not row: + return None + obj = Finetuning(**row) + obj.id = str(row["_id"]) + return obj + + def update(self, model: Finetuning) -> Finetuning: + self.storage.update_or_create( + DB_COLLECTION, + {"_id": ObjectId(model.id)}, + model.dict(exclude={"id"}), + ) + return model + + def find_by_id(self, id: str) -> Finetuning | None: + row = self.storage.find_one(DB_COLLECTION, {"_id": ObjectId(id)}) + if not row: + return None + obj = Finetuning(**row) + obj.id = str(row["_id"]) + return obj + + def find_by(self, query: dict, page: int = 1, limit: int = 10) -> list[Finetuning]: + rows = self.storage.find(DB_COLLECTION, query, page=page, limit=limit) + result = [] + for row in rows: + obj = Finetuning(**row) + obj.id = str(row["_id"]) + result.append(obj) + return result + + def find_all(self, page: int = 0, limit: int = 0) -> list[Finetuning]: + rows = self.storage.find_all(DB_COLLECTION, page=page, limit=limit) + result = [] + for row in rows: + obj = Finetuning(**row) + obj.id = str(row["_id"]) + result.append(obj) + return result + + def delete_by_id(self, id: str) -> int: + return self.storage.delete_by_id(DB_COLLECTION, id) diff --git a/dataherald/server/fastapi/__init__.py b/dataherald/server/fastapi/__init__.py index 82be2092..fe180f56 100644 --- a/dataherald/server/fastapi/__init__.py +++ b/dataherald/server/fastapi/__init__.py @@ -13,8 +13,11 @@ from dataherald.db_scanner.models.types import TableDescription from dataherald.sql_database.models.types import DatabaseConnection, SSHSettings from dataherald.types import ( + CancelFineTuningRequest, CreateResponseRequest, DatabaseConnectionRequest, + Finetuning, + FineTuningRequest, GoldenRecord, GoldenRecordRequest, Instruction, @@ -215,6 +218,28 @@ def __init__(self, settings: Settings): tags=["Instructions"], ) + self.router.add_api_route( + "/api/v1/finetunings", + self.create_finetuning_job, + methods=["POST"], + status_code=201, + tags=["Finetunings"], + ) + + self.router.add_api_route( + "/api/v1/finetunings/{finetuning_id}", + self.get_finetuning_job, + methods=["GET"], + tags=["Finetunings"], + ) + + self.router.add_api_route( + "/api/v1/finetunings/{finetuning_id}/cancel", + self.cancel_finetuning_job, + methods=["POST"], + tags=["Finetunings"], + ) + self.router.add_api_route( "/api/v1/heartbeat", self.heartbeat, methods=["GET"], tags=["System"] ) @@ -379,3 +404,19 @@ def update_instruction( ) -> Instruction: """Updates an instruction""" return self._api.update_instruction(instruction_id, instruction_request) + + def create_finetuning_job( + self, fine_tuning_request: FineTuningRequest, background_tasks: BackgroundTasks + ) -> Finetuning: + """Creates a fine tuning job""" + return self._api.create_finetuning_job(fine_tuning_request, background_tasks) + + def cancel_finetuning_job( + self, cancel_fine_tuning_request: CancelFineTuningRequest + ) -> Finetuning: + """Cancels a fine tuning job""" + return self._api.cancel_finetuning_job(cancel_fine_tuning_request) + + def get_finetuning_job(self, finetuning_job_id: str) -> Finetuning: + """Gets fine tuning jobs""" + return self._api.get_finetuning_job(finetuning_job_id) diff --git a/dataherald/types.py b/dataherald/types.py index 6cddcf9a..3d69a340 100644 --- a/dataherald/types.py +++ b/dataherald/types.py @@ -139,3 +139,43 @@ class ColumnDescriptionRequest(BaseModel): class TableDescriptionRequest(BaseModel): description: str | None columns: list[ColumnDescriptionRequest] | None + + +class FineTuningStatus(Enum): + QUEUED = "queued" + RUNNING = "running" + SUCCEEDED = "succeeded" + FAILED = "failed" + CANCELLED = "cancelled" + + +class BaseLLM(BaseModel): + model_provider: str | None = None + model_name: str | None = None + model_parameters: dict[str, str] | None = None + + +class Finetuning(BaseModel): + id: str | None = None + db_connection_id: str | None = None + status: str = "queued" + error: str | None = None + base_llm: BaseLLM | None = None + finetuning_file_id: str | None = None + finetuning_job_id: str | None = None + model_id: str | None = None + created_at: datetime = Field(default_factory=datetime.now) + golden_records: list[str] | None = None + metadata: dict[str, str] | None = None + + +class FineTuningRequest(BaseModel): + db_connection_id: str + base_llm: BaseLLM + golden_records: list[str] | None = None + metadata: dict[str, str] | None = None + + +class CancelFineTuningRequest(BaseModel): + finetuning_id: str + metadata: dict[str, str] | None = None