| | import os |
| | import torch |
| | import torch.nn as nn |
| | from torchvision import models, transforms |
| | from PIL import Image |
| | import json |
| | import sys |
| |
|
| | |
| | print("Handler module loaded") |
| | print(f"Python version: {sys.version}") |
| | print(f"PyTorch version: {torch.__version__}") |
| | print(f"Directory contents: {os.listdir('.')}") |
| | if os.path.exists('/repository'): |
| | print(f"Repository directory contents: {os.listdir('/repository')}") |
| |
|
| | |
| | class ViTForImageClassification: |
| | @staticmethod |
| | def from_pretrained(model_dir): |
| | |
| | print(f"ERROR: ViTForImageClassification.from_pretrained was called with {model_dir}") |
| | raise ValueError("ViTForImageClassification is not the correct model for this application") |
| |
|
| | class EndpointHandler: |
| | def __init__(self, model_dir): |
| | """ |
| | Initialize the model for AI image detection |
| | """ |
| | print(f"Initializing EndpointHandler with model_dir: {model_dir}") |
| | |
| | |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print(f"Using device: {self.device}") |
| | |
| | |
| | self.transform = transforms.Compose([ |
| | transforms.Resize((224, 224)), |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
| | ]) |
| | |
| | |
| | self.classes = ["Real Image", "AI-Generated Image"] |
| | |
| | |
| | try: |
| | self.model = self._load_model(model_dir) |
| | print("Model loaded successfully") |
| | except Exception as e: |
| | print(f"Error loading model: {e}") |
| | |
| | print("Creating a dummy model as fallback") |
| | self.model = models.efficientnet_v2_s(pretrained=True) |
| | self.model.classifier[-1] = nn.Linear( |
| | self.model.classifier[-1].in_features, 2 |
| | ) |
| | self.model.eval() |
| | |
| | def _load_model(self, model_dir): |
| | print(f"Loading model from directory: {model_dir}") |
| | print(f"Directory contents: {os.listdir(model_dir)}") |
| | |
| | |
| | model = models.efficientnet_v2_s(weights=None) |
| | |
| | |
| | model.classifier = nn.Sequential( |
| | nn.Linear(model.classifier[1].in_features, 1024), |
| | nn.ReLU(), |
| | nn.Dropout(p=0.3), |
| | nn.Linear(1024, 512), |
| | nn.ReLU(), |
| | nn.Dropout(p=0.3), |
| | nn.Linear(512, 2) |
| | ) |
| | |
| | |
| | model_found = False |
| | possible_paths = [ |
| | os.path.join(model_dir, "best_model_improved.pth"), |
| | os.path.join(model_dir, "pytorch_model.bin"), |
| | "best_model_improved.pth", |
| | "/repository/best_model_improved.pth" |
| | ] |
| | |
| | for model_path in possible_paths: |
| | print(f"Trying model path: {model_path}") |
| | if os.path.exists(model_path): |
| | print(f"Found model at: {model_path}") |
| | model.load_state_dict(torch.load(model_path, map_location=self.device)) |
| | model_found = True |
| | break |
| | |
| | if not model_found: |
| | |
| | if os.path.exists('best_model_improved.pth') and not os.path.exists(os.path.join(model_dir, 'best_model_improved.pth')): |
| | import shutil |
| | print(f"Copying model file to {model_dir}") |
| | shutil.copy('best_model_improved.pth', os.path.join(model_dir, 'best_model_improved.pth')) |
| | model.load_state_dict(torch.load(os.path.join(model_dir, 'best_model_improved.pth'), map_location=self.device)) |
| | model_found = True |
| | |
| | if not model_found: |
| | raise FileNotFoundError(f"Model file not found in any of these locations: {possible_paths}") |
| | |
| | model.to(self.device) |
| | model.eval() |
| | return model |
| | |
| | def __call__(self, data): |
| | """ |
| | Run prediction on the input data |
| | """ |
| | try: |
| | print(f"Received prediction request with data type: {type(data)}") |
| | |
| | |
| | if isinstance(data, dict) and "inputs" in data: |
| | |
| | input_data = data["inputs"] |
| | print(f"Extracted input data from API format, type: {type(input_data)}") |
| | else: |
| | |
| | input_data = data |
| | |
| | |
| | if isinstance(input_data, str): |
| | print("Processing base64 string image") |
| | import base64 |
| | from io import BytesIO |
| | |
| | |
| | if ',' in input_data: |
| | input_data = input_data.split(",", 1)[1] |
| | image_bytes = base64.b64decode(input_data) |
| | image = Image.open(BytesIO(image_bytes)).convert("RGB") |
| | elif hasattr(input_data, "read"): |
| | print("Processing file-like object image") |
| | image = Image.open(input_data).convert("RGB") |
| | elif isinstance(input_data, Image.Image): |
| | print("Processing PIL Image") |
| | image = input_data |
| | else: |
| | print(f"Unsupported input type: {type(input_data)}") |
| | return {"error": f"Unsupported input type: {type(input_data)}"} |
| | |
| | |
| | image_tensor = self.transform(image).unsqueeze(0).to(self.device) |
| | |
| | |
| | with torch.no_grad(): |
| | outputs = self.model(image_tensor) |
| | probabilities = torch.nn.functional.softmax(outputs, dim=1)[0] |
| | prediction = torch.argmax(probabilities).item() |
| | |
| | |
| | real_prob = probabilities[0].item() * 100 |
| | ai_prob = probabilities[1].item() * 100 |
| | |
| | |
| | |
| | return [ |
| | { |
| | "label": "Real Image", |
| | "score": float(real_prob) |
| | }, |
| | { |
| | "label": "AI-Generated Image", |
| | "score": float(ai_prob) |
| | } |
| | ] |
| | |
| | except Exception as e: |
| | import traceback |
| | print(f"Error during prediction: {e}") |
| | traceback.print_exc() |
| | return {"error": str(e), "traceback": traceback.format_exc()} |
| |
|