Respair commited on
Commit
7e5c8d4
·
verified ·
1 Parent(s): acd9652

Create batch_infer.py

Browse files
Files changed (1) hide show
  1. batch_infer.py +488 -0
batch_infer.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from concurrent.futures import ProcessPoolExecutor, as_completed
2
+ # import time
3
+ # from datetime import timedelta
4
+ # import pandas as pd
5
+ # import torch
6
+ # import warnings
7
+ # import logging
8
+ # import os
9
+ # import traceback
10
+
11
+ # # --- Load and filter dataframe ---
12
+ # df = pd.read_csv("/home/ubuntu/ttsar/ASR_DATA/train_large.csv")
13
+ # print('before filtering: ')
14
+ # print(df.shape)
15
+
16
+ # df = df[~df['filename'].str.contains("Sakura, Moyu")]
17
+ # print('after filtering: ')
18
+ # print(df.shape)
19
+
20
+ # total_samples = len(df)
21
+
22
+ # # --- PyTorch settings ---
23
+ # torch.set_float32_matmul_precision('high')
24
+ # torch.backends.cuda.matmul.allow_tf32 = True
25
+ # torch.backends.cudnn.allow_tf32 = True
26
+
27
+ # def process_batch(batch_data):
28
+ # """Process a batch of audio files"""
29
+ # batch_id, start_idx, audio_files, config_path, checkpoint_path = batch_data
30
+
31
+ # model = None # Initialize model to None for the finally block
32
+ # try:
33
+ # # Import and configure libraries within the worker process
34
+ # import torch
35
+ # import nemo.collections.asr as nemo_asr
36
+ # from omegaconf import OmegaConf, open_dict
37
+ # import warnings
38
+ # import logging
39
+
40
+ # # Suppress logs within the worker process to keep the main output clean
41
+ # logging.getLogger('nemo_logger').setLevel(logging.ERROR)
42
+ # logging.disable(logging.CRITICAL)
43
+ # warnings.filterwarnings('ignore')
44
+
45
+ # # Load model for this worker
46
+ # config = OmegaConf.load(config_path)
47
+ # with open_dict(config.cfg):
48
+ # for ds in ['train_ds', 'validation_ds', 'test_ds']:
49
+ # if ds in config.cfg:
50
+ # config.cfg[ds].defer_setup = True
51
+
52
+ # model = nemo_asr.models.EncDecMultiTaskModel(cfg=config.cfg)
53
+ # checkpoint = torch.load(checkpoint_path, map_location='cuda', weights_only=False)
54
+ # model.load_state_dict(checkpoint['state_dict'], strict=False)
55
+ # model = model.eval().cuda()
56
+
57
+ # decode_cfg = model.cfg.decoding
58
+ # decode_cfg.beam.beam_size = 4
59
+ # model.change_decoding_strategy(decode_cfg)
60
+
61
+ # # Transcribe
62
+ # start = time.time()
63
+ # hypotheses = model.transcribe(
64
+ # audio=audio_files,
65
+ # batch_size=64,
66
+ # source_lang='ja',
67
+ # target_lang='ja',
68
+ # task='asr',
69
+ # pnc='no',
70
+ # verbose=False,
71
+ # num_workers=0,
72
+ # channel_selector=0
73
+ # )
74
+
75
+ # results = [hyp.text for hyp in hypotheses]
76
+
77
+
78
+
79
+ # return batch_id, start_idx, results, len(audio_files), time.time() - start
80
+ # finally:
81
+ # # NEW: Ensure GPU memory is cleared in the worker process
82
+ # if model is not None:
83
+ # del model
84
+ # import torch
85
+ # torch.cuda.empty_cache()
86
+
87
+ # # --- Parameters ---
88
+ # chunk_size = 512 * 4
89
+ # n_workers = 4
90
+ # checkpoint_interval = 250_000
91
+
92
+ # config_path = "/home/ubuntu/NeMo_Canary/canary_results/Higurashi_ASR_v.02/version_4/hparams.yaml"
93
+ # checkpoint_path = "/home/ubuntu/NeMo_Canary/canary_results/Higurashi_ASR_v.02_plus/checkpoints/Higurashi_ASR_v.02_plus--step=174650.0000-epoch=8-last.ckpt"
94
+
95
+ # # --- Prepare data chunks ---
96
+ # audio_files = df['filename'].tolist()
97
+ # chunks = []
98
+ # for i in range(0, total_samples, chunk_size):
99
+ # end_idx = min(i + chunk_size, total_samples)
100
+ # chunk_files = audio_files[i:end_idx]
101
+ # chunks.append({
102
+ # 'batch_id': len(chunks),
103
+ # 'start_idx': i,
104
+ # 'files': chunk_files,
105
+ # 'config_path': config_path,
106
+ # 'checkpoint_path': checkpoint_path
107
+ # })
108
+
109
+ # print(f"Processing {total_samples:,} samples")
110
+ # print(f"Chunks: {len(chunks)} × ~{chunk_size} samples")
111
+ # print(f"Workers: {n_workers}")
112
+ # print(f"Checkpoint interval: every {checkpoint_interval:,} samples")
113
+ # print("-" * 50)
114
+
115
+ # # --- Initialize tracking variables ---
116
+ # all_results = {}
117
+ # failed_chunks = []
118
+ # start_time = time.time()
119
+ # samples_done = 0
120
+ # last_checkpoint = 0
121
+ # interrupted = False
122
+
123
+ # # Initialize 'text' column with a placeholder
124
+ # df['text'] = pd.NA
125
+
126
+ # # --- Main Processing Loop with Graceful Shutdown ---
127
+ # try:
128
+ # with ProcessPoolExecutor(max_workers=n_workers) as executor:
129
+ # future_to_chunk = {
130
+ # executor.submit(process_batch,
131
+ # (chunk['batch_id'], chunk['start_idx'], chunk['files'], chunk['config_path'], chunk['checkpoint_path'])): chunk
132
+ # for chunk in chunks
133
+ # }
134
+
135
+ # for future in as_completed(future_to_chunk):
136
+ # original_chunk = future_to_chunk[future]
137
+ # batch_id = original_chunk['batch_id']
138
+
139
+ # try:
140
+ # _batch_id, start_idx, results, count, batch_time = future.result()
141
+
142
+ # all_results[start_idx] = results
143
+ # samples_done += count
144
+
145
+ # end_idx = start_idx + len(results)
146
+ # if len(df.iloc[start_idx:end_idx]) == len(results):
147
+ # df.loc[start_idx:end_idx-1, 'text'] = results
148
+ # else:
149
+ # raise ValueError(f"Length mismatch: DataFrame slice vs results")
150
+
151
+ # elapsed = time.time() - start_time
152
+ # speed = samples_done / elapsed if elapsed > 0 else 0
153
+ # remaining = total_samples - samples_done
154
+ # eta = remaining / speed if speed > 0 else 0
155
+
156
+ # print(f"✓ Batch {batch_id}/{len(chunks)-1} done ({count} samples in {batch_time:.1f}s) | "
157
+ # f"Total: {samples_done:,}/{total_samples:,} ({100*samples_done/total_samples:.1f}%) | "
158
+ # f"Speed: {speed:.1f} samples/s | "
159
+ # f"ETA: {timedelta(seconds=int(eta))}")
160
+
161
+ # if samples_done - last_checkpoint >= checkpoint_interval or samples_done == total_samples:
162
+ # checkpoint_file = f"/home/ubuntu/ttsar/ASR_DATA/transcribed_checkpoint_{samples_done}.csv"
163
+ # df.to_csv(checkpoint_file, index=False)
164
+ # print(f" ✓ Checkpoint saved: {checkpoint_file}")
165
+ # last_checkpoint = samples_done
166
+
167
+ # except Exception:
168
+ # failed_chunks.append(original_chunk)
169
+ # print("-" * 20 + " ERROR " + "-" * 20)
170
+ # print(f"✗ Batch {batch_id} FAILED. Start index: {original_chunk['start_idx']}. Files: {len(original_chunk['files'])}")
171
+ # traceback.print_exc()
172
+ # print("-" * 47)
173
+
174
+ # except KeyboardInterrupt:
175
+ # interrupted = True
176
+ # print("\n\n" + "="*50)
177
+ # print("! KEYBOARD INTERRUPT DETECTED !")
178
+ # print("Stopping workers and saving all completed progress...")
179
+ # print("The script will exit shortly.")
180
+ # print("="*50 + "\n")
181
+ # # The `with ProcessPoolExecutor` context manager will automatically
182
+ # # handle shutting down the worker processes when we exit this block.
183
+
184
+ # # --- Finalization and Reporting (this block now runs on completion OR interruption) ---
185
+ # total_time = time.time() - start_time
186
+ # print("-" * 50)
187
+ # if interrupted:
188
+ # print(f"PROCESS INTERRUPTED")
189
+ # else:
190
+ # print(f"TRANSCRIPTION COMPLETE!")
191
+
192
+ # print(f"Total time elapsed: {timedelta(seconds=int(total_time))}")
193
+ # if total_time > 0 and samples_done > 0:
194
+ # print(f"Average speed (on completed work): {samples_done/total_time:.1f} samples/second")
195
+
196
+ # # Save final result
197
+ # final_output = "/home/ubuntu/ttsar/ASR_DATA/transcribed_manifest_final.csv"
198
+ # df.to_csv(final_output, index=False)
199
+ # print(f"Final progress saved to: {final_output}")
200
+ # print("-" * 50)
201
+
202
+ # # --- Summary and Verification ---
203
+ # successful_transcriptions = df['text'].notna().sum()
204
+ # print("Final Run Summary:")
205
+ # print(f" - Successfully transcribed: {successful_transcriptions:,} samples")
206
+ # print(f" - Failed batches: {len(failed_chunks)}")
207
+ # print(f" - Total samples in failed batches: {sum(len(c['files']) for c in failed_chunks):,}")
208
+
209
+ # if failed_chunks:
210
+ # failed_files_path = "/home/ubuntu/ttsar/ASR_DATA/failed_transcription_files.txt"
211
+ # with open(failed_files_path, 'w') as f:
212
+ # for chunk in failed_chunks:
213
+ # for file_path in chunk['files']:
214
+ # f.write(f"{file_path}\n")
215
+ # print(f"\nList of files from failed batches saved to: {failed_files_path}")
216
+
217
+ # print("-" * 50)
218
+
219
+
220
+ #NOTE #NOTE
221
+
222
+
223
+ from concurrent.futures import ProcessPoolExecutor, as_completed
224
+ import time
225
+ from datetime import timedelta
226
+ import pandas as pd
227
+ import torch
228
+ import warnings
229
+ import logging
230
+ import os
231
+ import traceback
232
+
233
+ # --- LOAD CHECKPOINT ---
234
+ checkpoint_file = "/home/ubuntu/ttsar/csv_kanad/sing/cg_shani_sing.csv"
235
+ print(f"Loading checkpoint from: {checkpoint_file}")
236
+ df = pd.read_csv(checkpoint_file)
237
+ print(f"Checkpoint loaded. Shape: {df.shape}")
238
+
239
+ # Check if 'text' column exists, if not create it
240
+ if 'text' not in df.columns:
241
+ df['text'] = pd.NA
242
+
243
+ # --- FIND ALL MISSING TRANSCRIPTIONS ---
244
+ missing_mask = df['text'].isna()
245
+ missing_indices = df[missing_mask].index.tolist()
246
+ already_done = (~missing_mask).sum()
247
+
248
+ print(f"Already transcribed: {already_done:,} samples")
249
+ print(f"Missing transcriptions: {len(missing_indices):,} samples")
250
+ print("-" * 50)
251
+
252
+ if len(missing_indices) == 0:
253
+ print("All samples already transcribed!")
254
+ exit(0)
255
+
256
+ # --- PyTorch settings ---
257
+ torch.set_float32_matmul_precision('high')
258
+ torch.backends.cuda.matmul.allow_tf32 = True
259
+ torch.backends.cudnn.allow_tf32 = True
260
+
261
+ def process_batch(batch_data):
262
+ """Process a batch of audio files"""
263
+ batch_id, indices, audio_files, config_path, checkpoint_path = batch_data
264
+
265
+ model = None
266
+ try:
267
+ # Import and configure libraries within the worker process
268
+ import torch
269
+ import nemo.collections.asr as nemo_asr
270
+ from omegaconf import OmegaConf, open_dict
271
+ import warnings
272
+ import logging
273
+
274
+ # Suppress logs within the worker process
275
+ logging.getLogger('nemo_logger').setLevel(logging.ERROR)
276
+ logging.disable(logging.CRITICAL)
277
+ warnings.filterwarnings('ignore')
278
+
279
+ # Load model for this worker
280
+ config = OmegaConf.load(config_path)
281
+ with open_dict(config.cfg):
282
+ for ds in ['train_ds', 'validation_ds', 'test_ds']:
283
+ if ds in config.cfg:
284
+ config.cfg[ds].defer_setup = True
285
+
286
+ model = nemo_asr.models.EncDecMultiTaskModel(cfg=config.cfg)
287
+ checkpoint = torch.load(checkpoint_path, map_location='cuda', weights_only=False)
288
+ model.load_state_dict(checkpoint['state_dict'], strict=False)
289
+ model = model.eval().cuda().bfloat16()
290
+
291
+ decode_cfg = model.cfg.decoding
292
+ decode_cfg.beam.beam_size = 1
293
+ model.change_decoding_strategy(decode_cfg)
294
+
295
+ # Transcribe
296
+ start = time.time()
297
+ try:
298
+ hypotheses = model.transcribe(
299
+ audio=audio_files,
300
+ batch_size=64,
301
+ source_lang='ja',
302
+ target_lang='ja',
303
+ task='asr',
304
+ pnc='no',
305
+ verbose=False,
306
+ num_workers=0,
307
+ channel_selector=0
308
+ )
309
+ results = [hyp.text for hyp in hypotheses]
310
+ except Exception as e:
311
+ print(f"Transcription error in batch {batch_id}: {str(e)}")
312
+ # Return empty results list on transcription failure
313
+ results = []
314
+
315
+ # Pad results with None if we got fewer results than expected
316
+ while len(results) < len(audio_files):
317
+ results.append(None)
318
+
319
+ # Count successful transcriptions
320
+ success_count = len([r for r in results if r is not None])
321
+
322
+ # Return indices and results as a tuple for pairing
323
+ return batch_id, list(zip(indices, results)), success_count, time.time() - start
324
+
325
+ finally:
326
+ if model is not None:
327
+ del model
328
+ import torch
329
+ torch.cuda.empty_cache()
330
+
331
+ # --- Parameters ---
332
+ chunk_size = 512 * 4 # 2048
333
+ n_workers = 6
334
+ checkpoint_interval = 250_000
335
+
336
+ config_path = "/home/ubuntu/NeMo_Canary/canary_results/Higurashi_ASR_v.02/version_4/hparams.yaml"
337
+ checkpoint_path = "/home/ubuntu/NeMo_Canary/canary_results/Higurashi_ASR_v.02_plus/checkpoints/Higurashi_ASR_v.02_plus--step=174650.0000-epoch=8-last.ckpt"
338
+
339
+ # --- Create batches from missing indices ---
340
+ chunks = []
341
+ for i in range(0, len(missing_indices), chunk_size):
342
+ batch_indices = missing_indices[i:i+chunk_size]
343
+ batch_files = df.loc[batch_indices, 'filename'].tolist()
344
+
345
+ chunks.append({
346
+ 'batch_id': len(chunks),
347
+ 'indices': batch_indices,
348
+ 'files': batch_files,
349
+ 'config_path': config_path,
350
+ 'checkpoint_path': checkpoint_path
351
+ })
352
+
353
+ print(f"Total batches to process: {len(chunks)}")
354
+ print(f"Batch size: ~{chunk_size} samples")
355
+ print(f"Workers: {n_workers}")
356
+ print(f"Checkpoint interval: every {checkpoint_interval:,} samples")
357
+ print("-" * 50)
358
+
359
+ # --- Initialize tracking variables ---
360
+ all_results = {}
361
+ failed_chunks = []
362
+ failed_files_list = []
363
+ start_time = time.time()
364
+ samples_done = 0
365
+ samples_failed = 0
366
+ last_checkpoint = 0
367
+ interrupted = False
368
+ total_to_process = len(missing_indices)
369
+
370
+ # --- Main Processing Loop ---
371
+ try:
372
+ with ProcessPoolExecutor(max_workers=n_workers) as executor:
373
+ future_to_chunk = {
374
+ executor.submit(process_batch,
375
+ (chunk['batch_id'], chunk['indices'], chunk['files'],
376
+ chunk['config_path'], chunk['checkpoint_path'])): chunk
377
+ for chunk in chunks
378
+ }
379
+
380
+ for future in as_completed(future_to_chunk):
381
+ original_chunk = future_to_chunk[future]
382
+ batch_id = original_chunk['batch_id']
383
+
384
+ try:
385
+ _batch_id, index_result_pairs, success_count, batch_time = future.result()
386
+
387
+ # Update DataFrame with results
388
+ failed_in_batch = 0
389
+ for idx, result in index_result_pairs:
390
+ if result is not None:
391
+ df.loc[idx, 'text'] = result
392
+ else:
393
+ df.loc[idx, 'text'] = "[FAILED]"
394
+ failed_in_batch += 1
395
+ failed_files_list.append(df.loc[idx, 'filename'])
396
+
397
+ samples_done += success_count
398
+ samples_failed += failed_in_batch
399
+
400
+ elapsed = time.time() - start_time
401
+ speed = samples_done / elapsed if elapsed > 0 else 0
402
+ remaining = total_to_process - samples_done - samples_failed
403
+ eta = remaining / speed if speed > 0 else 0
404
+
405
+ current_total = already_done + samples_done
406
+
407
+ status = f"✓ Batch {batch_id}/{len(chunks)-1} done ({success_count} success"
408
+ if failed_in_batch > 0:
409
+ status += f", {failed_in_batch} failed"
410
+ status += f" in {batch_time:.1f}s)"
411
+
412
+ print(f"{status} | "
413
+ f"Processed: {samples_done:,}/{total_to_process:,} | "
414
+ f"Total: {current_total:,}/{len(df):,} ({100*current_total/len(df):.1f}%) | "
415
+ f"Speed: {speed:.1f} samples/s | "
416
+ f"ETA: {timedelta(seconds=int(eta))}")
417
+
418
+ # Save checkpoint
419
+ if samples_done - last_checkpoint >= checkpoint_interval or (samples_done + samples_failed) >= total_to_process:
420
+ checkpoint_file = f"/home/ubuntu/ttsar/ASR_DATA/transcribed_checkpoint_{current_total}.csv"
421
+ df.to_csv(checkpoint_file, index=False)
422
+ print(f" ✓ Checkpoint saved: {checkpoint_file}")
423
+ last_checkpoint = samples_done
424
+
425
+ except Exception as e:
426
+ failed_chunks.append(original_chunk)
427
+ print("-" * 20 + " ERROR " + "-" * 20)
428
+ print(f"✗ Batch {batch_id} FAILED. Indices count: {len(original_chunk['indices'])}")
429
+ print(f"Error: {str(e)}")
430
+ traceback.print_exc()
431
+ print("-" * 47)
432
+
433
+ except KeyboardInterrupt:
434
+ interrupted = True
435
+ print("\n\n" + "="*50)
436
+ print("! KEYBOARD INTERRUPT DETECTED !")
437
+ print("Stopping workers and saving progress...")
438
+ print("="*50 + "\n")
439
+
440
+ # --- Finalization ---
441
+ total_time = time.time() - start_time
442
+ print("-" * 50)
443
+ if interrupted:
444
+ print(f"PROCESS INTERRUPTED")
445
+ else:
446
+ print(f"PROCESSING COMPLETE!")
447
+
448
+ print(f"Session time: {timedelta(seconds=int(total_time))}")
449
+ print(f"Samples successfully processed: {samples_done:,}")
450
+ print(f"Samples failed: {samples_failed:,}")
451
+ if total_time > 0 and samples_done > 0:
452
+ print(f"Average speed: {samples_done/total_time:.1f} samples/second")
453
+
454
+ # Save final result
455
+ final_output = "/home/ubuntu/ttsar/ASR_DATA/transcribed_manifest_final.csv"
456
+ df.to_csv(final_output, index=False)
457
+ print(f"Final output saved to: {final_output}")
458
+ print("-" * 50)
459
+
460
+ # --- Summary ---
461
+ successful_transcriptions = df['text'].notna().sum() - (df['text'] == "[FAILED]").sum()
462
+ failed_transcriptions = (df['text'] == "[FAILED]").sum()
463
+ remaining_missing = df['text'].isna().sum()
464
+
465
+ print("Summary:")
466
+ print(f" - Total dataset size: {len(df):,} samples")
467
+ print(f" - Successfully transcribed: {successful_transcriptions:,} samples")
468
+ print(f" - Failed transcriptions: {failed_transcriptions:,} samples")
469
+ print(f" - Still missing (NaN): {remaining_missing:,} samples")
470
+ print(f" - Processed this session: {samples_done:,} successful, {samples_failed:,} failed")
471
+ print(f" - Failed batches (entire batch): {len(failed_chunks)}")
472
+
473
+ # Save list of failed files
474
+ if failed_files_list:
475
+ failed_files_path = "/home/ubuntu/ttsar/ASR_DATA/failed_transcription_files.txt"
476
+ with open(failed_files_path, 'w') as f:
477
+ for file_path in failed_files_list:
478
+ f.write(f"{file_path}\n")
479
+ print(f"\nFailed files saved to: {failed_files_path}")
480
+
481
+ if failed_chunks:
482
+ failed_batches_path = "/home/ubuntu/ttsar/ASR_DATA/failed_batches.txt"
483
+ with open(failed_batches_path, 'w') as f:
484
+ for chunk in failed_chunks:
485
+ f.write(f"Batch {chunk['batch_id']}: indices {chunk['indices'][:5]}... ({len(chunk['indices'])} total)\n")
486
+ print(f"Failed batch info saved to: {failed_batches_path}")
487
+
488
+ print("-" * 50)