-
Notifications
You must be signed in to change notification settings - Fork 0
/
agent.py
executable file
·331 lines (281 loc) · 10.6 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
''' Agents: stop/random/shortest/seq2seq '''
import json
import os
import sys
import numpy as np
import random
import time
import torch
import torch.nn as nn
import torch.distributions as D
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F
from env import R2RBatch
from utils import padding_idx
class BaseAgent(object):
''' Base class for an R2R agent to generate and save trajectories. '''
def __init__(self, env, results_path):
self.env = env
self.results_path = results_path
random.seed(1)
self.results = {}
self.losses = [] # For learning agents
def write_results(self):
output = [{'instr_id': k, 'trajectory': v}
for k, v in self.results.items()]
with open(self.results_path, 'w') as f:
json.dump(output, f)
def rollout(self):
''' Return a list of dicts containing instr_id:'xx', path:[(viewpointId, heading_rad, elevation_rad)] '''
raise NotImplementedError
@staticmethod
def get_agent(name):
return globals()[name+"Agent"]
def test(self):
self.env.reset_epoch()
self.losses = []
self.results = {}
# We rely on env showing the entire batch before repeating anything
#print('Testing %s' % self.__class__.__name__)
looped = False
while True:
for traj in self.rollout():
if traj['instr_id'] in self.results:
looped = True
else:
self.results[traj['instr_id']] = traj['path']
if looped:
break
class StopAgent(BaseAgent):
''' An agent that doesn't move! '''
def rollout(self):
traj = [{
'instr_id': ob['instr_id'],
'path': [(ob['viewpoint'], ob['heading'], ob['elevation'])]
} for ob in self.env.reset()]
return traj
class RandomAgent(BaseAgent):
''' An agent that picks a random direction then tries to go straight for
five viewpoint steps and then stops. '''
def rollout(self):
obs = self.env.reset()
traj = [{
'instr_id': ob['instr_id'],
'path': [(ob['viewpoint'], ob['heading'], ob['elevation'])]
} for ob in obs]
self.steps = random.sample(range(-11, 1), len(obs))
ended = [False] * len(obs)
for t in range(30):
actions = []
for i, ob in enumerate(obs):
if self.steps[i] >= 5:
actions.append((0, 0, 0)) # do nothing, i.e. end
ended[i] = True
elif self.steps[i] < 0:
actions.append((0, 1, 0)) # turn right (direction choosing)
self.steps[i] += 1
elif len(ob['navigableLocations']) > 1:
actions.append((1, 0, 0)) # go forward
self.steps[i] += 1
else:
actions.append((0, 1, 0)) # turn right until we can go forward
obs = self.env.step(actions)
for i, ob in enumerate(obs):
if not ended[i]:
traj[i]['path'].append(
(ob['viewpoint'], ob['heading'], ob['elevation']))
return traj
class ShortestAgent(BaseAgent):
''' An agent that always takes the shortest path to goal. '''
def rollout(self):
obs = self.env.reset()
traj = [{
'instr_id': ob['instr_id'],
'path': [(ob['viewpoint'], ob['heading'], ob['elevation'])]
} for ob in obs]
ended = np.array([False] * len(obs))
while True:
actions = [ob['teacher'] for ob in obs]
obs = self.env.step(actions)
for i, a in enumerate(actions):
if a == (0, 0, 0):
ended[i] = True
for i, ob in enumerate(obs):
if not ended[i]:
traj[i]['path'].append(
(ob['viewpoint'], ob['heading'], ob['elevation']))
if ended.all():
break
return traj
class Seq2SeqAgent(BaseAgent):
''' An agent based on an LSTM seq2seq model with attention. '''
# For now, the agent can't pick which forward move to make - just the one in the middle
model_actions = ['left', 'right', 'up', 'down',
'forward', '<end>', '<start>', '<ignore>']
env_actions = [
(0, -1, 0), # left
(0, 1, 0), # right
(0, 0, 1), # up
(0, 0, -1), # down
(1, 0, 0), # forward
(0, 0, 0), # <end>
(0, 0, 0), # <start>
(0, 0, 0) # <ignore>
]
feedback_options = ['teacher', 'argmax', 'sample']
def __init__(self, env, results_path, encoder, decoder, episode_len=20):
super(Seq2SeqAgent, self).__init__(env, results_path)
self.encoder = encoder
self.decoder = decoder
self.episode_len = episode_len
self.losses = []
self.criterion = nn.CrossEntropyLoss(
ignore_index=self.model_actions.index('<ignore>'))
@staticmethod
def n_inputs():
return len(Seq2SeqAgent.model_actions)
@staticmethod
def n_outputs():
# Model doesn't output start or ignore
return len(Seq2SeqAgent.model_actions)-2
def _sort_batch(self, obs):
''' Extract instructions from a list of observations and sort by descending
sequence length (to enable PyTorch packing). '''
seq_tensor = np.array([ob['instr_encoding'] for ob in obs])
seq_lengths = np.argmax(seq_tensor == padding_idx, axis=1)
seq_lengths[seq_lengths == 0] = seq_tensor.shape[1] # Full length
seq_tensor = torch.from_numpy(seq_tensor)
seq_lengths = torch.from_numpy(seq_lengths)
# Sort sequences by lengths
seq_lengths, perm_idx = seq_lengths.sort(0, True)
sorted_tensor = seq_tensor[perm_idx]
mask = (sorted_tensor == padding_idx)[:, :seq_lengths[0]]
return Variable(sorted_tensor, requires_grad=False).long().cuda(), \
mask.byte().cuda(), \
list(seq_lengths), list(perm_idx)
def _feature_variable(self, obs):
''' Extract precomputed features into variable. '''
feature_size = obs[0]['feature'].shape[0]
features = np.empty((len(obs), feature_size), dtype=np.float32)
for i, ob in enumerate(obs):
features[i, :] = ob['feature']
return Variable(torch.from_numpy(features), requires_grad=False).cuda()
def _teacher_action(self, obs, ended):
''' Extract teacher actions into variable. '''
a = torch.LongTensor(len(obs))
for i, ob in enumerate(obs):
# Supervised teacher only moves one axis at a time
ix, heading_chg, elevation_chg = ob['teacher']
if heading_chg > 0:
a[i] = self.model_actions.index('right')
elif heading_chg < 0:
a[i] = self.model_actions.index('left')
elif elevation_chg > 0:
a[i] = self.model_actions.index('up')
elif elevation_chg < 0:
a[i] = self.model_actions.index('down')
elif ix > 0:
a[i] = self.model_actions.index('forward')
elif ended[i]:
a[i] = self.model_actions.index('<ignore>')
else:
a[i] = self.model_actions.index('<end>')
return Variable(a, requires_grad=False).cuda()
def rollout(self):
obs = np.array(self.env.reset())
batch_size = len(obs)
# Reorder the language input for the encoder
seq, seq_mask, seq_lengths, perm_idx = self._sort_batch(obs)
perm_obs = obs[perm_idx]
# Record starting point
traj = [{
'instr_id': ob['instr_id'],
'path': [(ob['viewpoint'], ob['heading'], ob['elevation'])]
} for ob in perm_obs]
# Forward through encoder, giving initial hidden state and memory cell for decoder
ctx, h_t, c_t = self.encoder(seq, seq_lengths)
# Initial action
a_t = Variable(torch.ones(batch_size).long() * self.model_actions.index('<start>'),
requires_grad=False).cuda()
# Indices match permuation of the model, not env
ended = np.array([False] * batch_size)
# Do a sequence rollout and calculate the loss
self.loss = 0
env_action = [None] * batch_size
for t in range(self.episode_len):
f_t = self._feature_variable(perm_obs) # Image features from obs
h_t, c_t, alpha, logit = self.decoder(
a_t.view(-1, 1), f_t, h_t, c_t, ctx, seq_mask)
# Mask outputs where agent can't move forward
for i, ob in enumerate(perm_obs):
if len(ob['navigableLocations']) <= 1:
logit[i, self.model_actions.index('forward')] = -float('inf')
# Supervised training
target = self._teacher_action(perm_obs, ended)
self.loss += self.criterion(logit, target)
# Determine next model inputs
if self.feedback == 'teacher':
a_t = target # teacher forcing
elif self.feedback == 'argmax':
_, a_t = logit.max(1) # student forcing - argmax
a_t = a_t.detach()
elif self.feedback == 'sample':
probs = F.softmax(logit, dim=1)
m = D.Categorical(probs)
a_t = m.sample() # sampling an action from model
else:
sys.exit('Invalid feedback option')
# Updated 'ended' list and make environment action
for i, idx in enumerate(perm_idx):
action_idx = a_t[i].item()
if action_idx == self.model_actions.index('<end>'):
ended[i] = True
env_action[idx] = self.env_actions[action_idx]
obs = np.array(self.env.step(env_action))
perm_obs = obs[perm_idx]
# Save trajectory output
for i, ob in enumerate(perm_obs):
if not ended[i]:
traj[i]['path'].append(
(ob['viewpoint'], ob['heading'], ob['elevation']))
# Early exit if all ended
if ended.all():
break
self.losses.append(self.loss.item() / self.episode_len)
return traj
def test(self, use_dropout=False, feedback='argmax', allow_cheat=False):
''' Evaluate once on each instruction in the current environment '''
if not allow_cheat: # permitted for purpose of calculating validation loss only
# no cheating by using teacher at test time!
assert feedback in ['argmax', 'sample']
self.feedback = feedback
if use_dropout:
self.encoder.train()
self.decoder.train()
else:
self.encoder.eval()
self.decoder.eval()
super(Seq2SeqAgent, self).test()
def train(self, encoder_optimizer, decoder_optimizer, n_iters, feedback='teacher'):
''' Train for a given number of iterations '''
assert feedback in self.feedback_options
self.feedback = feedback
self.encoder.train()
self.decoder.train()
self.losses = []
for iter in range(1, n_iters + 1):
encoder_optimizer.zero_grad()
decoder_optimizer.zero_grad()
self.rollout()
self.loss.backward()
encoder_optimizer.step()
decoder_optimizer.step()
def save(self, encoder_path, decoder_path):
''' Snapshot models '''
torch.save(self.encoder.state_dict(), encoder_path)
torch.save(self.decoder.state_dict(), decoder_path)
def load(self, encoder_path, decoder_path):
''' Loads parameters (but not training state) '''
self.encoder.load_state_dict(torch.load(encoder_path))
self.decoder.load_state_dict(torch.load(decoder_path))