Vision Guided Imitation Learning using Action Chunk Transformer#
The Action Chunking with Transformers (ACT) technique in single-arm robotic manipulation for vision-guided pick-and-place tasks. ACT employs a Conditional Variational Autoencoder (CVAE) to predict sequences of actions, termed “action chunks,” which are groups of actions predicted together to achieve more complex tasks efficiently. Unlike traditional methods that rely solely on joint position data and predict individual actions, our approach integrates visual data to enrich the learning context and enhance execution precision. We acquired the expert data by providing manual demonstrations of the task, allowing the model to learn from real-time, complex action sequences. By predicting these action chunks instead of single actions, the ACT model adapts from dual-arm to single-arm configurations, enhancing control strategies and demonstrating significant improvements in the robot’s speed, precision, and reliability. This substantiates the paper’s title, “Vision-Guided Imitation Learning Using Action Chunk Transformers,” highlighting the critical role of vision in advancing robotic control systems.
Transformer#
This is from the main paper "Attention is all you need "
.It’s designed to encode a sequence into a set of context-aware vector representations, relying heavily on self-attention and position-wise operations, without using recurrence (like RNNs) or convolution.
The transformer architecture is divided into two parts.
Encoder
Decoder
Encoder#
First of all the encoder transforms the input tokens into continuous embeddings and adds positional encodings, resulting in a sequence of vectors that now carry both semantic and positional information.
Multi-head Attention#
- Take an input embedding with a dimension of 256.
- Split the embedding into 8 segments, each with a dimension of 32.
- Pass each segment through its own linear layer.
- Feed the outputs of these linear layers into the Scaled Dot Product Attention mechanism.
- **Scaled Dot Product Attention**
$$
\mathrm{Attention}(Q,K,V) = \text{softmax}\!\biggl(\frac{QK^{T}}{\sqrt{d_{k}}}\biggr)V
$$

