gitmyhub

purejaxgcrl

Python ★ 24 updated 25d ago

GCRL in JAX. Official repository for LEO (ICML 2026).

Research code for training a single AI agent to pursue hundreds of different goals at once, from mining resources in a Minecraft-like game to navigating a gridworld. Includes multiple learning algorithms with GPU-accelerated JAX implementations.

PythonJAXsetup: moderatecomplexity 4/5

This is the official code release for a research paper called "Goal-Conditioned Agents that Learn Everything All at Once," accepted at the ICML 2026 machine learning conference. The project is about training AI agents that can pursue many different goals at the same time rather than being trained for one specific objective.

The term goal-conditioned reinforcement learning refers to a family of techniques where an agent learns to behave differently depending on which goal it is currently trying to achieve. Instead of training one specialized agent per task, you train a single agent that reads a goal as input and adapts its behavior accordingly. The challenge is making that work reliably when the number of possible goals is large.

The research tests these ideas in two environments. The first is Craftax, a game inspired by Minecraft where an agent can mine resources, craft tools, and explore. The paper defines 136 distinct goals for the simpler version of that game and 512 goals for the full version. The second is a simple grid-based world included for quick experiments where results are easier to interpret. The repository says a capable agent can be trained on the gridworld in under a minute.

The code implements several learning algorithms, including two called PPO and PQN that serve as baselines, a method called Hindsight Experience Replay, and the new methods introduced in the paper called LEO and Dual LEO. Each algorithm is contained in a single Python file so that researchers can read and modify them without navigating a complex codebase. All implementations use JAX, a Python library developed at Google that makes numerical computation run very fast, particularly on graphics hardware.

Installation requires Python 3.10 or later and a specific version of JAX. The default settings for each script are pre-set to the values that performed best in the paper's experiments. Trained agents can be visualized after training using an included renderer.

Where it fits