Skip to content

Commit

Permalink
[BugFix] PPOs with composite distribution (#2791)
Browse files Browse the repository at this point in the history
Co-authored-by: Louis Faury <louis.faury@helsing.ai>
(cherry picked from commit edfa25d)
  • Loading branch information
louisfaury authored and vmoens committed Feb 17, 2025
1 parent 2ebcb2e commit 882dc79
Showing 1 changed file with 13 additions and 24 deletions.
37 changes: 13 additions & 24 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,12 +527,9 @@ def _log_weight(
self.actor_network
) if self.functional else contextlib.nullcontext():
dist = self.actor_network.get_dist(tensordict)
if isinstance(dist, CompositeDistribution):
is_composite = True
else:
is_composite = False

# current log_prob of actions
is_composite = isinstance(dist, CompositeDistribution)

if is_composite:
action = tensordict.select(
*(
Expand Down Expand Up @@ -562,25 +559,26 @@ def _log_weight(
log_prob = dist.log_prob(action)
if is_composite:
with set_composite_lp_aggregate(False):
if log_prob.batch_size != adv_shape:
log_prob.batch_size = adv_shape
if not is_tensor_collection(prev_log_prob):
# this isn't great, in general multihead actions should have a composite log-prob too
# this isn't great: in general, multi-head actions should have a composite log-prob too
warnings.warn(
"You are using a composite distribution, yet your log-probability is a tensor. "
"Make sure you have called tensordict.nn.set_composite_lp_aggregate(False).set() at "
"the beginning of your script to get a proper composite log-prob.",
category=UserWarning,
)
if log_prob.batch_size != adv_shape:
log_prob.batch_size = adv_shape
if (
is_composite
and not is_tensor_collection(prev_log_prob)
and is_tensor_collection(log_prob)
):
log_prob = _sum_td_features(log_prob)
log_prob.view_as(prev_log_prob)

if is_tensor_collection(log_prob):
log_prob = _sum_td_features(log_prob)
log_prob.view_as(prev_log_prob)

log_weight = (log_prob - prev_log_prob).unsqueeze(-1)
if is_tensor_collection(log_weight):
log_weight = _sum_td_features(log_weight)
log_weight = log_weight.view(adv_shape).unsqueeze(-1)

kl_approx = (prev_log_prob - log_prob).unsqueeze(-1)
if is_tensor_collection(kl_approx):
kl_approx = _sum_td_features(kl_approx)
Expand Down Expand Up @@ -691,9 +689,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
log_weight, dist, kl_approx = self._log_weight(
tensordict, adv_shape=advantage.shape[:-1]
)
if is_tensor_collection(log_weight):
log_weight = _sum_td_features(log_weight)
log_weight = log_weight.view(advantage.shape)
neg_loss = log_weight.exp() * advantage
td_out = TensorDict({"loss_objective": -neg_loss})
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging
Expand Down Expand Up @@ -987,8 +982,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
# to different, unrelated trajectories, which is not standard. Still, it can give an idea of the weights'
# dispersion.
lw = log_weight.squeeze()
if not isinstance(lw, torch.Tensor):
lw = _sum_td_features(lw)
ess = (2 * lw.logsumexp(0) - (2 * lw).logsumexp(0)).exp()
batch = log_weight.shape[0]

Expand All @@ -1000,8 +993,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
gain2 = ratio * advantage

gain = torch.stack([gain1, gain2], -1).min(dim=-1).values
if is_tensor_collection(gain):
gain = _sum_td_features(gain)
td_out = TensorDict({"loss_objective": -gain})
td_out.set("clip_fraction", clip_fraction)
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging
Expand Down Expand Up @@ -1291,8 +1282,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
tensordict_copy, adv_shape=advantage.shape[:-1]
)
neg_loss = log_weight.exp() * advantage
if is_tensor_collection(neg_loss):
neg_loss = _sum_td_features(neg_loss)

with self.actor_network_params.to_module(
self.actor_network
Expand Down

0 comments on commit 882dc79

Please sign in to comment.