LLM4HEP / jobs /test_models.py
ho22joshua's picture
initial commit
cfcbbc8
import os
import subprocess
import time
import yaml
from concurrent.futures import ProcessPoolExecutor, as_completed
import re
import argparse
def sanitize(s):
# Replace / and : and other non-alphanumeric chars with _
return re.sub(r'[^A-Za-z0-9_.-]', '_', s)
def run_for_model(supervisor, coder, step, config_filepath, outdir):
timestamp = time.strftime("%Y%m%d_%H%M%S")
pid = os.getpid()
slurm_jobid = os.environ.get("SLURM_JOB_ID")
if slurm_jobid:
job_id = f"{sanitize(supervisor)}_{sanitize(coder)}_step{step}_{timestamp}_{pid}_slurm_{slurm_jobid}"
else:
job_id = f"{sanitize(supervisor)}_{sanitize(coder)}_step{step}_{timestamp}_{pid}"
out_path = os.path.join(outdir, job_id)
run_cmd = (
f"./run_smk_sequential.sh --step{step} --out-dir {out_path} --config {config_filepath} --validate"
)
subprocess.run(run_cmd, shell=True, check=True, executable='/bin/bash')
return supervisor, coder, pid
def main(supervisor, coder, num_tests, outdir):
config = {"supervisor": supervisor, "coder": coder, "temperature": 1.5}
config_dir = "/dev/shm/config"
os.makedirs(config_dir, exist_ok=True)
config_filepath = os.path.join(config_dir, f"{sanitize(supervisor)}_{sanitize(coder)}.yml")
with open(config_filepath, "w") as f:
yaml.dump(config, f)
futures = []
with ProcessPoolExecutor(max_workers=2) as executor:
for _ in range(num_tests):
for step in [1, 2, 3, 4, 5]:
futures.append(executor.submit(
run_for_model, supervisor, coder, step, config_filepath, outdir
))
for future in as_completed(futures):
supervisor, coder, pid = future.result()
print(f"Completed PID {pid}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("supervisor", help="Supervisor name")
parser.add_argument("coder", help="Coder name")
parser.add_argument("num_tests", type=int, help="Number of tests")
parser.add_argument("--outdir", default="/global/cfs/projectdirs/atlas/llm4hep/",
help="Output directory (default: %(default)s)")
args = parser.parse_args()
main(args.supervisor, args.coder, args.num_tests, args.outdir)