-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
110 lines (86 loc) · 3.01 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import requests
import json
try:
from utils.logger import get_logger
logger = get_logger()
except ImportError:
import logging
logger = logging.getLogger(__name__)
API_UNAVAILABLE = False
nlu_files = [
'data/user-data/user_data.md',
'data/user-data/user_data_2.md',
'data/manually-generated/commands.md',
'data/manually-generated/conversation.md',
'data/auto-generated/commands.md',
'data/auto-generated/clarification.md',
'data/rasa-interactive/nlu.md'
]
stories_files = [
'data/rasa-interactive/stories.md',
'data/manually-generated/stories.md'
]
def train_model():
try:
# get the domain file from the API
headers_domain = {'Accept': "Accept: application/json"}
domain_request = requests.get("http://localhost:5005/domain", headers=headers_domain)
print(domain_request.status_code)
# load the NLU and stories data from training files
with open('pipelines/supervised_embeddings.yml', 'r') as f:
config = f.read()
except Exception as e:
logger.warning(e)
API_UNAVAILABLE = True
if API_UNAVAILABLE:
logger.warning("API unavailable, didn't re-train model")
return
try:
nlu_data = ''
for file in nlu_files:
with open(file, 'r') as f:
nlu_data += f.read()+'\n'
stories_data = ''
for file in stories_files:
with open(file, 'r') as f:
stories_data += f.read()+'\n'
# construct a training request
train_request = {
"domain": domain_request.text,
"config": config,
"nlu": nlu_data,
"stories": stories_data,
"force": False,
"save_to_default_model_directory": False
}
headers_train = {
'Content-Type': "application/json",
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br"
}
train_request = requests.post("http://localhost:5005/model/train",
headers=headers_train,
data=json.dumps(train_request))
print(train_request.status_code)
if train_request.status_code == 200:
# write model to file
with open('models/test.tar.gz','wb') as f:
train_request.raw.decode_content = True
f.write(train_request.content)
# load the model
model_request = {
"model_file": "models/test.tar.gz"
}
headers_model = {
'Accept': "*/*"
}
load_request = requests.put("http://localhost:5005/model",
headers=headers_model,
data=json.dumps(model_request))
print(load_request.status_code)
else:
print(train_request.text[:300])
except Exception as e:
logger.warning(e)
if __name__ == "__main__":
train_model()