Categories: AI/ML Research

Gradient-based Planning for World Models at Longer Horizons



GRASP is a new gradient-based planner for learned dynamics (a “world model”) that makes long-horizon planning practical by (1) lifting the trajectory into virtual states so optimization is parallel across time, (2) adding stochasticity directly to the state iterates for exploration, and (3) reshaping gradients so actions get clean signals while we avoid brittle “state-input” gradients through high-dimensional vision models.

Large, learned world models are becoming increasingly capable. They can predict long sequences of future observations in high-dimensional visual spaces and generalize across tasks in ways that were difficult to imagine a few years ago. As these models scale, they start to look less like task-specific predictors and more like general-purpose simulators.

But having a powerful predictive model is not the same as being able to use it effectively for control/learning/planning. In practice, long-horizon planning with modern world models remains fragile: optimization becomes ill-conditioned, non-greedy structure creates bad local minima, and high-dimensional latent spaces introduce subtle failure modes.

In this blog post, I describe the problems that motivated this project and our approach to address them: why planning with modern world models can be surprisingly fragile, why long horizons are the real stress test, and what we changed to make gradient-based planning much more robust.


This blog post discusses work done with Mike Rabbat, Aditi Krishnapriyan, Yann LeCun, and Amir Bar (* denotes equal advisorship), where we propose GRASP.


What is a world model?

These days, the term “world model” is quite overloaded, and depending on the context can either mean an explicit dynamics model or some implicit, reliable internal state that a generative model relies on (e.g. when an LLM generates chess moves, whether there is some internal representation of the board). We give our loose working definition below.

Suppose you take actions $a_t in mathcal{A}$ and observe states $s_t in mathcal{S}$ (images, latent vectors, proprioception). A world model is a learned model that, given the current state and a sequence of future actions, predicts what will happen next. Formally, it defines a predictive distribution on a sequence of observed states $s_{t-h:t}$ and current action $a_t$:

[P_theta(s_{t+1} mid s_{t-h:t},; a_t)]

that approximates the environment’s true conditional $P(s_{t+1} mid s_{t-h:t},; a_t)$. For this blog post, we’ll assume a Markovian model $P(s_{t+1} mid s_{t-h:t},; a_t)$ for simplicity (all results here can be extended to the more general case), and when the model is deterministic it reduces to a map over states:

[s_{t+1} = F_theta(s_t, a_t).]

In practice the state $s_t$ is often a learned latent representation (e.g., encoded from pixels), so the model operates in a (theoretically) compact, differentiable space. The key point is that a world model gives you a differentiable simulator; you can roll it forward under hypothetical action sequences and backpropagate through the predictions.


Planning: choosing actions by optimizing through the model

Given a start $s_0$ and a goal $g$, the simplest planner chooses an action sequence $mathbf{a}=(a_0,dots,a_{T-1})$ by rolling out the model and minimizing terminal error:

[min_{mathbf{a}} ; | s_T(mathbf{a}) – g |_2^2, quad text{where } s_T(mathbf{a}) = mathcal{F}_{theta}^{T}(s_0,mathbf{a}).]

Here we use $mathcal{F}^T$ as shorthand for the full rollout through the world model (dependence on model parameters $theta$ is implicit):

[mathcal{F}_{theta}^{T}(s_0, mathbf{a}) = F_theta(F_theta(cdots F_theta(s_0, a_0), cdots, a_{T-2}), a_{T-1}).]

In short horizons and low-dimensional systems, this can work reasonably well. But as horizons grow and models become larger and more expressive, its weaknesses become amplified.

So why doesn’t this just work at scale?


Why long-horizon planning is hard (even when everything is differentiable)

There are two separate pain points for the more general world model, plus a third that is specific to learned, deep learning-based models.

1) Long-horizon rollouts create deep, ill-conditioned computation graphs

Those familiar with backprop through time (BPTT) may notice that we’re differentiating through a model applied to itself repeatedly, which will lead to the exploding/vanishing gradients problem. Namely, if we take derivatives (note we’re differentiating vector-valued functions, resulting in Jacobians that we denote with $D_x (cdots)$) with respect to earlier actions (e.g. $a_0$):

[D_{a_0} mathcal{F}_{theta}^{T}(s_0, mathbf{a}) = Bigl(prod_{t=1}^T D_s F_theta(s_t, a_t)Bigr) D_{a_0}F_theta(s_0, a_0).]

We see that the Jacobian’s conditioning scales exponentially with time $T$:

[sigma_{text{max/min}}(D_{a_0}mathcal{F}_{theta}^{T}) sim sigma_{text{max/min}}(D_s F_theta)^{T-1},]

leading to exploding or vanishing gradients.

2) The landscape is non-greedy and full of traps

