namanpenguin commited on
Commit
b7a2caf
·
verified ·
1 Parent(s): b4d7e60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -19
app.py CHANGED
@@ -83,19 +83,19 @@ training_status = {
83
  model_path = MODEL_SAVE_DIR / "DEBERTA_model.pth"
84
  tokenizer = get_tokenizer(DEBERTA_MODEL_NAME)
85
 
86
- # Load label encoders with error handling
87
  try:
88
  label_encoders = load_label_encoders()
89
- num_labels_list = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
 
 
 
 
 
90
  except Exception as e:
91
- logger.warning(f"Could not load label encoders: {str(e)}")
92
- # Use default values if label encoders can't be loaded
93
- num_labels_list = [2] * len(LABEL_COLUMNS) # Default to binary classification
94
-
95
- model = DebertaMultiOutputModel(num_labels_list).to(DEVICE)
96
- if os.path.exists(model_path):
97
- model.load_state_dict(torch.load(model_path, map_location=DEVICE))
98
- model.eval()
99
 
100
  class TrainingConfig(BaseModel):
101
  model_name: str = DEBERTA_MODEL_NAME
@@ -264,7 +264,7 @@ async def validate_model(
264
 
265
  data_df, label_encoders = load_and_preprocess_data(str(file_path))
266
 
267
- model_path = MODEL_SAVE_DIR / f"{model_name}.pth"
268
  if not model_path.exists():
269
  raise HTTPException(status_code=404, detail="DeBERTa model file not found")
270
 
@@ -353,18 +353,16 @@ async def predict(
353
  """
354
  try:
355
  # Load the model
356
- model_path = MODEL_SAVE_DIR / f"{model_name}.pth"
357
  if not model_path.exists():
358
  raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
359
 
360
- # Load label encoders with error handling
361
  try:
362
  label_encoders = load_label_encoders()
363
  num_labels_list = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
364
  except Exception as e:
365
- logger.warning(f"Could not load label encoders: {str(e)}")
366
- # Use default values if label encoders can't be loaded
367
- num_labels_list = [2] * len(LABEL_COLUMNS) # Default to binary classification
368
 
369
  model = DebertaMultiOutputModel(num_labels_list).to(DEVICE)
370
  model.load_state_dict(torch.load(model_path, map_location=DEVICE))
@@ -492,7 +490,7 @@ async def predict(
492
  @app.get("/v1/deberta/download-model/{model_id}")
493
  async def download_model(model_id: str):
494
  """Download a trained model"""
495
- model_path = MODEL_SAVE_DIR / f"{model_id}.pth"
496
  if not model_path.exists():
497
  raise HTTPException(status_code=404, detail="Model not found")
498
 
@@ -524,7 +522,7 @@ async def train_model_task(config: TrainingConfig, file_path: str, training_id:
524
  tokenizer,
525
  config.max_length
526
  )
527
- model = DebertaMultiOutputModel(num_labels_list).to(DEVICE)
528
  else:
529
  dataset = ComplianceDataset(
530
  texts.tolist(),
@@ -546,7 +544,7 @@ async def train_model_task(config: TrainingConfig, file_path: str, training_id:
546
  training_status["current_loss"] = train_loss
547
 
548
  # Save model after each epoch
549
- save_model(model, training_id)
550
 
551
  training_status.update({
552
  "is_training": False,
 
83
  model_path = MODEL_SAVE_DIR / "DEBERTA_model.pth"
84
  tokenizer = get_tokenizer(DEBERTA_MODEL_NAME)
85
 
86
+ # Initialize model and label encoders with error handling
87
  try:
88
  label_encoders = load_label_encoders()
89
+ model = DebertaMultiOutputModel([len(label_encoders[col].classes_) for col in LABEL_COLUMNS]).to(DEVICE)
90
+ if model_path.exists():
91
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
92
+ model.eval()
93
+ else:
94
+ print(f"Warning: Model file {model_path} not found. Model will be initialized but not loaded.")
95
  except Exception as e:
96
+ print(f"Warning: Could not load label encoders or model: {str(e)}")
97
+ print("Model will be initialized when training starts.")
98
+ model = None
 
 
 
 
 
99
 
100
  class TrainingConfig(BaseModel):
101
  model_name: str = DEBERTA_MODEL_NAME
 
264
 
265
  data_df, label_encoders = load_and_preprocess_data(str(file_path))
266
 
267
+ model_path = MODEL_SAVE_DIR / f"{model_name}_model.pth"
268
  if not model_path.exists():
269
  raise HTTPException(status_code=404, detail="DeBERTa model file not found")
270
 
 
353
  """
354
  try:
355
  # Load the model
356
+ model_path = MODEL_SAVE_DIR / f"{model_name}_model.pth"
357
  if not model_path.exists():
358
  raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
359
 
360
+ # Load label encoders
361
  try:
362
  label_encoders = load_label_encoders()
363
  num_labels_list = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
364
  except Exception as e:
365
+ raise HTTPException(status_code=500, detail=f"Could not load label encoders: {str(e)}")
 
 
366
 
367
  model = DebertaMultiOutputModel(num_labels_list).to(DEVICE)
368
  model.load_state_dict(torch.load(model_path, map_location=DEVICE))
 
490
  @app.get("/v1/deberta/download-model/{model_id}")
491
  async def download_model(model_id: str):
492
  """Download a trained model"""
493
+ model_path = MODEL_SAVE_DIR / f"{model_id}_model.pth"
494
  if not model_path.exists():
495
  raise HTTPException(status_code=404, detail="Model not found")
496
 
 
522
  tokenizer,
523
  config.max_length
524
  )
525
+ model = DebertaMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE)
526
  else:
527
  dataset = ComplianceDataset(
528
  texts.tolist(),
 
544
  training_status["current_loss"] = train_loss
545
 
546
  # Save model after each epoch
547
+ save_model(model, training_id, 'pth')
548
 
549
  training_status.update({
550
  "is_training": False,