| | |
| | |
| | |
| | |
| | @@ -989,6 +989,13 @@ class PI05Policy(PreTrainedPolicy): |
| | if remap_count > 0: |
| | print(f"Remapped {remap_count} state dict keys") |
| | |
| | # Load the remapped state dict into the model |
| | missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict) |
| | + |
| | + # --- FIX: tie embed_tokens to lm_head if embed_tokens missing in ckpt --- |
| | + if any("embed_tokens.weight" in k for k in missing_keys): |
| | + with torch.no_grad(): |
| | + embed = model.model.paligemma_with_expert.paligemma.model.language_model.embed_tokens |
| | + lm_head = model.model.paligemma_with_expert.paligemma.lm_head |
| | + embed.weight = lm_head.weight |
| | |
| | return model |
| |
|