-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathml-text-analysis.py
54 lines (40 loc) · 1.65 KB
/
ml-text-analysis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from scipy.special import softmax
# Model architecture must be the same as during training
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=5)
# Load the state dictionary from the saved file
model.load_state_dict(torch.load('585_bert_model.pth'))
# Put the model into evaluation mode
model.eval()
# Load the tokenizer
tokenizer = BertTokenizer.from_pretrained('./tokenizer')
def analyze_sentiment(text):
# Tokenize the input text and encode it
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=1024)
# Run the text through the model to get sentiment prediction logits
with torch.no_grad():
outputs = model(**inputs)
# Get logits and convert to probabilities using softmax
logits = outputs.logits
probs = softmax(logits.numpy()[0])
# Get the label with the highest probability
sentiment_class = probs.argmax()
# Mapping model output to the sentiment labels
labels = ["very negative", "negative", "neutral", "positive", "very positive"]
return labels[sentiment_class]
def main():
while True:
try:
# Take input from user
input_text = input("Enter a text (or type 'exit' to quit): ")
if input_text.lower() == 'exit':
break
# Analyze the sentiment of the input text
sentiment = analyze_sentiment(input_text)
print(f"Sentiment: {sentiment}")
except KeyboardInterrupt:
print("\nExiting.")
break
if __name__ == "__main__":
main()