At short horizons, the greedy solution, where we move straight toward the goal at every step, is often good enough. If you only need to plan a few steps ahead, the optimal trajectory usually doesn’t deviate much from “head toward $g$” at each step.

As horizons grow, two things happen. First, longer tasks are more likely to require non-greedy behavior: going around a wall, repositioning before pushing, backing up to take a better path. And as horizons grow, more of these non-greedy steps are typically needed. Second, the optimization space itself scales with horizon: $mathrm{dim}(mathcal{A} times cdots times mathcal{A}) = Tmathrm{dim}(mathcal{A})$, further expanding the space of local minima for the optimization problem.

Distance to goal along the optimal path is non-monotonic, and the resulting loss landscape can be rough.

A long-horizon fix: lifting the dynamics constraint

Suppose we treat the dynamics constraint $s_{t+1} = F_{theta}(s_t, a_t)$ as a soft constraint, and we instead optimize the following penalty function over both actions $(a_0,ldots,a_{T-1})$ and states $(s_0,ldots,s_T)$:

[min_{mathbf{s},mathbf{a}} mathcal{L}(mathbf{s}, mathbf{a}) = sum_{t=0}^{T-1} big|F_theta(s_t,a_t) – s_{t+1}big|_2^2,
quad text{with } s_0 text{ fixed and } s_T=g.]

This is also sometimes called collocation in planning/robotics literature. Note the lifted formulation shares the same global minimizers as the original rollout objective (both are zero exactly when the trajectory is dynamically feasible). But the optimization landscapes are very different, and we get two immediate benefits:

  • Each world model evaluation $F_{theta}(s_t,a_t)$ depends only on local variables, so all $T$ terms can be computed in parallel across time, resulting in a huge speed-up for longer horizons, and
  • You no longer backpropagate through a single deep $T$-step composition to get a learning signal, since the previous product of Jacobians now splits into a sum, e.g.:

[D_{a_0} mathcal{L} = 2(F_theta(s_0, a_0) – s_1).]

Being able to optimize states directly also helps with exploration, as we can temporarily navigate through unphysical domains to find the optimal plan:

Collocation-based planning allows us to directly perturb states and explore midpoints more effectively.

However, lunch is never free. And indeed, especially for deep learning-based world models, there is a critical issue that makes the above optimization quite difficult in practice.

An issue for deep learning-based world models: sensitivity of state-input gradients

The tl;dr of this section is: directly optimizing states through a deep learning-based $F_{theta}$ is incredibly brittle, à la adversarial robustness. Even if you train your world model in a lower-dimensional state space, the training process for the world model makes unseen state landscapes very sharp, whether it be an unseen state itself or simply a normal/orthogonal direction to the data manifold.

Adversarial robustness and the “dimpled manifold” model

Adversarial robustness originally looked at classification models $f_theta : mathbb{R}^{wtimes h times c} to mathbb{R}^K$, and showed that by following the gradient of a particular logit $nabla f_theta^k$ from a base image $x$ (not of class $k$), you did not have to move far along $x’ = x + epsilonnabla f_theta^k$ to make $f_theta$ classify $x’$ as $k$ (Szegedy et al., 2014; Goodfellow et al., 2015):

Depiction of the classic example from (Goodfellow et al., 2015).

Later work has painted a geometric picture for what’s going on: for data near a low-dimensional manifold $mathcal{M}$, the training process controls behavior in tangential directions, but does not regularize behavior in orthogonal directions, thus leading to sensitive behavior (Stutz et al., 2019). Another way stated: $f_theta$ has a reasonable Lipschitz constant when considering only tangential directions to the data manifold $mathcal{M}$, but can have very high Lipschitz constants in normal directions. In fact, it often benefits the model to be sharper in these normal directions, so it can fit more complicated functions more precisely.


As a result, such adversarial examples are incredibly common even for a single given model. Further, this is not just a computer vision phenomenon; adversarial examples also appear in LLMs (Wallace et al., 2019) and in RL (Gleave et al., 2019).

While there are methods to train for more adversarially robust models, there is a known trade-off between model performance and adversarial robustness (Tsipras et al., 2019): especially in the presence of many weakly-correlated variables, the model must be sharper to achieve higher performance. Indeed, most modern training algorithms, whether in computer vision or LLMs, do not train adversarial robustness out. Thus, at least until deep learning sees a major regime change, this is a problem we’re stuck with.

Why is adversarial robustness an issue for world model planning?

Consider a single component of the dynamics loss we’re optimizing in the lifted state approach:

[min_{s_t, a_t, s_{t+1}} |F_theta(s_t, a_t) – s_{t+1}|_2^2]

Let’s further focus on just the base state:

[min_{s_t} |F_theta(s_t, a_t) – s_{t+1}|_2^2.]

