Skip to content

Commit

Permalink
Merge pull request #24 from ChatWithPDF/up-lift
Browse files Browse the repository at this point in the history
feat: update model for up-lift
  • Loading branch information
techsavvyash authored Jan 22, 2024
2 parents 0a7af27 + db59160 commit 0e9a8e5
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions src/text_classification/flow_classification/local/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
from request import ModelRequest
from torch.nn.functional import softmax

class Model():
def __new__(cls, context):
Expand All @@ -20,5 +21,14 @@ async def inference(self, request: ModelRequest):
inputs = {key: value.to(self.device) for key, value in inputs.items()}
with torch.no_grad():
logits = self.model(**inputs).logits
predicted_class_id = logits.argmax().item()
return self.model.config.id2label[predicted_class_id]

probabilities = softmax(logits, dim=1)

output = []
for idx, score in enumerate(probabilities[0]):
label = self.model.config.id2label[idx]
output.append({"label": label, "score": score.item()})

sorted_output = sorted(output, key=lambda x: x['score'], reverse=True)

return [[item for item in sorted_output]]

0 comments on commit 0e9a8e5

Please sign in to comment.