Update app.py
Browse files
app.py
CHANGED
|
@@ -83,19 +83,19 @@ training_status = {
|
|
| 83 |
model_path = MODEL_SAVE_DIR / "DEBERTA_model.pth"
|
| 84 |
tokenizer = get_tokenizer(DEBERTA_MODEL_NAME)
|
| 85 |
|
| 86 |
-
#
|
| 87 |
try:
|
| 88 |
label_encoders = load_label_encoders()
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
except Exception as e:
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
model = DebertaMultiOutputModel(num_labels_list).to(DEVICE)
|
| 96 |
-
if os.path.exists(model_path):
|
| 97 |
-
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
|
| 98 |
-
model.eval()
|
| 99 |
|
| 100 |
class TrainingConfig(BaseModel):
|
| 101 |
model_name: str = DEBERTA_MODEL_NAME
|
|
@@ -264,7 +264,7 @@ async def validate_model(
|
|
| 264 |
|
| 265 |
data_df, label_encoders = load_and_preprocess_data(str(file_path))
|
| 266 |
|
| 267 |
-
model_path = MODEL_SAVE_DIR / f"{model_name}.pth"
|
| 268 |
if not model_path.exists():
|
| 269 |
raise HTTPException(status_code=404, detail="DeBERTa model file not found")
|
| 270 |
|
|
@@ -353,18 +353,16 @@ async def predict(
|
|
| 353 |
"""
|
| 354 |
try:
|
| 355 |
# Load the model
|
| 356 |
-
model_path = MODEL_SAVE_DIR / f"{model_name}.pth"
|
| 357 |
if not model_path.exists():
|
| 358 |
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
|
| 359 |
|
| 360 |
-
# Load label encoders
|
| 361 |
try:
|
| 362 |
label_encoders = load_label_encoders()
|
| 363 |
num_labels_list = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
|
| 364 |
except Exception as e:
|
| 365 |
-
|
| 366 |
-
# Use default values if label encoders can't be loaded
|
| 367 |
-
num_labels_list = [2] * len(LABEL_COLUMNS) # Default to binary classification
|
| 368 |
|
| 369 |
model = DebertaMultiOutputModel(num_labels_list).to(DEVICE)
|
| 370 |
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
|
|
@@ -492,7 +490,7 @@ async def predict(
|
|
| 492 |
@app.get("/v1/deberta/download-model/{model_id}")
|
| 493 |
async def download_model(model_id: str):
|
| 494 |
"""Download a trained model"""
|
| 495 |
-
model_path = MODEL_SAVE_DIR / f"{model_id}.pth"
|
| 496 |
if not model_path.exists():
|
| 497 |
raise HTTPException(status_code=404, detail="Model not found")
|
| 498 |
|
|
@@ -524,7 +522,7 @@ async def train_model_task(config: TrainingConfig, file_path: str, training_id:
|
|
| 524 |
tokenizer,
|
| 525 |
config.max_length
|
| 526 |
)
|
| 527 |
-
model = DebertaMultiOutputModel(num_labels_list).to(DEVICE)
|
| 528 |
else:
|
| 529 |
dataset = ComplianceDataset(
|
| 530 |
texts.tolist(),
|
|
@@ -546,7 +544,7 @@ async def train_model_task(config: TrainingConfig, file_path: str, training_id:
|
|
| 546 |
training_status["current_loss"] = train_loss
|
| 547 |
|
| 548 |
# Save model after each epoch
|
| 549 |
-
save_model(model, training_id)
|
| 550 |
|
| 551 |
training_status.update({
|
| 552 |
"is_training": False,
|
|
|
|
| 83 |
model_path = MODEL_SAVE_DIR / "DEBERTA_model.pth"
|
| 84 |
tokenizer = get_tokenizer(DEBERTA_MODEL_NAME)
|
| 85 |
|
| 86 |
+
# Initialize model and label encoders with error handling
|
| 87 |
try:
|
| 88 |
label_encoders = load_label_encoders()
|
| 89 |
+
model = DebertaMultiOutputModel([len(label_encoders[col].classes_) for col in LABEL_COLUMNS]).to(DEVICE)
|
| 90 |
+
if model_path.exists():
|
| 91 |
+
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
|
| 92 |
+
model.eval()
|
| 93 |
+
else:
|
| 94 |
+
print(f"Warning: Model file {model_path} not found. Model will be initialized but not loaded.")
|
| 95 |
except Exception as e:
|
| 96 |
+
print(f"Warning: Could not load label encoders or model: {str(e)}")
|
| 97 |
+
print("Model will be initialized when training starts.")
|
| 98 |
+
model = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
class TrainingConfig(BaseModel):
|
| 101 |
model_name: str = DEBERTA_MODEL_NAME
|
|
|
|
| 264 |
|
| 265 |
data_df, label_encoders = load_and_preprocess_data(str(file_path))
|
| 266 |
|
| 267 |
+
model_path = MODEL_SAVE_DIR / f"{model_name}_model.pth"
|
| 268 |
if not model_path.exists():
|
| 269 |
raise HTTPException(status_code=404, detail="DeBERTa model file not found")
|
| 270 |
|
|
|
|
| 353 |
"""
|
| 354 |
try:
|
| 355 |
# Load the model
|
| 356 |
+
model_path = MODEL_SAVE_DIR / f"{model_name}_model.pth"
|
| 357 |
if not model_path.exists():
|
| 358 |
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
|
| 359 |
|
| 360 |
+
# Load label encoders
|
| 361 |
try:
|
| 362 |
label_encoders = load_label_encoders()
|
| 363 |
num_labels_list = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
|
| 364 |
except Exception as e:
|
| 365 |
+
raise HTTPException(status_code=500, detail=f"Could not load label encoders: {str(e)}")
|
|
|
|
|
|
|
| 366 |
|
| 367 |
model = DebertaMultiOutputModel(num_labels_list).to(DEVICE)
|
| 368 |
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
|
|
|
|
| 490 |
@app.get("/v1/deberta/download-model/{model_id}")
|
| 491 |
async def download_model(model_id: str):
|
| 492 |
"""Download a trained model"""
|
| 493 |
+
model_path = MODEL_SAVE_DIR / f"{model_id}_model.pth"
|
| 494 |
if not model_path.exists():
|
| 495 |
raise HTTPException(status_code=404, detail="Model not found")
|
| 496 |
|
|
|
|
| 522 |
tokenizer,
|
| 523 |
config.max_length
|
| 524 |
)
|
| 525 |
+
model = DebertaMultiOutputModel(num_labels_list, metadata_dim=metadata_dim).to(DEVICE)
|
| 526 |
else:
|
| 527 |
dataset = ComplianceDataset(
|
| 528 |
texts.tolist(),
|
|
|
|
| 544 |
training_status["current_loss"] = train_loss
|
| 545 |
|
| 546 |
# Save model after each epoch
|
| 547 |
+
save_model(model, training_id, 'pth')
|
| 548 |
|
| 549 |
training_status.update({
|
| 550 |
"is_training": False,
|