forked from SeanNaren/deepspeech.torch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathUtilsMultiGPU.lua
116 lines (110 loc) · 4.09 KB
/
UtilsMultiGPU.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
require 'cunn'
require 'rnn'
require 'cudnn'
require 'DataParallelTableTrans'
local default_GPU = 2
function makeDataParallel(model, nGPU, is_cudnn)
if nGPU >= 1 then
if is_cudnn then
cudnn.fastest = true
model = cudnn.convert(model, cudnn)
end
if nGPU > 1 then
gpus = torch.range(1, nGPU):totable()
dpt = nn.DataParallelTableTrans(1, true, true)
dpt:add(model, gpus) -- now use our impl instead; nn.DataParallelTable(1)
dpt:threads(function()
require 'rnn'
require 'cudnn'
require 'nnx'
require 'warp_ctc'
require 'BNDecorator'
end)
dpt.gradInput = nil
model = dpt
end
model:cuda()
end
return model
end
local function clear(tensor)
if tensor then
tensor:set()
end
end
local function cleanDPT(module)
local newDPT = nn.DataParallelTableTrans(1)
newDPT:add(module:get(1):float():clone('weight','bias','running_mean','running_var'), 1)
return newDPT
end
function saveDataParallel(filename, orgModel)
local model_type = torch.type(orgModel)
local model
if model_type == 'nn.DataParallelTable' or
model_type == 'nn.DataParallelTableTrans' then
model = cleanDPT(orgModel)
elseif model_type == 'nn.Sequential' then
local temp_model = nn.Sequential()
for i, module in ipairs(model.modules) do
if torch.type(module) == 'nn.DataParallelTable' or
torch.type(module) == 'nn.DataParallelTableTrans' then
temp_model:add(cleanDPT(module))
else
temp_model:add(module:float():clone('weight','bias','running_mean','running_var'))
end
end
model = temp_model
else
assert(model_type == 'nn.gModule',
'This saving function only works with Sequential, gModule or DataParallelTable modules.')
end
--if torch.type(model) == 'nn.gModule' then
-- for _,node in ipairs(model.forwardnodes) do
-- m = node.data.module
-- if m then
-- if m.modules then
-- for _,inner_m in ipairs(m.modules) do
-- if torch.type(inner_m) == 'cudnn.LSTM' then
-- clear(inner_m.hiddenOutput)
-- clear(inner_m.cellOutput)
-- clear(inner_m.gradHiddenInput)
-- clear(inner_m.gradCellInput)
-- clear(inner_m.workspace)
-- else
-- inner_m.gradBias = nil
-- end
-- inner_m.gradWeight = nil
-- end
-- end
-- clear(m.reverse_input)
-- clear(m._input)
-- clear(m.reverse_gradOutput)
-- clear(m._gradOutput)
-- end
-- end
--end
model:clearState()
collectgarbage()
torch.save(filename, model)
end
function loadDataParallel(filename, nGPU, is_cudnn)
local model = torch.load(filename)
local model_type = torch.type(model)
if model_type == 'nn.DataParallelTable' or
model_type == 'nn.DataParallelTableTrans' then
return makeDataParallel(model:get(1):float(), nGPU, is_cudnn)
elseif model_type == 'nn.Sequential' then
for i,module in ipairs(model.modules) do
if torch.type(module) == 'nn.DataParallelTable' or
torch.type(module) == 'nn.DataParallelTableTrans' then
model.modules[i] = makeDataParallel(module:get(1):float(), nGPU, is_cudnn)
end
end
return model
elseif model_type == 'nn.gModule' then
model = makeDataParallel(model, nGPU, is_cudnn)
return model
else
error('The loaded model is not a Sequential or DataParallelTable module.')
end
end