Skip to content

Commit

Permalink
Update model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Gautam-Rajeev authored Dec 12, 2023
1 parent a28168b commit 1c8c6eb
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions src/embeddings/openai/remote/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from request import ModelRequest
from openai import OpenAI
import os
import pandas as pd

class Model:
embedding_model = "text-embedding-ada-002"
Expand All @@ -9,20 +10,34 @@ def __new__(cls, context):
cls.context = context
if not hasattr(cls, 'instance'):
cls.client = OpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
api_key=os.getenv("OPENAI_API_KEY")
)
cls.instance = super(Model, cls).__new__(cls)
return cls.instance

async def inference(self, request: ModelRequest):
# Modify this function according to model requirements such that inputs and output remains the same
query = request.query
if request.df is not None:
data = request.df
data = data.loc[~pd.isnull(data['content']),:]
data['content'] = data['content'].astype(str)

if(query != None):
if data.empty or data['content'].isnull().any():
return 'There are nonzero null rows'

data['embeddings'] = data['content'].apply(
lambda x: self.client.embeddings.create(
input=x,
model=self.embedding_model,
).data[0].embedding
)
csv_string = data.to_csv(index=False)
return str(csv_string)

if request.query is not None:
embedding = self.client.embeddings.create(
input=query,
input=request.query,
model=self.embedding_model,
).data[0].embedding
return [embedding]

return "Invalid input"

0 comments on commit 1c8c6eb

Please sign in to comment.