We concatenate the various segments we initially divided, restoring them to the original input embedding size. This concatenated output is then passed through a Linear layer, which produces the final output of the Attention mechanism.
Decoder#
The decoder receives the encoder’s output and feeds it into the multi-headed attention mechanism.
The only change in the decoder is that within its multi-headed attention, the values and keys are sourced from the encoder, while the queries are derived from the decoder’s preceding output.
Main Overview of the ACT#
The main architecture of the Action Chunk Transformer incorporates two distinct encoders
. On the left side
, one encoder is dedicated to processing action sequences along with position embeddings, which are then integrated into a style variable.
STEP 1#
Encoding Latent Variations: The encoder on the left side compresses the action sequence and joint observations into a latent variable z. This variable captures the “
style
” of the action, focusing on unique motion patterns while ignoring specific details.
Facilitating Stochasticity:
z
introduces necessary stochasticity, allowing the model to account for multiple plausible action sequences from the same input, which is vital during training.
Test-Time Simplicity:
At test time, z
is set to the mean of its prior distribution, usually zero, to simplify the inference process by:
Eliminating the need for random sampling.
Streamlining the generation of deterministic action sequences.
In essence, z aids in learning a probabilistic representation of action variations during training but is simplified to a deterministic value at test time.
Finally, the encoder on the right further supports the training process by handling additional aspects of the input data, ensuring comprehensive learning and adaptation by the model.
This structure allows the transformer to manage and learn from complex action sequences in a highly effective manner.
STEP 2#
System Overview#
The system is designed to process visual input from cameras to predict the movement or configuration of robot joints. It integrates convolutional neural networks (CNNs) and transformers to achieve this, using image data to determine how the robot should adjust its joints.
Image Input and CNN Processing#
Input Images: The input consists of RGB images with a resolution of 640x480 pixels. If there are multiple cameras, each provides an image of this size.
CNN Backbone (ResNet18): The images are first processed by a convolutional neural network (CNN), specifically using ResNet18 as the backbone. This CNN is tasked with extracting relevant features from the raw images. The CNN reduces the spatial dimensions of the image while increasing the depth of features:
Output Dimension: After passing through the CNN, the dimensions of each image are reduced to a feature map of 300x512. This size represents a flattened form of the original feature maps, which may initially be spatially structured (e.g., 15x20 pixels with 512 feature channels).
Feature Tokenization and Transformer Encoding#
Tokenization: The output from the CNN is considered as a series of tokens. Each pixel in the 300x512 map is treated as an individual token, carrying rich feature information.
Inputs for Transformer:
Observations: The tokens derived from the CNN output.
Joint States: Current or previous states of the robot’s joints are also inputted into the transformer as a 1x512 vector. This informs the model of the robot’s current configuration.
Latent Style Variable
z
: A style variable (also of size 1x512) is included to possibly capture and integrate additional contextual or style-specific information which may influence how the output should be adapted.
Decoding and Output Prediction#
Decoder: The decoder part of the transformer then takes the encoded data and, through a series of transformations, predicts the sequence of joint angles. These predictions are formatted as sequences (kx8), where
k
represents different prediction points or time steps, and 8 could represent the number of joints or the dimensionality of each joint’s output.Output: The final output is a sequence predicting how the joints should move or be positioned based on the visual input and the given conditions (latent variables and joint states).
Temporial Ensembling#
Action Chunking#
Action chunking involves predicting a sequence of actions for multiple future time steps, starting at t=0
. Initially, the model predicts actions for four future time steps, which the robot then executes sequentially. Once these actions are completed, at t=4
, the model again predicts the next four actions. This pattern continues, allowing the robot to perform tasks by planning several steps ahead.
Action Chunking with Temporal Ensembling#
In the enhanced method combining Action Chunking and Temporal Ensembling, the model initially predicts four actions at t=0
and executes the first action immediately. At each subsequent time step, such as t=1
, the model predicts another set of four actions. To determine the next action to execute, the model averages the predictions from the current time step with the predictions from the previous time step. Specifically, it averages the current state’s prediction with the first state’s prediction from the new batch. This averaging process is used at each time step to smooth the trajectory of actions, ensuring that the robot’s motion is consistent.
Algorithm#
CODE#
Datset Prepare#
The datset structure
`Contents of the HDF5 file:
action: <HDF5 dataset "action": shape (149, 8), type "<f8">
- Shape: (149, 8), Type: float64
observations: <HDF5 group "/observations" (2 members)>
images: <HDF5 group "/observations/images" (1 members)>
top: <HDF5 dataset "top": shape (149, 480, 640, 3), type "|u1">
- Shape: (149, 480, 640, 3), Type: uint8
qpos: <HDF5 dataset "qpos": shape (149, 8), type "<f8">
- Shape: (149, 8), Type: float64`
This code is in live_record.py
import cv2
import os
import h5py
import numpy as np
from controller.robot_state import *
class CameraController:
def __init__(self, camera_index=0):
self.capture = cv2.VideoCapture(camera_index)
self.capture.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
self.capture.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
self.data = []
self.robot_state_data = []
self.recording = False
self.franka = RobotController()
# self.franka.initial_pose()
def capture_frames(self):
print("Press 's' to start/stop recording. Press 'q' to quit.")
while True:
ret, frame = self.capture.read()
if ret:
cv2.imshow('Camera Feed', frame)
key = cv2.waitKey(1) & 0xFF
if key == ord('s'):
self.recording = not self.recording
if self.recording:
print("Recording started.")
else:
print("Recording stopped. Saving data...")
self.record_extra_frames(5)
self.save_data()
self.data = []
self.robot_state_data = []
elif key == ord('q'):
break
if self.recording:
self.data.append(frame)
robot_state = self.get_robot_state(0) # Initial state with '0'
self.robot_state_data.append(robot_state)
cv2.destroyAllWindows()
def get_robot_state(self, end_marker=0):
angles = self.franka.angles()
return np.concatenate((angles, [end_marker]))
def record_extra_frames(self, count):
last_frame = self.data[-1]
last_state = self.get_robot_state(1) # Final state with '1'
for _ in range(count):
self.data.append(last_frame)
self.robot_state_data.append(last_state)
def save_data(self):
if not self.data:
print("No data to save.")
return
episode_idx = 0
directory = "real_dir2"
if not os.path.exists(directory):
os.makedirs(directory)
while os.path.exists(os.path.join(directory, f'episode_{episode_idx}.hdf5')):
episode_idx += 1
file_path = os.path.join(directory, f'episode_{episode_idx}.hdf5')
with h5py.File(file_path, 'w') as root:
root.attrs['sim'] = True
obs = root.create_group('observations')
images = obs.create_group('images')
camera_names = ['top']
for cam_name, data in zip(camera_names, [self.data]):
image_data = np.array(data, dtype='uint8')
images.create_dataset(cam_name, data=image_data, dtype='uint8', chunks=(1, 480, 640, 3))
robot_data = np.array(self.robot_state_data, dtype='float64')
obs.create_dataset('qpos', data=robot_data)
root.create_dataset('action', data=robot_data)
def main():
camera_index = 0
camera_controller = CameraController(camera_index)
camera_controller.capture_frames()
if __name__ == '__main__':
main()
Dataset Sampler#
import os
import h5py
import torch
import numpy as np
from einops import rearrange
from torch.utils.data import DataLoader
# from policy import ACTPolicy , CNNMLPPolicy
from policy import ACTPolicy , CNNMLPPolicy
import IPython
e = IPython.embed
class EpisodicDataset(torch.utils.data.Dataset):
def __init__(self, episode_ids, dataset_dir, camera_names, norm_stats):
super(EpisodicDataset).__init__()
self.episode_ids = episode_ids
self.dataset_dir = dataset_dir
self.camera_names = camera_names
self.norm_stats = norm_stats
self.is_sim = None
#self.__getitem__(0) # initialize self.is_sim
def __len__(self):
return len(self.episode_ids)
def __getitem__(self, index):
sample_full_episode = False # hardcode
episode_id = self.episode_ids[index]
dataset_path = os.path.join(self.dataset_dir, f'episode_{episode_id}.hdf5')
with h5py.File(dataset_path, 'r') as root:
is_sim = root.attrs['sim']
original_action_shape = root['/action'].shape
episode_len = original_action_shape[0]
if sample_full_episode:
start_ts = 0
else:
start_ts = np.random.choice(episode_len)
# get observation at start_ts only
qpos = root['/observations/qpos'][start_ts]
# qvel = root['/observations/qvel'][start_ts]
image_dict = dict()
for cam_name in self.camera_names:
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][start_ts]
# get all actions after and including start_ts
if is_sim:
action = root['/action'][start_ts:]
action_len = episode_len - start_ts
else:
action = root['/action'][max(0, start_ts - 1):] # hack, to make timesteps more aligned
action_len = episode_len - max(0, start_ts - 1) # hack, to make timesteps more aligned
self.is_sim = is_sim
padded_action = np.zeros(original_action_shape, dtype=np.float32)
padded_action[:action_len] = action
is_pad = np.zeros(episode_len)
is_pad[action_len:] = 1
# new axis for different cameras
all_cam_images = []
for cam_name in self.camera_names:
all_cam_images.append(image_dict[cam_name])
all_cam_images = np.stack(all_cam_images, axis=0)
# construct observations
image_data = torch.from_numpy(all_cam_images)
qpos_data = torch.from_numpy(qpos).float()
action_data = torch.from_numpy(padded_action).float()
is_pad = torch.from_numpy(is_pad).bool()
# channel last
image_data = torch.einsum('k h w c -> k c h w', image_data)
# normalize image and change dtype to float
image_data = image_data / 255.0
action_data = (action_data - self.norm_stats["action_mean"]) / self.norm_stats["action_std"]
qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats["qpos_std"]
return image_data, qpos_data, action_data, is_pad
Here
image_data
: Sampling image data
qpos_data
: Positions (eg joint angles)
action_data
: Actions
is_pad
: It tells us how far it padded.
Policy#
import torch.nn as nn
from torch.nn import functional as F
import torchvision.transforms as transforms
from detr.main import build_ACT_model_and_optimizer, build_CNNMLP_model_and_optimizer
import IPython
e = IPython.embed
class ACTPolicy(nn.Module):
def __init__(self, args_override):
super().__init__()
model, optimizer = build_ACT_model_and_optimizer(args_override)
self.model = model # CVAE decoder conditional Variational Auto encoder
self.optimizer = optimizer
self.kl_weight = args_override['kl_weight']
print(f'KL Weight {self.kl_weight}')
def __call__(self, qpos, image, actions=None, is_pad=None):
env_state = None
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
image = normalize(image)
if actions is not None: # training time
actions = actions[:, :self.model.num_queries]
is_pad = is_pad[:, :self.model.num_queries]
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
loss_dict = dict()
all_l1 = F.l1_loss(actions, a_hat, reduction='none')
l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean()
loss_dict['l1'] = l1 # regression loss
loss_dict['kl'] = total_kld[0] #
loss_dict['loss'] = loss_dict['l1'] + loss_dict['kl'] * self.kl_weight
return loss_dict
else: # inference time
a_hat, _, (_, _) = self.model(qpos, image, env_state)
return a_hat
def configure_optimizers(self):
return self.optimizer
Training code#
train.py file
from settings.var import *
import os
import pickle
import argparse
from copy import deepcopy
import matplotlib.pyplot as plt
from training.utils import *
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='')
args = parser.parse_args()
task = args.task
# configs
task_cfg = TASK_CONFIG
train_cfg = TRAIN_CONFIG
policy_config = POLICY_CONFIG
checkpoint_dir = os.path.join(train_cfg['checkpoint_dir'], task)
# device
device = os.environ['DEVICE']
def forward_pass(data, policy):
image_data, qpos_data, action_data, is_pad = data
qpos_data = qpos_data.float()
action_data = action_data.float()
image_data, qpos_data, action_data, is_pad = image_data.to(device),
qpos_data.to(device), action_data.to(device), is_pad.to(device)
return policy(qpos_data, image_data, action_data, is_pad) # TODO remove None
def plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed):
# save training curves
for key in train_history[0]:
plot_path = os.path.join(ckpt_dir, f'train_val_{key}_seed_{seed}.png')
plt.figure()
train_values = [summary[key].item() for summary in train_history]
val_values = [summary[key].item() for summary in validation_history]
plt.plot(np.linspace(0, num_epochs-1, len(train_history)), train_values, label='train')
plt.plot(np.linspace(0, num_epochs-1, len(validation_history)), val_values, label='validation')
# plt.ylim([-0.1, 1])
plt.tight_layout()
plt.legend()
plt.title(key)
plt.savefig(plot_path)
print(f'Saved plots to {ckpt_dir}')
def train_bc(train_dataloader, val_dataloader, policy_config):
# load policy
policy = make_policy(policy_config['policy_class'], policy_config)
policy.to(device)
# load optimizer
optimizer = make_optimizer(policy_config['policy_class'], policy)
# create checkpoint dir if not exists
os.makedirs(checkpoint_dir, exist_ok=True)
train_history = []
validation_history = []
min_val_loss = np.inf
best_ckpt_info = None
for epoch in range(train_cfg['num_epochs']):
print(f'\nEpoch {epoch}')
# validation
with torch.inference_mode():
policy.eval()
epoch_dicts = []
for batch_idx, data in enumerate(val_dataloader):
forward_dict = forward_pass(data, policy)
epoch_dicts.append(forward_dict)
epoch_summary = compute_dict_mean(epoch_dicts)
validation_history.append(epoch_summary)
epoch_val_loss = epoch_summary['loss']
if epoch_val_loss < min_val_loss:
min_val_loss = epoch_val_loss
best_ckpt_info = (epoch, min_val_loss, deepcopy(policy.state_dict()))
print(f'Val loss: {epoch_val_loss:.5f}')
summary_string = ''
for k, v in epoch_summary.items():
summary_string += f'{k}: {v.item():.3f} '
print(summary_string)
# training
policy.train()
optimizer.zero_grad()
for batch_idx, data in enumerate(train_dataloader):
forward_dict = forward_pass(data, policy)
# backward
loss = forward_dict['loss']
loss.backward()
optimizer.step()
optimizer.zero_grad()
train_history.append(detach_dict(forward_dict))
epoch_summary = compute_dict_mean(train_history[(batch_idx+1)*epoch:(batch_idx+1)*(epoch+1)])
epoch_train_loss = epoch_summary['loss']
print(f'Train loss: {epoch_train_loss:.5f}')
summary_string = ''
for k, v in epoch_summary.items():
summary_string += f'{k}: {v.item():.3f} '
print(summary_string)
if epoch % 200 == 0:
ckpt_path = os.path.join(checkpoint_dir, f"policy_epoch_{epoch}_seed_{train_cfg['seed']}.ckpt")
torch.save(policy.state_dict(), ckpt_path)
plot_history(train_history, validation_history, epoch, checkpoint_dir, train_cfg['seed'])
ckpt_path = os.path.join(checkpoint_dir, f'policy_last.ckpt')
torch.save(policy.state_dict(), ckpt_path)
if __name__ == '__main__':
# set seed
set_seed(train_cfg['seed'])
# create ckpt dir if not exists
os.makedirs(checkpoint_dir, exist_ok=True)
# number of training episodes
data_dir = os.path.join(task_cfg['dataset_dir'], task)
num_episodes = len(os.listdir(data_dir))
# load data
train_dataloader, val_dataloader, stats, _ = load_data(data_dir, num_episodes, task_cfg['camera_names'],
train_cfg['batch_size_train'], train_cfg['batch_size_val'])
# save stats
stats_path = os.path.join(checkpoint_dir, f'dataset_stats.pkl')
with open(stats_path, 'wb') as f:
pickle.dump(stats, f)
# train
train_bc(train_dataloader, val_dataloader, policy_config)
'''
The data loader is very important bcs we are pluging our input data here
The data lodaer like an iterator or a smampler that samples each part of the data
were we will have the training data validation data and statastics which is just to normalize our data
'''
Model#
import torch
from torch import nn
from torch.autograd import Variable
from .backbone import build_backbone
from .transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer
import numpy as np
import IPython
e = IPython.embed
device = torch.device('cuda')
def reparametrize(mu, logvar):
std = logvar.div(2).exp()
eps = Variable(std.data.new(std.size()).normal_())
return mu + std * eps
def get_sinusoid_encoding_table(n_position, d_hid):
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
class DETRVAE(nn.Module):
""" This is the DETR module that performs object detection """
def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names):
""" Initializes the model.
Parameters:
backbones: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
state_dim: robot state dimension of the environment
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.num_queries = num_queries
self.camera_names = camera_names
self.transformer = transformer
self.encoder = encoder
hidden_dim = transformer.d_model
self.action_head = nn.Linear(hidden_dim, state_dim)
self.is_pad_head = nn.Linear(hidden_dim, 1)
self.query_embed = nn.Embedding(num_queries, hidden_dim)
if backbones is not None:
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
self.backbones = nn.ModuleList(backbones)
self.input_proj_robot_state = nn.Linear(8, hidden_dim)
else:
# input_dim = 14 + 7 # robot_state + env_state
self.input_proj_robot_state = nn.Linear(8, hidden_dim)
self.input_proj_env_state = nn.Linear(8, hidden_dim)
self.pos = torch.nn.Embedding(2, hidden_dim)
self.backbones = None
# encoder extra parameters
self.latent_dim = 32 # final size of latent z # TODO tune
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
self.encoder_action_proj = nn.Linear(8, hidden_dim) # project action to embedding
self.encoder_joint_proj = nn.Linear(8, hidden_dim) # project qpos to embedding
self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq
# decoder extra parameters
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent
def forward(self, qpos, image, env_state, actions=None, is_pad=None):
"""
qpos: batch, qpos_dim
image: batch, num_cam, channel, height, width
env_state: None
actions: batch, seq, action_dim
"""
is_training = actions is not None # train or val
bs, _ = qpos.shape
# actions = actions.to(torch.float32) if actions is not None else actions
# qpos = qpos.to(torch.float32)
# image = image.to(torch.float32)
### Obtain latent z from action sequence
if is_training:
# project action sequence to embedding dim, and concat with a CLS token
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
cls_embed = self.cls_embed.weight # (1, hidden_dim)
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
encoder_input = torch.cat([cls_embed, qpos_embed, action_embed], axis=1) # (bs, seq+1, hidden_dim)
encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
# do not mask cls token
cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
# obtain position embedding
pos_embed = self.pos_table.clone().detach()
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
# query model
encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad)
encoder_output = encoder_output[0] # take cls output only
latent_info = self.latent_proj(encoder_output)
mu = latent_info[:, :self.latent_dim]
logvar = latent_info[:, self.latent_dim:]
latent_sample = reparametrize(mu, logvar)
latent_input = self.latent_out_proj(latent_sample)
else:
mu = logvar = None
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
latent_input = self.latent_out_proj(latent_sample)
if self.backbones is not None:
# Image observation features and position embeddings
all_cam_features = []
all_cam_pos = []
for cam_id, cam_name in enumerate(self.camera_names):
features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED for single back bone
# features, pos = self.backbones[cam_id](image[:, cam_id])
features = features[0] # take the last layer feature
pos = pos[0]
all_cam_features.append(self.input_proj(features))
all_cam_pos.append(pos)
# proprioception features
proprio_input = self.input_proj_robot_state(qpos)
# fold camera dimension into width dimension
src = torch.cat(all_cam_features, axis=3)
pos = torch.cat(all_cam_pos, axis=3)
hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
else:
qpos = self.input_proj_robot_state(qpos)
env_state = self.input_proj_env_state(env_state)
transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
a_hat = self.action_head(hs)
is_pad_hat = self.is_pad_head(hs)
return a_hat, is_pad_hat, [mu, logvar]
def build_encoder(args):
d_model = args.hidden_dim # 256
dropout = args.dropout # 0.1
nhead = args.nheads # 8
dim_feedforward = args.dim_feedforward # 2048
num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder
normalize_before = args.pre_norm # False
activation = "relu"
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
return encoder
def build(args):
state_dim = 8 # TODO hardcode
# From state
# backbone = None # from state for now, no need for conv nets
# From image
backbones = []
backbone = build_backbone(args)
backbones.append(backbone)
transformer = build_transformer(args)
encoder = build_encoder(args)
model = DETRVAE(
backbones,
transformer,
encoder,
state_dim=state_dim,
num_queries=args.num_queries,
camera_names=args.camera_names,
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of parameters: %.2fM" % (n_parameters/1e6,))
return model