import os import sys import numpy as np import matplotlib.pyplot as plt # ATLAS style only needed for plotting try: import atlas_mpl_style as ampl ampl.use_atlas_style() plt.rcParams['font.family'] = 'DejaVu Sans' except ImportError: print("Warning: ATLAS style not available, using default matplotlib style") plt.style.use('default') # Plotting helpers are not used in array-only validation, keep import disabled to reduce deps # from utils_plot import plot_myy_comparison, plot_scores_comparison import argparse parser = argparse.ArgumentParser() add_arg = parser.add_argument add_arg('--out_dir', help='output directory') add_arg('--step', type=int, choices=[1, 2, 3, 4, 5], help='Validate only specific step (1-5)') args = parser.parse_args() out_dir = args.out_dir specific_step = args.step def arrays_match(generated, reference, name: str, atol: float = 1e-10) -> bool: """ Compare two numpy arrays element-wise with a strict absolute tolerance. - NaNs are considered equal when they appear at the same positions. - rtol is set to 0.0 so only absolute tolerance matters. Prints a concise status and returns True/False. """ print(f"Validating {name}...") if generated.shape != reference.shape: print(f" āŒ Shape mismatch: {generated.shape} vs {reference.shape}") return False ok = np.allclose(generated, reference, rtol=0.0, atol=atol, equal_nan=True) if ok: print(f" āœ… {name} matches (atol={atol})") return True # Brief diff stats to aid debugging nan_mask_equal = np.array_equal(np.isnan(generated), np.isnan(reference)) finite = (~np.isnan(generated)) & (~np.isnan(reference)) mismatches = int(np.sum(generated[finite] != reference[finite])) print(f" āŒ {name} differs: NaN mask equal={nan_mask_equal}, finite mismatches={mismatches}/{int(finite.sum())}") if finite.any(): diffs = np.abs(generated[finite] - reference[finite]) print(f" diff stats: max={diffs.max():.6g}, mean={diffs.mean():.6g}") # Additional debug: show sample mismatches print("šŸ” Running detailed mismatch analysis...") analyze_array_differences(generated, reference, name) return False def calculate_adaptive_tolerance(values, significant_digits=4): """ Calculate adaptive tolerance based on the magnitude of values to achieve desired significant digits. For each value, the tolerance is set to preserve the specified number of significant digits. Examples: - Value 123000 with 4 sig digits: tolerance = 1000 (1e3) - Value 0.00014 with 4 sig digits: tolerance = 0.0000014 (1.4e-6) - Value 0 with 4 sig digits: tolerance = 1e-10 (small default) """ # Handle zero values non_zero_mask = values != 0 tolerances = np.full_like(values, 1e-10, dtype=float) # Default for zeros if np.any(non_zero_mask): # Calculate tolerance as value / 10^(significant_digits) # This preserves the desired number of significant digits abs_values = np.abs(values[non_zero_mask]) tolerances[non_zero_mask] = abs_values / (10 ** significant_digits) return tolerances def analyze_array_differences(generated, reference, array_name, significant_digits=4): """ Analyze differences between generated and reference numpy arrays. Uses adaptive tolerance based on significant digits rather than fixed tolerance. """ print(f"\nšŸ” Detailed analysis for {array_name} (using {significant_digits} significant digit tolerance):") print(f" Generated shape: {generated.shape}, Reference shape: {reference.shape}") print(f" Tolerance: Adaptive based on {significant_digits} significant digits per value") # Check for shape differences first if generated.shape != reference.shape: print(f" āŒ Shape mismatch: {generated.shape} vs {reference.shape}") return # Calculate adaptive tolerances for each element combined_values = np.abs(np.concatenate([generated.flatten(), reference.flatten()])) adaptive_tolerances = calculate_adaptive_tolerance(combined_values, significant_digits) # Reshape tolerances to match original arrays atol_array = adaptive_tolerances[:generated.size].reshape(generated.shape) # Use absolute tolerance only (relative tolerance not used) # Find differences and identify where tolerances are exceeded diff = generated - reference abs_diff = np.abs(diff) not_close = abs_diff > atol_array # Remove any comparisons involving NaNs (gen or ref) invalid = np.isnan(generated) | np.isnan(reference) not_close = not_close & ~invalid total_different = np.sum(not_close) if total_different == 0: print(" āœ… All elements match within tolerance") return print(f" āŒ {total_different} elements differ (out of {generated.size} total)") # Show numeric mismatches only (exclude any NaN comparisons) flat_gen = generated.flatten() flat_ref = reference.flatten() flat_not_close = not_close.flatten() # Mask to include only finite mismatches numeric_mask = (~np.isnan(flat_gen)) & (~np.isnan(flat_ref)) mismatch_mask = flat_not_close & numeric_mask if np.any(mismatch_mask): diff_indices = np.where(mismatch_mask)[0][:10] print(" šŸ“Š Sample numeric mismatches (first 10 indices):") for idx in diff_indices: gen_val = flat_gen[idx] ref_val = flat_ref[idx] diff_val = gen_val - ref_val print(f" Index {idx}: gen={gen_val}, ref={ref_val}, diff={diff_val}") else: print(" āœ… No numeric mismatches (all differences involve NaNs)") # Skip overall statistics for now - they may not be meaningful for all data types # Analyze differences by column (if 2D array) if generated.ndim == 2: col_diffs = np.sum(not_close, axis=0) cols_with_diffs = np.where(col_diffs > 0)[0] if len(cols_with_diffs) > 0: print(f" šŸ“Š Columns with differences: {cols_with_diffs[:10]} (showing first 10)") # Show side-by-side entries for first 10 differing columns num_cols_to_show = min(10, len(cols_with_diffs)) num_rows_to_show = min(5, generated.shape[0]) # Show first 5 rows print(f" šŸ“‹ Sample entries (first {num_rows_to_show} rows, first {num_cols_to_show} differing columns):") print(" Row | Column | Generated Value | Reference Value | Difference") print(" ----|--------|----------------|-----------------|------------") for col_idx in cols_with_diffs[:num_cols_to_show]: for row_idx in range(num_rows_to_show): gen_val = generated[row_idx, col_idx] ref_val = reference[row_idx, col_idx] diff = gen_val - ref_val # Format values nicely gen_str = f"{gen_val:.6g}" if not np.isnan(gen_val) else "NaN" ref_str = f"{ref_val:.6g}" if not np.isnan(ref_val) else "NaN" diff_str = f"{diff:.6g}" if not np.isnan(diff) else "NaN" print(f" {row_idx:3d} | {col_idx:3d} | {gen_str:>14} | {ref_str:>15} | {diff_str:>10}") else: print(" āœ… All columns match within tolerance") else: print(" šŸ“Š 1D array - no column-by-column analysis needed") # Check for special values - only warn if there's a significant difference nan_gen = np.sum(np.isnan(generated)) nan_ref = np.sum(np.isnan(reference)) if nan_gen > 1000 or nan_ref > 1000: # Only show if significant number of NaNs # Check if NaN counts are very similar (within 1% difference) if nan_gen > 0 and nan_ref > 0: nan_ratio = min(nan_gen, nan_ref) / max(nan_gen, nan_ref) if nan_ratio > 0.99: # NaN counts are essentially identical print(" āœ… Data structure consistency: Identical NaN patterns in generated and reference files") print(f" - Both files have {nan_gen:,} NaN values (excellent consistency)") else: print(" āš ļø Special values detected:") if nan_gen > 1000: print(f" - NaN in generated: {nan_gen:,}") if nan_ref > 1000: print(f" - NaN in reference: {nan_ref:,}") else: print(" āš ļø Special values detected:") if nan_gen > 1000: print(f" - NaN in generated: {nan_gen:,}") if nan_ref > 1000: print(f" - NaN in reference: {nan_ref:,}") def validate_root_summary(llm_content, ref_content): """ Validate root_summary.txt content by checking that all required branch names are present Focus on content (branch names) rather than exact format structure """ try: # Extract all branch names from LLM content llm_branches = set(extract_branch_names(llm_content)) # Required branches that must be present required_branches = { 'SumWeights', 'XSection', 'channelNumber', 'ditau_m', 'eventNumber', 'jet_E', 'jet_MV2c10', 'jet_eta', 'jet_jvt', 'jet_n', 'jet_phi', 'jet_pt', 'jet_pt_syst', 'jet_trueflav', 'jet_truthMatched', 'largeRjet_D2', 'largeRjet_E', 'largeRjet_eta', 'largeRjet_m', 'largeRjet_n', 'largeRjet_phi', 'largeRjet_pt', 'largeRjet_pt_syst', 'largeRjet_tau32', 'largeRjet_truthMatched', 'lep_E', 'lep_charge', 'lep_eta', 'lep_etcone20', 'lep_isTightID', 'lep_n', 'lep_phi', 'lep_pt', 'lep_pt_syst', 'lep_ptcone30', 'lep_trackd0pvunbiased', 'lep_tracksigd0pvunbiased', 'lep_trigMatched', 'lep_truthMatched', 'lep_type', 'lep_z0', 'mcWeight', 'met_et', 'met_et_syst', 'met_phi', 'photon_E', 'photon_convType', 'photon_eta', 'photon_etcone20', 'photon_isTightID', 'photon_n', 'photon_phi', 'photon_pt', 'photon_pt_syst', 'photon_ptcone30', 'photon_trigMatched', 'photon_truthMatched', 'runNumber', 'scaleFactor_BTAG', 'scaleFactor_ELE', 'scaleFactor_LepTRIGGER', 'scaleFactor_MUON', 'scaleFactor_PHOTON', 'scaleFactor_PILEUP', 'scaleFactor_PhotonTRIGGER', 'scaleFactor_TAU', 'tau_BDTid', 'tau_E', 'tau_charge', 'tau_eta', 'tau_isTightID', 'tau_n', 'tau_nTracks', 'tau_phi', 'tau_pt', 'tau_pt_syst', 'tau_trigMatched', 'tau_truthMatched', 'trigE', 'trigM', 'trigP' } print(f" šŸ“Š LLM output has {len(llm_branches)} unique words, Required: {len(required_branches)} branches") # Debug: Show all required branch names found in txt file found_required_branches = required_branches & llm_branches if found_required_branches: sorted_found = sorted(found_required_branches) print(f" šŸ” Required branch names found in txt file: {', '.join(sorted_found)}") # Check if we have any branches at all if len(llm_branches) == 0: print(" āŒ No branches found in LLM output") return False # Check if all required branches are present missing_branches = required_branches - llm_branches if missing_branches: print(f" āŒ Missing {len(missing_branches)} required branches:") for branch in sorted(missing_branches): print(f" - {branch}") return False else: print(" āœ… All required branches present in LLM output") return True except Exception as e: print(f" āŒ Error parsing root_summary: {e}") return False def extract_branch_names(content): """ Extract all words from root_summary.txt content. This approach parses the file into words and checks for branch names as tokens. """ import re # Split content into words using regex to handle various separators # This will capture words with underscores, dots, etc. as single tokens words = re.findall(r'\b\w+\b', content) # Convert to set to remove duplicates and for fast lookup return set(words) def parse_root_summary(content): """ Parse root_summary.txt content into structured data Supports both reference format (File 1:, File 2:, etc.) and LLM format (single file summary) """ files = {} current_file = None lines = content.split('\n') i = 0 while i < len(lines): line = lines[i].strip() # Look for file headers in reference format if line.startswith('File ') and ':' in line: # Extract filename parts = line.split(': ') if len(parts) >= 2: filename = parts[1].strip() current_file = filename files[current_file] = { 'total_objects': 0, 'trees': 0, 'entries': 0, 'total_branches': 0, 'branches': {} } # Look for LLM format header (alternative format) elif line.startswith('Root file: ') and ':' in line: # Extract filename from path parts = line.split(': ') if len(parts) >= 2: full_path = parts[1].strip() filename = os.path.basename(full_path) current_file = filename files[current_file] = { 'total_objects': 1, # Assume 1 tree 'trees': 1, 'entries': 0, # Will be set if found 'total_branches': 0, 'branches': {} } # Parse file data elif current_file and current_file in files: if 'Total objects:' in line: try: files[current_file]['total_objects'] = int(line.split(':')[1].strip()) except Exception: pass elif 'Trees found:' in line: try: files[current_file]['trees'] = int(line.split(':')[1].strip()) except Exception: pass elif 'Entries:' in line: try: files[current_file]['entries'] = int(line.split(':')[1].strip()) except Exception: pass elif 'Common branches (' in line and ')' in line: # Extract total branch count from common branches section try: count_part = line.split('(')[1].split(')')[0] # This sets the total for all files since they're common common_branch_count = int(count_part) # Set this for all existing files for filename in files: files[filename]['total_branches'] = common_branch_count except Exception: pass # Parse branch categories branches = {} j = i + 1 while j < len(lines) and not lines[j].strip().startswith('='): branch_line = lines[j].strip() if ': ' in branch_line: category, branch_list = branch_line.split(': ', 1) category = category.strip().lower() branch_names = [b.strip() for b in branch_list.split(',')] branches[category] = branch_names j += 1 files[current_file]['branches'] = branches i = j - 1 # Skip the lines we already processed # Handle LLM format branch parsing (with - prefix) elif line == 'TTree: mini': # Count branches in LLM format branches = {} branch_lines = [] j = i + 1 while j < len(lines) and lines[j].strip() and not lines[j].strip().startswith('='): branch_line = lines[j].strip() if branch_line.startswith(' Branches:'): # Skip the "Branches:" header j += 1 continue elif branch_line.startswith(' - '): # Extract branch name from "- branch_name" format branch_name = branch_line.replace(' - ', '').strip() branch_lines.append(branch_name) j += 1 # Categorize branches for LLM format photon_branches = [] jet_branches = [] met_branches = [] lep_branches = [] tau_branches = [] event_branches = [] weights_branches = [] for branch in branch_lines: if branch.startswith('photon_'): photon_branches.append(branch) elif branch.startswith('jet_'): jet_branches.append(branch) elif branch.startswith('met_'): met_branches.append(branch) elif branch.startswith('lep_'): lep_branches.append(branch) elif branch.startswith('tau_'): tau_branches.append(branch) elif branch in ['runNumber', 'eventNumber', 'channelNumber', 'mcWeight', 'trigE', 'trigM', 'trigP', 'ditau_m']: event_branches.append(branch) elif branch in ['SumWeights', 'XSection'] or branch.startswith('scaleFactor_') or branch.startswith('largeRjet_'): weights_branches.append(branch) if photon_branches: branches['photon'] = photon_branches if jet_branches: branches['jet'] = jet_branches if met_branches: branches['met'] = met_branches if lep_branches: branches['lep'] = lep_branches if tau_branches: branches['tau'] = tau_branches if event_branches: branches['event'] = event_branches if weights_branches: branches['weights'] = weights_branches files[current_file]['branches'] = branches files[current_file]['total_branches'] = len(branch_lines) i = j - 1 # Skip the lines we already processed i += 1 return files # Load reference solution files for steps 1 and 2 - only load what's needed # This will be done after mode detection below # Load existing reference files for steps 3, 4, 5 signal_soln = np.load('/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/signal.npy') bkgd_soln = np.load('/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/bkgd.npy') signal_scores_soln = np.load('/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/signal_scores.npy') bkgd_scores_soln = np.load('/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/bkgd_scores.npy') boundaries_soln = np.load('/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/boundaries.npy') significances_soln = np.load('/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/significances.npy') base_dir = os.path.join(out_dir, 'arrays') missing_file_1 = False # Step 1: summarize_root files missing_file_2 = False # Step 2: create_numpy files missing_file_3 = False # Step 3: preprocess files missing_file_4 = False # Step 4: scores files missing_file_5 = False # Step 5: categorization files # Step 1: Check summarize_root outputs (file_list.txt, root_summary.txt) if not specific_step or specific_step == 1: file_list_llm_path = os.path.join(out_dir, 'logs', 'file_list.txt') root_summary_llm_path = os.path.join(out_dir, 'logs', 'root_summary.txt') # Note: create_numpy_modified.txt comes from insert_root_summary rule (no LLM), so we don't validate it for step 1 if not (os.path.exists(file_list_llm_path) and os.path.exists(root_summary_llm_path)): if not specific_step or specific_step == 1: print("Step 1 (summarize_root) outputs missing") missing_file_1 = True # Step 2: Check create_numpy outputs (data_A_raw.npy and signal_WH_raw.npy) if not specific_step or specific_step == 2: # Check for the specific files requested: data_A_raw.npy and signal_WH_raw.npy data_A_raw_llm_path = os.path.join(base_dir, 'data_A_raw.npy') signal_WH_raw_llm_path = os.path.join(base_dir, 'signal_WH_raw.npy') if os.path.exists(data_A_raw_llm_path) and os.path.exists(signal_WH_raw_llm_path): data_raw_llm = np.load(data_A_raw_llm_path) signal_raw_llm = np.load(signal_WH_raw_llm_path) if not specific_step or specific_step == 2: print("Found required files: data_A_raw.npy and signal_WH_raw.npy") else: if not specific_step or specific_step == 2: print("Step 2 (create_numpy) outputs missing - data_A_raw.npy and/or signal_WH_raw.npy not found") missing_file_2 = True # Step 3: Check preprocess outputs (signal.npy, bkgd.npy) if not specific_step or specific_step == 3: signal_llm_path = os.path.join(base_dir, 'signal.npy') if os.path.exists(signal_llm_path): signal_llm = np.load(signal_llm_path) else: if not specific_step or specific_step == 3: print("LLM generated signal sample does not exist (Step 3)") missing_file_3 = True bkgd_llm_path = os.path.join(base_dir, 'bkgd.npy') if os.path.exists(bkgd_llm_path): bkgd_llm = np.load(bkgd_llm_path) else: if not specific_step or specific_step == 3: print("LLM generated background sample does not exist (Step 3)") missing_file_3 = True # Step 4: Check scores outputs (signal_scores.npy, bkgd_scores.npy) if not specific_step or specific_step == 4: signal_scores_llm_path = os.path.join(base_dir, 'signal_scores.npy') if os.path.exists(signal_scores_llm_path): signal_scores_llm = np.load(signal_scores_llm_path) else: if not specific_step or specific_step == 4: print("LLM generated signal scores do not exist (Step 4)") missing_file_4 = True bkgd_scores_llm_path = os.path.join(base_dir, 'bkgd_scores.npy') if os.path.exists(bkgd_scores_llm_path): bkgd_scores_llm = np.load(bkgd_scores_llm_path) else: if not specific_step or specific_step == 4: print("LLM generated background scores do not exist (Step 4)") missing_file_4 = True # Step 5: Check categorization outputs (boundaries.npy, significances.npy) if not specific_step or specific_step == 5: boundaries_llm_path = os.path.join(base_dir, 'boundaries.npy') if os.path.exists(boundaries_llm_path): boundaries_llm = np.load(boundaries_llm_path) else: if not specific_step or specific_step == 5: print("LLM generated boundaries do not exist (Step 5)") missing_file_5 = True significances_llm_path = os.path.join(base_dir, 'significances.npy') if os.path.exists(significances_llm_path): significances_llm = np.load(significances_llm_path) else: if not specific_step or specific_step == 5: print("LLM generated significances do not exist (Step 5)") missing_file_5 = True # Step 2: Check create_numpy outputs (data_A_raw.npy and signal_WH_raw.npy) signal_raw_llm_path = os.path.join(base_dir, 'signal_raw.npy') data_raw_llm_path = os.path.join(base_dir, 'data_raw.npy') # Check for the specific files requested: data_A_raw.npy and signal_WH_raw.npy data_A_raw_llm_path = os.path.join(base_dir, 'data_A_raw.npy') signal_WH_raw_llm_path = os.path.join(base_dir, 'signal_WH_raw.npy') if os.path.exists(data_A_raw_llm_path) and os.path.exists(signal_WH_raw_llm_path): data_raw_llm = np.load(data_A_raw_llm_path) signal_raw_llm = np.load(signal_WH_raw_llm_path) else: missing_file_2 = True # Load reference files for Step 2 validation selective_refs_loaded = False standard_refs_loaded = False data_A_raw_soln_path = '/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/data_A_raw.npy' signal_WH_raw_soln_path = '/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/signal_WH_raw.npy' signal_raw_soln_path = '/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/signal_raw.npy' data_raw_soln_path = '/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/data_raw.npy' # Try to load selective reference files first if os.path.exists(data_A_raw_soln_path): data_A_raw_soln = np.load(data_A_raw_soln_path) selective_refs_loaded = True if os.path.exists(signal_WH_raw_soln_path): signal_WH_raw_soln = np.load(signal_WH_raw_soln_path) selective_refs_loaded = True # Also try to load standard reference files if os.path.exists(signal_raw_soln_path): signal_raw_soln = np.load(signal_raw_soln_path) standard_refs_loaded = True if os.path.exists(data_raw_soln_path): data_raw_soln = np.load(data_raw_soln_path) standard_refs_loaded = True # Step 3: Check preprocess outputs (signal.npy, bkgd.npy) signal_llm_path = os.path.join(base_dir, 'signal.npy') if os.path.exists(signal_llm_path): signal_llm = np.load(signal_llm_path) else: missing_file_3 = True bkgd_llm_path = os.path.join(base_dir, 'bkgd.npy') if os.path.exists(bkgd_llm_path): bkgd_llm = np.load(bkgd_llm_path) else: missing_file_3 = True # Step 4: Check scores outputs (signal_scores.npy, bkgd_scores.npy) signal_scores_llm_path = os.path.join(base_dir, 'signal_scores.npy') if os.path.exists(signal_scores_llm_path): signal_scores_llm = np.load(signal_scores_llm_path) else: missing_file_4 = True bkgd_scores_llm_path = os.path.join(base_dir, 'bkgd_scores.npy') if os.path.exists(bkgd_scores_llm_path): bkgd_scores_llm = np.load(bkgd_scores_llm_path) else: missing_file_4 = True # Step 5: Check categorization outputs (boundaries.npy, significances.npy) boundaries_llm_path = os.path.join(base_dir, 'boundaries.npy') if os.path.exists(boundaries_llm_path): boundaries_llm = np.load(boundaries_llm_path) else: missing_file_5 = True significances_llm_path = os.path.join(base_dir, 'significances.npy') if os.path.exists(significances_llm_path): significances_llm = np.load(significances_llm_path) else: missing_file_5 = True """ Plotting and derived checks removed per request: validation for steps 2–5 now does direct array comparisons only (generated vs reference). """ step1_success = False step2_success = False step3_success = False step4_success = False step5_success = False # Step 1 validation (summarize_root outputs) if (not specific_step or specific_step == 1) and not missing_file_1: try: print("=== Step 1 Validation (summarize_root) ===") # Load reference files for comparison ref_file_list_path = '/global/cfs/projectdirs/atlas/dwkim/llm4hep/solution/arrays/file_list.txt' # ref_root_summary_path no longer needed since we don't compare to reference # Load LLM-generated files with open(file_list_llm_path, 'r') as f: file_list_llm = f.read() with open(root_summary_llm_path, 'r') as f: root_summary_llm = f.read() # Standard mode: compare content with reference if os.path.exists(ref_file_list_path): with open(ref_file_list_path, 'r') as f: ref_file_list = f.read() # Extract filenames from both files for comparison # Handle both full paths and just filenames def extract_filenames(content): lines = [line.strip() for line in content.strip().split('\n') if line.strip()] filenames = [] for line in lines: # Extract filename from path or use as-is filename = os.path.basename(line) if '/' in line else line filenames.append(filename) return sorted(filenames) llm_filenames = extract_filenames(file_list_llm) ref_filenames = extract_filenames(ref_file_list) file_list_match = llm_filenames == ref_filenames if not file_list_match: print(f" šŸ“Š LLM files: {len(llm_filenames)} | Reference files: {len(ref_filenames)}") if len(llm_filenames) != len(ref_filenames): print(f" āŒ File count mismatch: {len(llm_filenames)} vs {len(ref_filenames)}") else: # Show first few differences for i, (llm_file, ref_file) in enumerate(zip(llm_filenames, ref_filenames)): if llm_file != ref_file: print(f" āŒ File {i+1} mismatch: '{llm_file}' vs '{ref_file}'") break else: file_list_match = True # No reference to compare # Use detailed root_summary validation # Only check that required branches are present (no reference comparison needed) root_summary_match = validate_root_summary(root_summary_llm, "") step1_success = file_list_match and root_summary_match # Removed duplicate printing - summary will be shown in VALIDATION SUMMARY section except Exception as e: print(f"Error in Step 1 validation: {e}") step1_success = False # Step 2 validation (create_numpy outputs) - direct array comparisons if (not specific_step or specific_step == 2) and not missing_file_2: print("=== Step 2 Validation (create_numpy) ===") # Choose reference arrays: prefer selective names, fallback to standard data_ref = None signal_ref = None if 'data_A_raw_soln' in globals(): data_ref = data_A_raw_soln elif 'data_raw_soln' in globals(): data_ref = data_raw_soln if 'signal_WH_raw_soln' in globals(): signal_ref = signal_WH_raw_soln elif 'signal_raw_soln' in globals(): signal_ref = signal_raw_soln ok_data = False ok_signal = False if data_ref is not None: ok_data = arrays_match(data_raw_llm, data_ref, "data_A_raw.npy (or data_raw.npy)") else: print(" āŒ Missing data reference array (data_A_raw.npy or data_raw.npy)") if signal_ref is not None: ok_signal = arrays_match(signal_raw_llm, signal_ref, "signal_WH_raw.npy (or signal_raw.npy)") else: print(" āŒ Missing signal reference array (signal_WH_raw.npy or signal_raw.npy)") step2_success = ok_data and ok_signal print(f"Step 2 validation: {'PASS' if step2_success else 'FAIL'}") # Step 3 validation (preprocess outputs) - direct array comparisons if (not specific_step or specific_step == 3) and not missing_file_3: print("=== Step 3 Validation (preprocess) ===") ok_signal = arrays_match(signal_llm, signal_soln, "signal.npy") ok_bkgd = arrays_match(bkgd_llm, bkgd_soln, "bkgd.npy") step3_success = ok_signal and ok_bkgd # Step 4 validation (scores) - direct array comparisons if (not specific_step or specific_step == 4) and not missing_file_4: print("=== Step 4 Validation (scores) ===") ok_sig_scores = arrays_match(signal_scores_llm, signal_scores_soln, "signal_scores.npy") ok_bkg_scores = arrays_match(bkgd_scores_llm, bkgd_scores_soln, "bkgd_scores.npy") step4_success = ok_sig_scores and ok_bkg_scores # Step 5 validation (categorization outputs) - direct array comparisons if (not specific_step or specific_step == 5) and not missing_file_5: print("=== Step 5 Validation (categorization) ===") ok_boundaries = arrays_match(boundaries_llm, boundaries_soln, "boundaries.npy") ok_significances = arrays_match(significances_llm, significances_soln, "significances.npy") step5_success = ok_boundaries and ok_significances # Save results success_results = [int(step1_success), int(step2_success), int(step3_success), int(step4_success), int(step5_success)] # np.save('success.npy', success_results) # Removed - results are already printed to console print("\n=== VALIDATION SUMMARY ===") if specific_step: step_names = ["summarize_root", "create_numpy", "preprocess", "scores", "categorization"] step_name = step_names[specific_step - 1] print(f"Step: {specific_step} ({step_name})") if specific_step == 1: print("Files validated:") print(" • file_list.txt - List of processed ROOT files") print(" • root_summary.txt - Branch structure and file metadata") elif specific_step == 2: print("Files validated:") print(" • data_A_raw.npy - Raw data array (must have 46 columns)") print(" • signal_WH_raw.npy - Raw signal array (must have 46 columns)") elif specific_step == 3: print("Files validated:") print(" • signal.npy - Preprocessed signal events") print(" • bkgd.npy - Preprocessed background events") # print("Histograms validated:") # print(" • Signal m_yy histogram (10 bins, 123-127 GeV)") # print(" • Background m_yy histogram (100 bins, 105-160 GeV)") # print(" • Signal leading lepton pT histogram (10 bins, 25-300 GeV)") # print(" • Background leading lepton pT histogram (10 bins, 25-300 GeV)") elif specific_step == 4: print("Files validated:") print(" • signal_scores.npy - Signal event classification scores") print(" • bkgd_scores.npy - Background event classification scores") elif specific_step == 5: print("Files validated:") print(" • boundaries.npy - Category boundary thresholds") print(" • significances.npy - Statistical significance values") else: print("All steps validated") # Mode info removed; direct comparisons are used for all steps # Show only relevant step status if specific_step: step_names = ["summarize_root", "create_numpy", "preprocess", "scores", "categorization"] step_name = step_names[specific_step - 1] if specific_step == 1 and not missing_file_1: status = "PASS" if step1_success else "FAIL" elif specific_step == 2 and not missing_file_2: status = "PASS" if step2_success else "FAIL" elif specific_step == 3 and not missing_file_3: status = "PASS" if step3_success else "FAIL" elif specific_step == 4 and not missing_file_4: status = "PASS" if step4_success else "FAIL" elif specific_step == 5 and not missing_file_5: status = "PASS" if step5_success else "FAIL" else: status = "MISSING" print(f"\nStep {specific_step} ({step_name}): {status}") if status == "PASS": print("āœ… Validation successful") elif status == "FAIL": print("āŒ Validation failed") else: print("āš ļø Step outputs missing") else: # Show all steps for full validation step_status = [] for i, (success, missing) in enumerate([(step1_success, missing_file_1), (step2_success, missing_file_2), (step3_success, missing_file_3), (step4_success, missing_file_4), (step5_success, missing_file_5)], 1): if missing: step_status.append("MISSING") elif success: step_status.append("PASS") else: step_status.append("FAIL") print(f"Step 1 (summarize_root): {step_status[0]}") print(f"Step 2 (create_numpy): {step_status[1]}") print(f"Step 3 (preprocess): {step_status[2]}") print(f"Step 4 (scores): {step_status[3]}") print(f"Step 5 (categorization): {step_status[4]}") # Only count actually validated steps for overall success if specific_step: validated_steps = 1 passed_steps = 1 if success_results[specific_step-1] and not [missing_file_1, missing_file_2, missing_file_3, missing_file_4, missing_file_5][specific_step-1] else 0 print(f"\nResult: {passed_steps}/{validated_steps} step passed") else: validated_steps = sum(1 for missing in [missing_file_1, missing_file_2, missing_file_3, missing_file_4, missing_file_5] if not missing) passed_steps = sum(success_results) print(f"Overall success: {passed_steps}/{validated_steps} validated steps passed") print(f"Success array: {success_results}") # At the end of main script, ensure validation script exits zero so Run_SMK prints PASS/FAIL instead of 'failed to run' sys.exit(0)