robzjgman commited on
Commit
c175766
·
verified ·
1 Parent(s): ee1c2ba

Upload validate_multiGen.py

Browse files
4 _ LLM (Gemini)/aspect-identification/validate_multiGen.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import os
4
+ from sklearn.metrics import accuracy_score, hamming_loss, f1_score
5
+
6
+ def validate_single_aspect(pred_df, gt_df, aspect):
7
+ """Validate a single aspect column"""
8
+ y_pred = pred_df[aspect].fillna('0').astype(str)
9
+ y_true = gt_df[aspect].fillna('0').astype(str)
10
+
11
+ accuracy = accuracy_score(y_true, y_pred)
12
+
13
+ print(f"\n=== {aspect.upper()} ASPECT ===")
14
+ print(f"Accuracy: {accuracy:.4f}")
15
+
16
+ return {
17
+ 'aspect': aspect,
18
+ 'accuracy': accuracy
19
+ }
20
+
21
+ def calculate_exact_match_metrics(pred_df, gt_df, aspects):
22
+ """Calculate exact set matching metrics and hamming loss"""
23
+ correct_samples = 0
24
+ total_samples = len(pred_df)
25
+
26
+ # For precision, recall, F1 - treat each sample as binary (all correct vs not all correct)
27
+ y_true_binary = []
28
+ y_pred_binary = []
29
+
30
+ # For hamming loss calculation
31
+ y_true_matrix = []
32
+ y_pred_matrix = []
33
+
34
+ for i in range(total_samples):
35
+ # Check if all aspects match for this sample
36
+ all_correct = True
37
+ sample_true = []
38
+ sample_pred = []
39
+
40
+ for aspect in aspects:
41
+ pred_val = str(pred_df.loc[i, aspect]) if pd.notna(pred_df.loc[i, aspect]) else '0'
42
+ true_val = str(gt_df.loc[i, aspect]) if pd.notna(gt_df.loc[i, aspect]) else '0'
43
+
44
+ # Convert to binary for hamming loss
45
+ sample_true.append(1 if true_val != '0' else 0)
46
+ sample_pred.append(1 if pred_val != '0' else 0)
47
+
48
+ if pred_val != true_val:
49
+ all_correct = False
50
+
51
+ if all_correct:
52
+ correct_samples += 1
53
+
54
+ # Add to matrices for hamming loss
55
+ y_true_matrix.append(sample_true)
56
+ y_pred_matrix.append(sample_pred)
57
+
58
+ # binary classification metrics (1 = all correct, 0 = not all correct)
59
+ y_true_binary.append(1) # Ground truth is always "all should be correct"
60
+ y_pred_binary.append(1 if all_correct else 0) # Prediction success
61
+
62
+ # Calculate metrics
63
+ exact_match_accuracy = correct_samples / total_samples
64
+
65
+ # Calculate hamming loss
66
+ h_loss = hamming_loss(y_true_matrix, y_pred_matrix)
67
+
68
+ return exact_match_accuracy, correct_samples, total_samples, h_loss, y_pred_matrix, y_true_matrix
69
+
70
+ def get_true_pred_aspects(pred_df: pd.DataFrame, gt_df: pd.DataFrame, aspect: str) -> list:
71
+ result = []
72
+ has_text = 'Review' in gt_df.columns
73
+
74
+ for i in range(len(pred_df)):
75
+ pred_val = str(pred_df.loc[i, aspect]).strip().lower() if pd.notna(pred_df.loc[i, aspect]) else '0'
76
+ true_val = str(gt_df.loc[i, aspect]).strip().lower() if pd.notna(gt_df.loc[i, aspect]) else '0'
77
+
78
+ predicted_binary = 1 if pred_val != '0' else 0
79
+ actual_binary = 1 if true_val != '0' else 0
80
+
81
+ sample_data = {
82
+ 'predicted': predicted_binary,
83
+ 'actual': actual_binary,
84
+ 'predicted_value': pred_val,
85
+ 'actual_value': true_val,
86
+ 'index': i
87
+ }
88
+
89
+ if has_text:
90
+ # 'Review' from gt_df
91
+ sample_data['Review'] = str(gt_df.loc[i, 'Review'])
92
+
93
+ result.append(sample_data)
94
+
95
+ return result
96
+
97
+ def identification_error_analysis(pred_df: pd.DataFrame, gt_df: pd.DataFrame, aspects: list) -> dict:
98
+ """Analyze common identification errors for all aspects."""
99
+ analysis = {
100
+ 'aspect': {}
101
+ }
102
+
103
+ for aspect in aspects:
104
+ if aspect not in pred_df.columns or aspect not in gt_df.columns:
105
+ continue
106
+
107
+ results = get_true_pred_aspects(pred_df, gt_df, aspect)
108
+
109
+ fp = [r for r in results if r['predicted'] == 1 and r['actual'] == 0] # False Positives (FP): Predicted 1, Actual 0 (Aspect *wrongly* identified)
110
+ fn = [r for r in results if r['predicted'] == 0 and r['actual'] == 1] # False Negatives (FN): Predicted 0, Actual 1 (Aspect *missed*)
111
+
112
+ tp = [r for r in results if r['predicted'] == 1 and r['actual'] == 1] # True Positives (TP): Predicted 1, Actual 1
113
+ tn = [r for r in results if r['predicted'] == 0 and r['actual'] == 0] # True Negatives (TN): Predicted 0, Actual 0
114
+
115
+ precision = len(tp) / (len(tp) + len(fp)) if (len(tp) + len(fp)) > 0 else 0.0 # Precision = TP / (TP + FP)
116
+ recall = len(tp) / (len(tp) + len(fn)) if (len(tp) + len(fn)) > 0 else 0.0 # Recall = TP / (TP + FN)
117
+ f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0 # F1 Score = 2 * (Precision * Recall) / (Precision + Recall)
118
+
119
+ analysis['aspect'][aspect] = {
120
+ 'true_positives': len(tp),
121
+ 'true_negatives': len(tn),
122
+ 'false_positives': len(fp),
123
+ 'false_negatives': len(fn),
124
+ 'precision': precision,
125
+ 'recall': recall,
126
+ 'f1_score': f1_score,
127
+ 'fp_examples': fp[:5], # Top 5 examples
128
+ 'fn_examples': fn[:5] # Top 5 examples
129
+ }
130
+
131
+ return analysis
132
+
133
+ def save_error_analysis(analysis: dict, analysis_file: str):
134
+ """Save error analysis results to a file."""
135
+ results_text = ["Error Analysis: Aspect Identification\n" + "="*50 + "\n"]
136
+
137
+ for aspect, data in analysis['aspect'].items():
138
+ results_text.append(f"\n--- {aspect.upper()} ASPECT ---\n")
139
+ results_text.append(f"Precision: {data['precision']:.4f}")
140
+ results_text.append(f"Recall: {data['recall']:.4f}")
141
+ results_text.append(f"F1: {data['f1_score']:.4f}")
142
+ results_text.append(f"True Positives (TP): {data['true_positives']}")
143
+ results_text.append(f"False Positives (FP - Aspect *wrongly* identified): {data['false_positives']}")
144
+ results_text.append(f"False Negatives (FN - Aspect *missed*): {data['false_negatives']}")
145
+ results_text.append(f"True Negatives (TN): {data['true_negatives']}")
146
+
147
+ # FP Examples
148
+ results_text.append("\nTOP 5 FALSE POSITIVE EXAMPLES (Model identified, but Ground Truth said '0'):")
149
+ for i, fp_ex in enumerate(data['fp_examples']):
150
+ text = fp_ex.get('Review', f"[Review text not available, index: {fp_ex['index']}]")
151
+ results_text.append(f" {i+1}. Pred Val: '{fp_ex['predicted_value']}'. Text: \"{text[:100]}...\"")
152
+
153
+ # FN Examples
154
+ results_text.append("\nTOP 5 FALSE NEGATIVE EXAMPLES (Model missed, but Ground Truth said *a value*):")
155
+ for i, fn_ex in enumerate(data['fn_examples']):
156
+ text = fn_ex.get('Review', f"[Review text not available, index: {fn_ex['index']}]")
157
+ results_text.append(f" {i+1}. Actual Val: '{fn_ex['actual_value']}'. Text: \"{text[:100]}...\"")
158
+
159
+ # Save results to text file
160
+ with open(analysis_file, 'w', encoding='utf-8') as f:
161
+ f.write('\n'.join(results_text))
162
+ print(f"\nError analysis has been saved to {analysis_file}")
163
+
164
+ def save_result_txt(results: dict, results_file: str):
165
+ # Save results to text file
166
+ with open(results_file, 'w', encoding='utf-8') as f:
167
+ f.write('\n'.join(results['results_text']))
168
+ print(f"\nResults saved to {results_file}")
169
+
170
+ def validate_all_aspects(predicted_file: str, ground_truth_file: str, aspects: list,
171
+ results_file: str, error_analysis_file: str) -> dict:
172
+ """Main validation function"""
173
+ # Load data
174
+ pred_df = pd.read_csv(predicted_file)
175
+ gt_df = pd.read_csv(ground_truth_file)
176
+
177
+ print(f"Predicted data shape: {pred_df.shape}")
178
+ print(f"Ground truth data shape: {gt_df.shape}")
179
+
180
+ # Check if dataframes have the same length before proceeding
181
+ if len(pred_df) != len(gt_df):
182
+ print("ERROR: Predicted and Ground Truth files have different number of rows.")
183
+ return {}
184
+
185
+ # Store results for text file
186
+ results_text = []
187
+ results_text.append(f"Validation Results\n{'='*50}\n")
188
+
189
+ # Validate each aspect
190
+ aspect_results = []
191
+
192
+ for aspect in aspects:
193
+ if aspect in pred_df.columns and aspect in gt_df.columns:
194
+ result = validate_single_aspect(pred_df, gt_df, aspect)
195
+ aspect_results.append(result)
196
+ results_text.append(f"\n{aspect.upper()} ASPECT")
197
+ results_text.append(f"Accuracy: {result['accuracy']:.4f}")
198
+ else:
199
+ print(f"WARNING: '{aspect}' column not found in both files")
200
+ results_text.append(f"\nWARNING: '{aspect}' column not found in both files")
201
+
202
+ # Combined metrics
203
+ valid_aspects = [aspect for aspect in aspects
204
+ if aspect in pred_df.columns and aspect in gt_df.columns]
205
+
206
+ if valid_aspects:
207
+ combined_accuracy, correct_count, total_count, hamming_loss_score, y_true_matrix, y_pred_matrix = \
208
+ calculate_exact_match_metrics(pred_df, gt_df, valid_aspects)
209
+
210
+ if y_true_matrix:
211
+ # Calculate micro and macro F1 scores
212
+ micro_f1 = f1_score(y_true_matrix, y_pred_matrix, average='micro')
213
+ macro_f1 = f1_score(y_true_matrix, y_pred_matrix, average='macro')
214
+
215
+ results_text.append(f"\n{'='*50}")
216
+ results_text.append("EXACT MATCH (ALL ASPECTS)")
217
+ results_text.append(f"{'='*50}")
218
+ results_text.append(f"Samples with ALL aspects correct: {correct_count}/{total_count}")
219
+ results_text.append(f"Accuracy: {combined_accuracy:.4f}")
220
+ results_text.append(f"Hamming Loss: {hamming_loss_score:.4f}")
221
+ results_text.append(f"Micro F1 Score (Multi-Aspect): {micro_f1:.4f}")
222
+ results_text.append(f"Macro F1 Score (Multi-Aspect): {macro_f1:.4f}")
223
+
224
+ save_result_txt({'results_text': results_text}, results_file)
225
+
226
+ # --- Error Analysis ---
227
+ if valid_aspects:
228
+ error_analysis_results = identification_error_analysis(pred_df, gt_df, valid_aspects)
229
+ save_error_analysis(error_analysis_results, error_analysis_file)
230
+
231
+ return {
232
+ 'results_text': results_text,
233
+ 'aspect_results': aspect_results,
234
+ 'combined_accuracy': combined_accuracy,
235
+ 'correct_count': correct_count,
236
+ 'total_count': total_count,
237
+ 'hamming_loss': hamming_loss_score,
238
+ 'micro_f1': micro_f1,
239
+ 'macro_f1': macro_f1
240
+ }
241
+
242
+ def calculate_overall_performance(general_aspect_mapping: dict, error_analysis_files: dict) -> dict:
243
+ """Calculate overall performance metrics for each general aspect group.
244
+
245
+ Args:
246
+ general_aspect_mapping: Dictionary mapping general aspects to their specific aspects
247
+ error_analysis_files: Dictionary mapping general aspects to their error analysis file paths
248
+
249
+ Returns:
250
+ Dictionary containing aggregated metrics for each general aspect
251
+ """
252
+ overall_results = {}
253
+
254
+ for general_aspect, specific_aspects in general_aspect_mapping.items():
255
+ specific_aspects = [aspect.lower() for aspect in specific_aspects]
256
+ # Initialize counters for this general aspect
257
+ total_tp = 0
258
+ total_fp = 0
259
+ total_tn = 0
260
+ total_fn = 0
261
+
262
+ # Load error analysis file for this general aspect
263
+ error_analysis_file = error_analysis_files[general_aspect]
264
+ with open(error_analysis_file, 'r', encoding='utf-8') as f:
265
+ lines = f.readlines()
266
+
267
+ # Process each specific aspect's metrics
268
+ current_aspect = None
269
+ for line in lines:
270
+ line = line.strip()
271
+ if line.startswith('---') and line.endswith('---') and 'ASPECT' in line:
272
+ # Extract aspect name and clean it, removing 'ASPECT' and dashes
273
+ current_aspect = line.replace('-', '').replace('ASPECT', '').strip().lower()
274
+ continue
275
+
276
+ if current_aspect in specific_aspects:
277
+ if 'True Positives (TP):' in line:
278
+ total_tp += int(line.split(':')[1].strip())
279
+ elif 'False Positives (FP' in line:
280
+ total_fp += int(line.split(':')[1].strip())
281
+ elif 'False Negatives (FN' in line:
282
+ total_fn += int(line.split(':')[1].strip())
283
+ elif 'True Negatives (TN):' in line:
284
+ total_tn += int(line.split(':')[1].strip())
285
+
286
+ # Calculate overall metrics for this general aspect
287
+ precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
288
+ recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
289
+ f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
290
+ accuracy = (total_tp + total_tn) / (total_tp + total_tn + total_fp + total_fn) if (total_tp + total_tn + total_fp + total_fn) > 0 else 0.0
291
+
292
+ overall_results[general_aspect] = {
293
+ 'true_positives': total_tp,
294
+ 'false_positives': total_fp,
295
+ 'true_negatives': total_tn,
296
+ 'false_negatives': total_fn,
297
+ 'precision': precision,
298
+ 'recall': recall,
299
+ 'f1_score': f1_score,
300
+ 'accuracy': accuracy,
301
+ 'specific_aspects': specific_aspects
302
+ }
303
+
304
+ return overall_results
305
+
306
+ def save_overall_results(results: dict, output_file: str):
307
+ """Save overall performance results to a file."""
308
+ with open(output_file, 'w', encoding='utf-8') as f:
309
+ f.write("Overall Performance by General Aspect\n")
310
+ f.write("=" * 50 + "\n\n")
311
+
312
+ for general_aspect, metrics in results.items():
313
+ f.write(f"=== {general_aspect.upper()} ===\n")
314
+ f.write(f"Specific aspects included: {', '.join(metrics['specific_aspects'])}\n\n")
315
+ f.write(f"Aggregated Metrics:\n")
316
+ f.write(f"True Positives (TP): {metrics['true_positives']}\n")
317
+ f.write(f"False Positives (FP): {metrics['false_positives']}\n")
318
+ f.write(f"True Negatives (TN): {metrics['true_negatives']}\n")
319
+ f.write(f"False Negatives (FN): {metrics['false_negatives']}\n")
320
+ f.write(f"Accuracy: {metrics['accuracy']:.4f}\n")
321
+ f.write(f"Precision: {metrics['precision']:.4f}\n")
322
+ f.write(f"Recall: {metrics['recall']:.4f}\n")
323
+ f.write(f"F1 Score: {metrics['f1_score']:.4f}\n\n")
324
+
325
+ print(f"Overall results saved to {output_file}")
326
+
327
+ # Example usage:
328
+ # general_aspect_mapping = {
329
+ # 'price': ['price_value', 'price_comparison', 'price_discount'],
330
+ # 'quality': ['quality_material', 'quality_durability', 'quality_defects']
331
+ # }
332
+ #
333
+ # error_analysis_files = {
334
+ # 'price': 'results/price_error_analysis.txt',
335
+ # 'quality': 'results/quality_error_analysis.txt'
336
+ # }
337
+ #
338
+ # results = calculate_overall_performance(general_aspect_mapping, error_analysis_files)
339
+ # save_overall_results(results, 'results/overall_performance.txt')