| |
|
| | """
|
| | patch_sigma_env.py
|
| |
|
| | Idempotent patcher for Sigma VLA experiments.
|
| |
|
| | Patch goals:
|
| | 1) LeRobot PI05Policy (modeling_pi05.py):
|
| | 1.1 If ckpt omits embed_tokens.weight, tie embed_tokens.weight to lm_head.weight
|
| | *after* load_state_dict runs.
|
| | 1.2 Ensure torch is imported if target file lacks it.
|
| | 1.3 Downgrade the "incorrect transformer version" hard guard
|
| | (ValueError) to a WARNING so new GPU environments don't crash.
|
| | IMPORTANT: preserve indentation and patch only the intended guard.
|
| |
|
| | 2) LeRobot policies __init__ (lerobot/policies/__init__.py):
|
| | 2.1 Make ONLY Groot/Diffusers-related imports optional (wrapped in try/except),
|
| | leaving all other exports untouched.
|
| | This prevents errors like: No module named 'triton.ops'
|
| | or diffusers/peft chain issues on fresh GPUs.
|
| |
|
| | 3) eval_sigma_vla_rollout.py (your /workspace eval script):
|
| | 3.1 Force strict=False for PI05Policy.from_pretrained calls:
|
| | - strict=True -> strict=False
|
| | - if a PI05Policy load call has no strict arg, add strict=False
|
| | 3.2 Ensure randomized subset evaluation is possible:
|
| | - add --shuffle arg if missing
|
| | - change DataLoader shuffle=False -> shuffle=getattr(args,"shuffle",False)
|
| |
|
| | Safe to run multiple times; no-op if already patched.
|
| | """
|
| |
|
| | import os
|
| | import re
|
| | import sys
|
| | import pathlib
|
| | from typing import Optional, Tuple, List
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def _read_text(p: pathlib.Path) -> str:
|
| | return p.read_text(encoding="utf-8")
|
| |
|
| | def _write_text(p: pathlib.Path, s: str) -> None:
|
| | p.write_text(s, encoding="utf-8")
|
| |
|
| | def _search_file(
|
| | roots: List[os.PathLike],
|
| | filename: str,
|
| | must_contain: Optional[str] = None
|
| | ) -> Optional[pathlib.Path]:
|
| | for r in roots:
|
| | r = pathlib.Path(r)
|
| | if not r.exists():
|
| | continue
|
| | for p in r.rglob(filename):
|
| | if must_contain and must_contain not in str(p):
|
| | continue
|
| | return p
|
| | return None
|
| |
|
| | def _default_roots():
|
| | return [
|
| | "/workspace/lerobot/src",
|
| | "/workspace/lerobot",
|
| | pathlib.Path(sys.prefix)
|
| | / "lib"
|
| | / f"python{sys.version_info.major}.{sys.version_info.minor}"
|
| | / "site-packages",
|
| | ]
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def find_pi05_file() -> pathlib.Path:
|
| | env = os.getenv("PI05_FILE")
|
| | if env:
|
| | p = pathlib.Path(env)
|
| | if p.exists():
|
| | return p
|
| |
|
| | p = _search_file(_default_roots(), "modeling_pi05.py", must_contain="/pi05/")
|
| | if p and p.exists():
|
| | return p
|
| |
|
| | raise FileNotFoundError("modeling_pi05.py not found. Set PI05_FILE env var to its path.")
|
| |
|
| |
|
| | def ensure_torch_import(s: str) -> str:
|
| | if re.search(r"(?m)^\s*import\s+torch\b", s) or re.search(r"(?m)^\s*from\s+torch\b", s):
|
| | return s
|
| |
|
| | lines = s.splitlines(True)
|
| | insert_idx = 0
|
| |
|
| | if lines and lines[0].startswith("#!"):
|
| | insert_idx = 1
|
| |
|
| |
|
| | if insert_idx < len(lines) and lines[insert_idx].lstrip().startswith('"""'):
|
| | i = insert_idx + 1
|
| | while i < len(lines) and '"""' not in lines[i]:
|
| | i += 1
|
| | if i < len(lines):
|
| | insert_idx = i + 1
|
| |
|
| | lines.insert(insert_idx, "import torch # PATCH: required for embed/lm_head tying\n")
|
| | return "".join(lines)
|
| |
|
| |
|
| | def patch_pi05_embed_tie(p: pathlib.Path) -> Tuple[bool, str]:
|
| | s = _read_text(p)
|
| | s = ensure_torch_import(s)
|
| |
|
| | marker = "PATCH: tie embed_tokens to lm_head if ckpt omitted embed_tokens"
|
| | if marker in s:
|
| | _write_text(p, s)
|
| | return False, f"PI05 embed-tie patch already present: {p}"
|
| |
|
| | pat = r"(?m)^(\s*)missing_keys,\s*unexpected_keys\s*=\s*model\.load_state_dict\(\s*remapped_state_dict\s*,\s*strict\s*=\s*strict\s*\)\s*$"
|
| | m = re.search(pat, s)
|
| | if not m:
|
| | _write_text(p, s)
|
| | return False, f"Could not find load_state_dict line to patch in PI05 file: {p}"
|
| |
|
| | indent = m.group(1)
|
| | inject = (
|
| | f"\n{indent}# --- PATCH: tie embed_tokens to lm_head if ckpt omitted embed_tokens ---\n"
|
| | f"{indent}if any('embed_tokens.weight' in k for k in missing_keys):\n"
|
| | f"{indent} try:\n"
|
| | f"{indent} with torch.no_grad():\n"
|
| | f"{indent} embed = model.model.paligemma_with_expert.paligemma.model.language_model.embed_tokens\n"
|
| | f"{indent} lm_head = model.model.paligemma_with_expert.paligemma.lm_head\n"
|
| | f"{indent} if embed is not None and lm_head is not None:\n"
|
| | f"{indent} embed.weight = lm_head.weight # {marker}\n"
|
| | f"{indent} except Exception as _e:\n"
|
| | f"{indent} print('[patch_pi05] Could not tie embed_tokens to lm_head:', _e)\n"
|
| | )
|
| |
|
| | s2 = re.sub(pat, lambda mm: mm.group(0) + inject, s, count=1)
|
| | _write_text(p, s2)
|
| | return True, f"Patched PI05 embed-tie in: {p}"
|
| |
|
| |
|
| | def patch_pi05_transformers_guard(p: pathlib.Path) -> Tuple[bool, str]:
|
| | """
|
| | Downgrade ONLY the PI05 hard guard:
|
| | ValueError: An incorrect transformer version is used...
|
| | to WARNING print, preserving indentation.
|
| |
|
| | Strategy:
|
| | - Find raise ValueError(msg) from None lines.
|
| | - Only patch the one whose nearby context contains
|
| | "incorrect transformer version".
|
| | """
|
| | s = _read_text(p)
|
| | marker = "PATCH: downgrade transformer version guard"
|
| | if marker in s:
|
| | return False, f"PI05 transformers-guard patch already present: {p}"
|
| |
|
| | if "incorrect transformer version" not in s:
|
| | return False, f"No transformers guard message found to patch in: {p}"
|
| |
|
| | lines = s.splitlines(True)
|
| | raise_pat = re.compile(r"^(\s*)raise\s+ValueError\(\s*msg\s*\)\s*from\s*None\s*$")
|
| |
|
| | target_idx = None
|
| | target_indent = ""
|
| |
|
| | for i, line in enumerate(lines):
|
| | m = raise_pat.match(line)
|
| | if not m:
|
| | continue
|
| |
|
| | window_start = max(0, i - 8)
|
| | window = "".join(lines[window_start:i+1]).lower()
|
| | if "incorrect transformer version" in window:
|
| | target_idx = i
|
| | target_indent = m.group(1)
|
| | break
|
| |
|
| | if target_idx is None:
|
| | return False, f"Guard raise line with context not found in: {p}"
|
| |
|
| | repl = (
|
| | f"{target_indent}# --- PATCH: downgrade transformer version guard ---\n"
|
| | f"{target_indent}print('[patch_pi05] WARNING:', msg) # {marker}\n"
|
| | f"{target_indent}# continues execution despite version mismatch\n"
|
| | )
|
| |
|
| | lines[target_idx] = repl
|
| | s2 = "".join(lines)
|
| | _write_text(p, s2)
|
| | return True, f"Patched PI05 transformers guard (raise->warn) in: {p}"
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def find_policies_init() -> pathlib.Path:
|
| | env = os.getenv("POLICIES_INIT_FILE")
|
| | if env:
|
| | p = pathlib.Path(env)
|
| | if p.exists():
|
| | return p
|
| |
|
| | p = _search_file(_default_roots(), "__init__.py", must_contain="/lerobot/policies/")
|
| | if p and p.exists():
|
| | return p
|
| |
|
| | raise FileNotFoundError("lerobot/policies/__init__.py not found. Set POLICIES_INIT_FILE env var.")
|
| |
|
| |
|
| | def patch_policies_optional_imports(p: pathlib.Path) -> Tuple[bool, str]:
|
| | """
|
| | Make ONLY Groot/Diffusers imports optional.
|
| | This avoids wrapping unrelated exports/imports.
|
| | """
|
| | s = _read_text(p)
|
| | marker = "PATCH: optional Groot/Diffusers imports"
|
| | if marker in s:
|
| | return False, f"Policies optional-import patch already present: {p}"
|
| |
|
| | lines = s.splitlines(True)
|
| |
|
| | def is_groot_line(line: str) -> bool:
|
| |
|
| | return bool(re.search(r"^\s*from\s+\.\s*groot\b|^\s*from\s+\.groot\b|^\s*import\s+.*\bgroot\b", line))
|
| |
|
| | idxs = [i for i, l in enumerate(lines) if is_groot_line(l)]
|
| | if not idxs:
|
| | return False, f"No Groot imports found to wrap in: {p}"
|
| |
|
| |
|
| | groups = []
|
| | start = prev = idxs[0]
|
| | for i in idxs[1:]:
|
| | if i == prev + 1:
|
| | prev = i
|
| | else:
|
| | groups.append((start, prev))
|
| | start = prev = i
|
| | groups.append((start, prev))
|
| |
|
| | new_lines = []
|
| | last_end = -1
|
| | for (a, b) in groups:
|
| |
|
| | new_lines.extend(lines[last_end + 1:a])
|
| |
|
| |
|
| | new_lines.append("# --- PATCH: optional Groot/Diffusers imports ---\n")
|
| | new_lines.append(f"try: # {marker}\n")
|
| | for j in range(a, b + 1):
|
| | new_lines.append(" " + lines[j].lstrip())
|
| | new_lines.append("except Exception as _e:\n")
|
| | new_lines.append(" print('[policies_init] WARNING: optional groot deps missing:', _e)\n")
|
| |
|
| | last_end = b
|
| |
|
| |
|
| | new_lines.extend(lines[last_end + 1:])
|
| |
|
| | s2 = "".join(new_lines)
|
| | if s2 == s:
|
| | return False, f"Policies file unchanged after optional-import attempt: {p}"
|
| |
|
| | _write_text(p, s2)
|
| | return True, f"Patched policies __init__ optional imports in: {p}"
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def find_eval_file() -> pathlib.Path:
|
| | env = os.getenv("EVAL_FILE")
|
| | if env:
|
| | p = pathlib.Path(env)
|
| | if p.exists():
|
| | return p
|
| |
|
| | p = pathlib.Path("/workspace/eval_sigma_vla_rollout.py")
|
| | if p.exists():
|
| | return p
|
| |
|
| | pp = _search_file(["/workspace", "/workspace/lerobot"], "eval_sigma_vla_rollout.py")
|
| | if pp and pp.exists():
|
| | return pp
|
| |
|
| | raise FileNotFoundError("eval_sigma_vla_rollout.py not found. Set EVAL_FILE env var.")
|
| |
|
| |
|
| | def patch_eval_force_strict_false(p: pathlib.Path) -> Tuple[bool, str]:
|
| | s = _read_text(p)
|
| | marker = "PATCH: force strict=False for PI05Policy"
|
| |
|
| |
|
| | pat_strict_true = r"(policy_cls\.from_pretrained\([^)]*strict\s*=\s*)True(\s*[^)]*\))"
|
| | s2, n_true = re.subn(pat_strict_true, r"\1False\2", s)
|
| |
|
| |
|
| | def _add_strict_false_call(match: re.Match) -> str:
|
| | call = match.group(0)
|
| | if "strict" in call:
|
| | return call
|
| | return call[:-1] + ", strict=False)"
|
| |
|
| | pat_no_strict_1 = r"policy_cls\.from_pretrained\(\s*repo_id\s*,\s*token\s*=\s*hf_token\s*\)"
|
| | pat_no_strict_2 = r"policy_cls\.from_pretrained\(\s*pretrained_name_or_path\s*=\s*repo_id\s*,\s*token\s*=\s*hf_token\s*\)"
|
| |
|
| | s3, n_add1 = re.subn(pat_no_strict_1, _add_strict_false_call, s2)
|
| | s4, n_add2 = re.subn(pat_no_strict_2, _add_strict_false_call, s3)
|
| |
|
| | changed = (n_true + n_add1 + n_add2) > 0
|
| | if not changed:
|
| | if marker in s:
|
| | return False, f"Eval strict patch already present: {p}"
|
| | return False, f"Eval already strict=False or no PI05 strict targets found: {p}"
|
| |
|
| | if marker not in s4:
|
| |
|
| | s4 = s4.replace("strict=False)", f"strict=False) # {marker}", 1)
|
| |
|
| | _write_text(p, s4)
|
| | return True, f"Patched eval PI05 strict=False in: {p}"
|
| |
|
| |
|
| | def patch_eval_shuffle_support(p: pathlib.Path) -> Tuple[bool, str]:
|
| | s = _read_text(p)
|
| | marker_arg = "PATCH: add --shuffle arg"
|
| | marker_dl = "PATCH: DataLoader shuffle uses args.shuffle"
|
| |
|
| | changed = False
|
| |
|
| |
|
| | if re.search(r'add_argument\(\s*["\']--shuffle["\']', s) is None:
|
| |
|
| | arg_pat = re.compile(r"(?m)^\s*parser\.add_argument\(.+?\)\s*$")
|
| | matches = list(arg_pat.finditer(s))
|
| | if matches:
|
| | last = matches[-1]
|
| | insert_pos = last.end()
|
| | insert_text = (
|
| | "\nparser.add_argument("
|
| | "\"--shuffle\", action=\"store_true\", "
|
| | "help=\"Shuffle dataset order to sample different subsets per seed.\")"
|
| | f" # {marker_arg}\n"
|
| | )
|
| | s = s[:insert_pos] + insert_text + s[insert_pos:]
|
| | changed = True
|
| |
|
| |
|
| | if marker_dl not in s:
|
| | def _dl_repl(m: re.Match) -> str:
|
| | prefix = m.group(1)
|
| | return prefix + f'getattr(args, "shuffle", False) # {marker_dl}'
|
| |
|
| |
|
| | pat_dl = re.compile(r"(?s)(DataLoader\([\s\S]{0,1200}?shuffle\s*=\s*)False")
|
| | if pat_dl.search(s):
|
| | s = pat_dl.sub(_dl_repl, s, count=1)
|
| | changed = True
|
| |
|
| | if changed:
|
| | _write_text(p, s)
|
| | return True, f"Patched eval shuffle support in: {p}"
|
| |
|
| | return False, f"Eval shuffle support already present or no targets found: {p}"
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def main():
|
| | changed_any = False
|
| |
|
| | try:
|
| | pi05_file = find_pi05_file()
|
| | changed, msg = patch_pi05_embed_tie(pi05_file)
|
| | print(msg)
|
| | changed_any |= changed
|
| | except Exception as e:
|
| | print("[patch_sigma_env] PI05 embed-tie patch skipped:", e)
|
| |
|
| | try:
|
| | pi05_file = find_pi05_file()
|
| | changed, msg = patch_pi05_transformers_guard(pi05_file)
|
| | print(msg)
|
| | changed_any |= changed
|
| | except Exception as e:
|
| | print("[patch_sigma_env] PI05 transformers-guard patch skipped:", e)
|
| |
|
| | try:
|
| | policies_init = find_policies_init()
|
| | changed, msg = patch_policies_optional_imports(policies_init)
|
| | print(msg)
|
| | changed_any |= changed
|
| | except Exception as e:
|
| | print("[patch_sigma_env] policies __init__ patch skipped:", e)
|
| |
|
| | try:
|
| | eval_file = find_eval_file()
|
| | changed, msg = patch_eval_force_strict_false(eval_file)
|
| | print(msg)
|
| | changed_any |= changed
|
| | except Exception as e:
|
| | print("[patch_sigma_env] Eval strict patch skipped:", e)
|
| |
|
| | try:
|
| | eval_file = find_eval_file()
|
| | changed, msg = patch_eval_shuffle_support(eval_file)
|
| | print(msg)
|
| | changed_any |= changed
|
| | except Exception as e:
|
| | print("[patch_sigma_env] Eval shuffle patch skipped:", e)
|
| |
|
| | if changed_any:
|
| | print("[patch_sigma_env] Done. Patches applied.")
|
| | else:
|
| | print("[patch_sigma_env] Done. Nothing to change (already patched).")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|