#!/usr/bin/env python3 import os import sys import argparse import numpy as np def compare_pair(res_path: str, sol_path: str) -> str: A = np.load(res_path) B = np.load(sol_path) lines = [] lines.append(f"shape: res {A.shape} vs sol {B.shape}") if A.shape != B.shape: lines.append("DIFF: shape mismatch") return "\n".join(lines) nan_mask_equal = np.array_equal(np.isnan(A), np.isnan(B)) finite_mask = (~np.isnan(A)) & (~np.isnan(B)) equal_exact = np.array_equal(A[finite_mask], B[finite_mask]) mismatch_count = int(np.sum(A[finite_mask] != B[finite_mask])) lines.append(f"NaN mask equal: {nan_mask_equal}") lines.append(f"Exact equal (finite): {equal_exact}") lines.append(f"mismatches (finite): {mismatch_count} / {finite_mask.sum()}") if finite_mask.any(): diffs = np.abs(A[finite_mask] - B[finite_mask]) lines.append(f"diff stats (finite): max={diffs.max():.6g}, mean={diffs.mean():.6g}") lines.append(f"allclose atol=1e-10: {np.allclose(A, B, rtol=0.0, atol=1e-10, equal_nan=True)}") lines.append(f"allclose atol=1e-6: {np.allclose(A, B, rtol=0.0, atol=1e-6, equal_nan=True)}") else: lines.append("No finite entries to compare") return "\n".join(lines) def main(): p = argparse.ArgumentParser(description="Compare arrays under results/arrays vs solution/arrays") p.add_argument("--results_dir", required=True, help="Path to results_XXX directory containing arrays/") p.add_argument("--solution_arrays", default="solution/arrays", help="Path to solution/arrays directory") args = p.parse_args() res_arrays = os.path.join(args.results_dir, "arrays") sol_arrays = args.solution_arrays if not os.path.isdir(res_arrays): print(f"ERROR: results arrays dir not found: {res_arrays}", file=sys.stderr) sys.exit(2) if not os.path.isdir(sol_arrays): print(f"ERROR: solution arrays dir not found: {sol_arrays}", file=sys.stderr) sys.exit(2) files = sorted([f for f in os.listdir(res_arrays) if f.endswith('.npy')]) any_diff = False print(f"Comparing {len(files)} files found in {res_arrays} against {sol_arrays}\n") for f in files: sol_path = os.path.join(sol_arrays, f) res_path = os.path.join(res_arrays, f) print(f"== {f} ==") if not os.path.exists(sol_path): print(f"- No reference: {sol_path}") print("") continue try: report = compare_pair(res_path, sol_path) print(report) print("") if "DIFF:" in report or "Exact equal (finite): False" in report or "NaN mask equal: False" in report: any_diff = True except Exception as e: print(f"ERROR comparing {f}: {e}") any_diff = True print("") if any_diff: print("SUMMARY: Differences detected (see details above)") sys.exit(1) else: print("SUMMARY: All compared arrays match within tolerance (see details above)") sys.exit(0) if __name__ == "__main__": main()