-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_gpt2.py
52 lines (36 loc) · 2.35 KB
/
test_gpt2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from transformers import GPT2Tokenizer, GPT2Model
from transformers import AutoTokenizer, AutoModel
import torch
import pdb
# Load pre-trained GPT-2 model and tokenizer
# tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# model = GPT2Model.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
# Encode a sentence
# sentence_1 = "This is a sample sentence to encode."
# sentence_2 = "Please encode the example sentece"
# sentence_3 = "I have had an apple today."
sentence_1 = "Measurements of electric power consumption in one household with a one-minute sampling rate over a period of almost 4 years. The sampling period is minutely. The min value is 0.09, the max value is 100."
# sentence_2 = "The Electricity Transformer Temperature (ETT) is a crucial indicator in the electric power long-term deployment. This dataset consists of 2 years data from two separated counties in China. To explore the granularity on the Long sequence time-series forecasting (LSTF) problem, different subsets are created, {ETTh1, ETTh2} for 1-hour-level and ETTm1 for 15-minutes-level. Each data point consists of the target value ”oil temperature” and 6 power load features. The train/val/test is 12/4/4 months."
sentence_2 = "Measurements of electric power consumption in one household with a one-minute sampling rate over a period of almost 4 years. The sampling period is hourly. The sampling period is hourly. The min value is 0.09, the max value is 100."
sentence_3 = "I have had an apple today."
def get_gpt_embedding(sentence):
inputs = tokenizer(sentence, return_tensors="pt")
# Get hidden states from the model
with torch.no_grad():
outputs = model(**inputs)
# The hidden states are in the `last_hidden_state` tensor
hidden_states = outputs.last_hidden_state
# Typically, you might average the hidden states or take the CLS token's hidden state for a sentence embedding
sentence_embedding = hidden_states.mean(dim=1)
return sentence_embedding / sentence_embedding.norm()
def dot_prod(emb_1, emb_2):
return (emb_1*emb_2).sum()
emb_1 = get_gpt_embedding(sentence_1)
emb_2 = get_gpt_embedding(sentence_2)
emb_3 = get_gpt_embedding(sentence_3)
print(
dot_prod(emb_1, emb_2), dot_prod(emb_1, emb_3)
)
pdb.set_trace()