Since world models are typically trained on state/action trajectories $(s_1, a_1, s_2, a_2, ldots)$, the state-data manifold for $F_{theta}$ has dimensionality bounded by the action space:

[mathrm{dim}(mathcal{M}_s) le mathrm{dim}(mathcal{A}) + 1 + mathrm{dim}(mathcal{R}),]

where $mathcal{R}$ is some optional space of augmentations (e.g. translations/rotations). Thus, we can typically expect $mathrm{dim}(mathcal{M}_s)$ to be much lower than $mathrm{dim}(mathcal{S})$, and thus: it is very easy to find adversarial examples that hack any state to any other desired state.

As a result, the dynamics optimization

[sum_{t=0}^{T-1} big|F_theta(s_t,a_t) – s_{t+1}big|_2^2]

feels incredibly “sticky,” as the base points $s_t$ can easily trick $F_{theta}$ into thinking it’s already made its local goal.1



1. This adversarial robustness issue, while particularly bad for lifted-state approaches, is not unique to them. Even for serial optimization methods that optimize through the full rollout map $mathcal{F}^T$, it is possible to get into unseen states, where it is very easy to have a normal component fed into the sensitive normal components of $D_s F_{theta}$. The action Jacobian’s chain rule expansion is

[Bigl(prod_{t=1}^T D_s F_theta(s_t, a_t)Bigr) D_{a_0}F_theta(s_0, a_0).]

See what happens if any stage of the product has any component normal to the data manifold.


Our fix

This is where our new planner GRASP comes in. The main observation: while $D_s F_{theta}$ is untrustworthy and adversarial, the action space is usually low-dimensional and exhaustively trained, so $D_a F_{theta}$ is actually reasonable to optimize through and doesn’t suffer from the adversarial robustness issue!

The action input is usually lower-dimensional and densely trained (the model has seen every action direction), so action gradients are much better behaved.

At its core, GRASP builds a first-order lifted state / collocation-based planner that is only dependent on action Jacobians through the world model. We thus exploit the differentiability of learned world models $F_{theta}$, while not falling victim to the inherent sensitivity of the state Jacobians $D_s F_{theta}$.

GRASP: Gradient RelAxed Stochastic Planner

As noted before, we start with the collocation planning objective, where we lift the states and relax dynamics into a penalty:

[min_{mathbf{s},mathbf{a}} mathcal{L}(mathbf{s}, mathbf{a}) = sum_{t=0}^{T-1} big|F_theta(s_t,a_t) – s_{t+1}big|_2^2,
quad text{with } s_0 text{ fixed and } s_T=g.]

We then make two key additions.

Ingredient 1: Exploration by noising the state iterates

Even with a smoother objective, planning is nonconvex. We introduce exploration by injecting Gaussian noise into the virtual state updates during optimization.

A simple version:

[s_t leftarrow s_t – eta_s nabla_{s_t}mathcal{L} + sigma_{text{state}} xi, qquad xisimmathcal{N}(0,I).]

Actions are still updated by non-stochastic descent:

[a_t leftarrow a_t – eta_a nabla_{a_t}mathcal{L}.]

The state noise helps you “hop” between basins in the lifted space, while the actions remain guided by gradients. We found that specifically noising states here (as opposed to actions) finds a good balance of exploration and the ability to find sharper minima.2


2. Because we only noise the states (and not the actions), the corresponding dynamics are not truly Langevin dynamics.


Ingredient 2: Reshape gradients: stop brittle state-input gradients, keep action gradients

As discussed, the fragile pathway is the gradient that flows into the state input of the world model, (D_s F_{theta}). The most straightforward way to do this initially is to just stop state gradients into (F_{theta}) directly:

  • Let $bar{s}_t$ be the same value as $s_t$, but with gradients stopped.

Define the stop-gradient dynamics loss:

[mathcal{L}_{text{dyn}}^{text{sg}}(mathbf{s},mathbf{a})
= sum_{t=0}^{T-1} big|F_theta(bar{s}_t, a_t) – s_{t+1}big|_2^2.]

This alone does not work. Notice now states only follow the previous state’s step, without anything forcing the base states to chase the next ones. As a result, there are trivial minima for just stopping at the origin, then only for the final action trying to get to the goal in one step.

Dense goal shaping

We can view the above issue as the goal’s signal being cut off entirely from previous states. One way to fix this is to simply add a dense goal term throughout prediction:

[mathcal{L}_{text{goal}}^{text{sg}}(mathbf{s},mathbf{a})
= sum_{t=0}^{T-1} big|F_theta(bar{s}_t, a_t) – gbig|_2^2.]

In normal settings this would over-bias towards the greedy solution of straight chasing the goal, but this is balanced in our setting by the stop-gradient dynamics loss’s bias towards feasible dynamics. The final objective is then as follows:

