World Models
Paper
•
1803.10122
•
Published
•
5
This is a World Models agent trained on the SpaceInvadersNoFrameskip-v4 environment.
World Models is a model-based reinforcement learning approach that learns a compressed representation of the environment and trains a controller to maximize reward in the learned model.
The architecture consists of three components:
import torch
import gymnasium as gym
# Load models
vae = VAE(latent_dim=32)
vae.load_state_dict(torch.load('vae_model.pt'))
rnn = MDNRNN(latent_dim=32, action_dim=6)
rnn.load_state_dict(torch.load('mdnrnn_model.pt'))
controller = Controller(latent_dim=32, hidden_dim=256)
controller.load_state_dict(torch.load('controller_model.pt'))
# Run agent
env = gym.make('SpaceInvadersNoFrameskip-v4')
# ... (see repository for full inference code)
@article{ha2018worldmodels,
title={World Models},
author={Ha, David and Schmidhuber, J{\"u}rgen},
journal={arXiv preprint arXiv:1803.10122},
year={2018}
}