Motion forecasting is one of the hardest sub-problems in autonomous driving. Given a few seconds of history for every agent in the scene — other cars, cyclists, pedestrians — your model must predict a distribution over their future trajectories. Get it wrong and your planner either over-brakes (annoying) or mis-allocates collision-avoidance budget (dangerous).

At Rivian I led the implementation of Wayformer for our offboard planning stack. This post covers the architecture, the key design choices that make it work, and the practical lessons from shipping it at scale.


The core problem

A standard scene has $N$ agents, each with a $T$-step history. Naively attending every agent to every other agent at every timestep costs:

$$\mathcal{O}(N^2 \cdot T^2)$$

For a dense urban scene with $N = 64$ agents and $T = 50$ timesteps that's 10.2 M attention pairs per forward pass — before you've even touched the HD map. Wayformer's answer is factorized attention.

Architecture overview

Wayformer decomposes the joint attention into two orthogonal axes:

  1. Temporal self-attention — each agent attends to its own history tokens
  2. Social self-attention — at each timestep, agents attend to each other

The input is a sequence of per-agent state vectors $\mathbf{x}_i^t \in \mathbb{R}^d$. After positional encoding these are arranged into a 2-D grid:

$$\mathbf{X} \in \mathbb{R}^{N \times T \times d}$$

Temporal encoder

Temporal attention is applied row-wise (per agent, across time):

$$\mathbf{H}^{(l)}_{\text{temp}} = \text{softmax}\!\left(\frac{Q_t K_t^\top}{\sqrt{d}}\right) V_t, \quad \text{applied for each } i \in [N]$$

Social encoder

Social attention is applied column-wise (across agents, at each timestep):

$$\mathbf{H}^{(l)}_{\text{soc}} = \text{softmax}\!\left(\frac{Q_s K_s^\top}{\sqrt{d}}\right) V_s, \quad \text{applied for each } t \in [T]$$

Alternating these two reduces the effective cost to:

$$\mathcal{O}\!\left(N \cdot T^2 + T \cdot N^2\right)$$

For our typical scene sizes this is a ~15× reduction in attention FLOPs.

Multimodal decoder

The decoder produces $K$ future trajectories (modes) per agent via learned mode queries. Each mode query $\mathbf{q}_k$ attends to the fused representation:

$$\hat{\mathbf{y}}_k = \text{MLP}\!\left(\text{CrossAttn}(\mathbf{q}_k, \mathbf{Z})\right) \in \mathbb{R}^{T_f \times 2}$$

Loss function

Training uses the winner-takes-all strategy: only the mode closest to ground truth contributes to the regression loss.

$$\mathcal{L} = \underbrace{-\log p_{k^*}}_{\text{classification}} + \underbrace{\sum_{t=1}^{T_f} \left\| \hat{\mathbf{y}}_{k^*}^t - \mathbf{y}^t \right\|_2}_{\text{regression on best mode}}$$

where $k^* = \arg\min_k \text{FDE}(\hat{\mathbf{y}}_k, \mathbf{y})$.

Without winner-takes-all, all modes collapse to the mean trajectory — the classic mode-averaging failure of trajectory models.

Key implementation lessons

1. Agent masking is non-trivial

Real scenes have variable $N$. Padding + masking sounds simple, but you need to propagate masks correctly through both temporal and social attention.

def masked_attention(q, k, v, mask):
    # mask: (B, N) bool, True = valid agent
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
    if mask is not None:
        pad_mask = ~mask.unsqueeze(1).unsqueeze(2)   # (B, 1, 1, N)
        scores = scores.masked_fill(pad_mask, -1e9)
    return torch.softmax(scores, dim=-1) @ v

2. Coordinate frame matters

We normalize all agent states into an agent-centric frame at the last observed timestep — position becomes $(0, 0)$, heading becomes $0$. This dramatically improves generalization because the model doesn't need to learn the same maneuver at every map location.

3. Gradient instability in long sequences

With $T = 50$ input steps, gradients through temporal attention can explode for early timesteps. Two mitigations worked well:

  • Gradient clipping at norm 1.0
  • Pre-LN (pre-layer-normalization) rather than post-LN in every transformer block

Argoverse 2 results

Model minADE₆ ↓ minFDE₆ ↓ MR₆ ↓
LSTM baseline 1.14 2.62 0.47
MTR 0.60 1.23 0.13
Wayformer 0.58 1.16 0.12

Production considerations

A research benchmark number and a production model are different things. Some gaps closed at Rivian:

  • Latency: reference runs ~80 ms on V100. We needed <20 ms — achieved via TorchScript export, reducing $K$ from 64→16 modes, INT8 quantization of the scene encoder.
  • Calibration: softmax mode probabilities are poorly calibrated out of the box. Temperature scaling on a held-out val set fixed this.
  • Long-horizon stability: beyond 3 s the prediction variance grows fast. We weight short-horizon portions more heavily in our planner's collision metrics.
Key takeaway Wayformer's factorized attention is the right abstraction for AV forecasting: it respects the 2-D structure of the problem (time × agents) without quadratic blowup. The architecture is clean enough that you can understand and modify every component, which matters in production where you constantly need to add new agent types or input modalities.

References

  • Nayakanti et al., Wayformer: Motion Forecasting via Simple & Efficient Attention Networks, ICRA 2023.
  • github.com/Manojbhat09/Trajformer — earlier transformer-based forecasting work at CMU.