-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patha2_training_and_testing.py
222 lines (175 loc) · 7.38 KB
/
a2_training_and_testing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
'''
This code is provided solely for the personal and private use of students
taking the CSC401H/2511H course at the University of Toronto. Copying for
purposes other than this use is expressly prohibited. All forms of
distribution of this code, including but not limited to public repositories on
GitHub, GitLab, Bitbucket, or any other online platform, whether as given or
with any changes, are expressly prohibited.
Authors: Sean Robertson, Jingcheng Niu, Zining Zhu, and Mohamed Abdall
Updated by: Raeid Saqur <[email protected]>
All of the files in this directory and all subdirectories are:
Copyright (c) 2022 University of Toronto
'''
'''Functions related to training and testing.
You don't need anything more than what's been imported here.
'''
from tqdm import tqdm
import typing
import torch
import a2_bleu_score
import a2_dataloader
import a2_encoder_decoder
def train_for_epoch(
model: a2_encoder_decoder.EncoderDecoder,
dataloader: a2_dataloader.HansardDataLoader,
optimizer: torch.optim.Optimizer,
device: torch.device) -> float:
'''Train an EncoderDecoder for an epoch
An epoch is one full loop through the training data. This function:
1. Defines a loss function using :class:`torch.nn.CrossEntropyLoss`,
keeping track of what id the loss considers "padding"
2. For every iteration of the `dataloader` (which yields triples
``F, F_lens, E``)
1. Sends ``F`` to the appropriate device via ``F = F.to(device)``. Same
for ``F_lens`` and ``E``.
2. Zeros out the model's previous gradient with ``optimizer.zero_grad()``
3. Calls ``logits = model(F, F_lens, E)`` to determine next-token
probabilities.
4. Modifies ``E`` for the loss function, getting rid of a token and
replacing excess end-of-sequence tokens with padding using
``model.get_target_padding_mask()`` and ``torch.masked_fill``
5. Flattens out the sequence dimension into the batch dimension of both
``logits`` and ``E``
6. Calls ``loss = loss_fn(logits, E)`` to calculate the batch loss
7. Calls ``loss.backward()`` to backpropagate gradients through
``model``
8. Calls ``optim.step()`` to update model parameters
3. Returns the average loss over sequences
Parameters
----------
model : EncoderDecoder
The model we're training.
dataloader : HansardDataLoader
Serves up batches of data.
device : torch.device
A torch device, like 'cpu' or 'cuda'. Where to perform computations.
optimizer : torch.optim.Optimizer
Implements some algorithm for updating parameters using gradient
calculations.
Returns
-------
avg_loss : float
The total loss divided by the total numer of sequence
'''
# If you want, instead of looping through your dataloader as
# for ... in dataloader: ...
# you can wrap dataloader with "tqdm":
# for ... in tqdm(dataloader): ...
# This will update a progress bar on every iteration that it prints
# to stdout. It's a good gauge for how long the rest of the epoch
# will take. This is entirely optional - we won't grade you differently
# either way.
# If you are running into CUDA memory errors part way through training,
# try "del F, F_lens, E, logits, loss" at the end of each iteration of
# the loop.
func = torch.nn.CrossEntropyLoss(ignore_index = model.source_pad_id)
total_loss, batches = 0,0
print("Begin Training")
for F, F_lens, E in dataloader:
F = F.to(device)
F_lens = F_lens.to(device)
E = E.to(device)
optimizer.zero_grad() # zero gradient
logits = model(F, F_lens, E).to(device)
E = E[1:, :]
E = E.masked_fill(model.get_target_padding_mask(E), model.source_pad_id)
logits = logits.flatten(0, 1)
flatE = torch.flatten(E, start_dim =0)
# cross entropy
CE = func(logits, flatE)
CE.backward() # backprop
optimizer.step()
total_loss += CE.item()
batches += 1
avg_loss = total_loss/batches
print("End Training")
return avg_loss
def compute_batch_total_bleu(
E_ref: torch.LongTensor,
E_cand: torch.LongTensor,
target_sos: int,
target_eos: int) -> float:
'''Compute the total BLEU score over elements in a batch
Parameters
----------
E_ref : torch.LongTensor
A batch of reference transcripts of shape ``(T, M)``, including
start-of-sequence tags and right-padded with end-of-sequence tags.
E_cand : torch.LongTensor
A batch of candidate transcripts of shape ``(T', M)``, also including
start-of-sequence and end-of-sequence tags.
target_sos : int
The ID of the start-of-sequence tag in the target vocabulary.
target_eos : int
The ID of the end-of-sequence tag in the target vocabulary.
Returns
-------
total_bleu : float
The sum total BLEU score for across all elements in the batch. Use
n-gram precision 4.
'''
total_bleu,n_gram_precision = 0, 4
str_eos, str_sos = str(target_eos),str(target_sos)
E_ref, E_cand = E_ref.permute(1, 0).tolist(), E_cand.permute(1, 0).tolist()
for ref, cand in zip(E_ref, E_cand):
ref = [str(i) for i in ref if ((str(i) != str_eos) and (str(i) != str_sos))]
cand = [str(j) for j in cand if ((str(j) != str_eos) and (str(j) != str_sos))]
total_bleu += a2_bleu_score.BLEU_score(ref, cand, n_gram_precision)
return total_bleu
def compute_average_bleu_over_dataset(
model: a2_encoder_decoder.EncoderDecoder,
dataloader: a2_dataloader.HansardDataLoader,
target_sos: int,
target_eos: int,
device: torch.device) -> float:
'''Determine the average BLEU score across sequences
This function computes the average BLEU score across all sequences in
a single loop through the `dataloader`.
1. For every iteration of the `dataloader` (which yields triples
``F, F_lens, E_ref``):
1. Sends ``F`` to the appropriate device via ``F = F.to(device)``. Same
for ``F_lens``. No need for ``E_cand``, since it will always be
compared on the CPU.
2. Performs a beam search by calling ``b_1 = model(F, F_lens)``
3. Extracts the top path per beam as ``E_cand = b_1[..., 0]``
4. Computes the total BLEU score of the batch using
:func:`compute_batch_total_bleu`
2. Returns the average per-sequence BLEU score
Parameters
----------
model : EncoderDecoder
The model we're testing.
dataloader : HansardDataLoader
Serves up batches of data.
target_sos : int
The ID of the start-of-sequence tag in the target vocabulary.
target_eos : int
The ID of the end-of-sequence tag in the target vocabulary.
Returns
-------
avg_bleu : float
The total BLEU score summed over all sequences divided by the number of
sequences
'''
print("Starting Avg BLEU")
points, total = 0,0
for F, F_lens, E in dataloader:
F , F_lens= F.to(device), F_lens.to(device)
b_1 = model(F, F_lens)
E_cand = b_1[:,:,0]
total += compute_batch_total_bleu(E, E_cand, target_sos, target_eos)
points += F_lens.shape[0]
print("End Avg BLEU")
print("%%%%%")
avg_bleu = total/points
return avg_bleu