[mathcal{L}(mathbf{s},mathbf{a}) = mathcal{L}_{text{dyn}}^{text{sg}}(mathbf{s},mathbf{a}) + gamma , mathcal{L}_{text{goal}}^{text{sg}}(mathbf{s},mathbf{a}).]

The result is a planning optimization objective that does not have dependence on state gradients.


Periodic “sync”: briefly return to true rollout gradients

The lifted stop-gradient objective is great for fast, guided exploration, but it’s still an approximation of the original serial rollout objective.

So every $K_{text{sync}}$ iterations, GRASP does a short refinement phase:

  1. Roll out from $s_0$ using current actions $mathbf{a}$, and take a few small gradient steps on the original serial loss:

[mathbf{a} leftarrow mathbf{a} – eta_{text{sync}},nabla_{mathbf{a}},|s_T(mathbf{a})-g|_2^2.]

The lifted-state optimization still provides the core of the optimization, while this refinement step adds some assistance to keep states and actions grounded towards real trajectories. This refinement step can of course be replaced with a serial planner of your choice (e.g. CEM); the core idea is to still get some of the benefit of the full-path synchronization of serial planners, while still mostly using the benefits of the lifted-state planning.


How GRASP addresses long-range planning

Collocation-based planners offer a natural fix for long-horizon planning, but this optimization is quite difficult through modern world models due to adversarial robustness issues. GRASP proposes a simple solution for a smoother collocation-based planner, alongside stable stochasticity for exploration. As a result, longer-horizon planning ends up not only succeeding more, but also finding such successes faster:

Push-T demo: longer-horizon planning with GRASP.
Horizon CEM GD LatCo GRASP
H=40 61.4% / 35.3s 51.0% / 18.0s 15.0% / 598.0s 59.0% / 8.5s
H=50 30.2% / 96.2s 37.6% / 76.3s 4.2% / 1114.7s 43.4% / 15.2s
H=60 7.2% / 83.1s 16.4% / 146.5s 2.0% / 231.5s 26.2% / 49.1s
H=70 7.8% / 156.1s 12.0% / 103.1s 0.0% / — 16.0% / 79.9s
H=80 2.8% / 132.2s 6.4% / 161.3s 0.0% / — 10.4% / 58.9s

Push-T results. Success rate (%) / median time to success. Bold = best in row. Note the median success time will bias higher with higher success rate; GRASP manages to be faster despite higher success rate.


What’s next?

There is still plenty of work to be done for modern world model planners. We want to exploit the gradient structure of learned world models, and collocation (lifted-state optimization) is a natural approach for long-horizon planning, but it’s crucial to understand typical gradient structure here: smooth and informative action gradients and brittle state gradients. We view GRASP as an initial iteration for such planners.

Extension to diffusion-based world models (deeper latent timesteps can be viewed as smoothed versions of the world model itself), more sophisticated optimizers and noising strategies, and integrating GRASP into either a closed-loop system or RL policy learning for adaptive long-horizon planning are all natural and interesting next steps.

I do genuinely think it’s an exciting time to be working on world model planners. It’s a funny sweet spot where the background literature (planning and control overall) is incredibly mature and well-developed, but the current setting (pure planning optimization over modern, large-scale world models) is still heavily underexplored. But, once we figure out all the right ideas, world model planners will likely become as commonplace as RL.


For more details, read the full paper or visit the project website.


Citation

@article{psenka2026grasp,
  title={Parallel Stochastic Gradient-Based Planning for World Models},
  author={Michael Psenka and Michael Rabbat and Aditi Krishnapriyan and Yann LeCun and Amir Bar},
  year={2026},
  eprint={2602.00475},
  archivePrefix={arXiv},
  primaryClass={cs.LG},
  url={https://arxiv.org/abs/2602.00475}
}
AI Generated Robotic Content

Recent Posts

Open source CRT animation lora for ltx 2.3

None of the video gen models do a real CRT terminal animation look. Weights +…

3 seconds ago

Getting Started with Zero-Shot Text Classification

Zero-shot text classification is a way to label text without first training a classifier on…

9 seconds ago

What Do Your Logits Know? (The Answer May Surprise You!)

Recent work has shown that probing model internals can reveal a wealth of information not…

20 seconds ago

Accelerate Generative AI Inference on Amazon SageMaker AI with G7e Instances

As the demand for generative AI continues to grow, developers and enterprises seek more flexible,…

30 seconds ago

A Humanoid Robot Set a Half-Marathon Record in China

An autonomous robot from the company Honor ran a half marathon in 50:26, beating the…

1 hour ago

Flux2Klein Ksampler Soon!

UPDATED Flux2Klein Ksampler has been added to the repo : here Sample Workflow: here ------------------------------------------------------…

1 day ago