-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathencoder_init.py
29 lines (24 loc) · 1.07 KB
/
encoder_init.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import sys
import torch
from autoencoder.encoder import VariationalEncoder
class EncodeState():
def __init__(self, latent_dim):
self.latent_dim = latent_dim
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
self.conv_encoder = VariationalEncoder(self.latent_dim).to(self.device)
self.conv_encoder.load()
self.conv_encoder.eval()
for params in self.conv_encoder.parameters():
params.requires_grad = False
except:
print('Encoder could not be initialized.')
sys.exit()
def process(self, observation):
image_obs = torch.tensor(observation[0], dtype=torch.float).to(self.device)
image_obs = image_obs.unsqueeze(0)
image_obs = image_obs.permute(0,3,2,1)
image_obs = self.conv_encoder(image_obs)
navigation_obs = torch.tensor(observation[1], dtype=torch.float).to(self.device)
observation = torch.cat((image_obs.view(-1), navigation_obs), -1)
return observation