| | |
| | """ |
| | Simple inference script for anime style classification. |
| | """ |
| |
|
| | import torch |
| | from torchvision import models, transforms |
| | from PIL import Image |
| | import json |
| | from pathlib import Path |
| |
|
| |
|
| | def load_model(model_path=None, config_path='config.json'): |
| | """Load the trained model. |
| | |
| | Args: |
| | model_path: Path to model weights. If None, auto-detects (.safetensors preferred) |
| | config_path: Path to config file |
| | """ |
| | |
| | with open(config_path, 'r') as f: |
| | config = json.load(f) |
| |
|
| | |
| | model = models.efficientnet_b0(pretrained=False) |
| | num_classes = config['num_classes'] |
| | model.classifier[-1] = torch.nn.Linear(model.classifier[-1].in_features, num_classes) |
| |
|
| | |
| | if model_path is None: |
| | if Path('model.safetensors').exists(): |
| | model_path = 'model.safetensors' |
| | elif Path('pytorch_model.pth').exists(): |
| | model_path = 'pytorch_model.pth' |
| | else: |
| | raise FileNotFoundError("No model weights found (model.safetensors or pytorch_model.pth)") |
| |
|
| | |
| | if model_path.endswith('.safetensors'): |
| | from safetensors.torch import load_file |
| | state_dict = load_file(model_path) |
| | model.load_state_dict(state_dict) |
| | else: |
| | checkpoint = torch.load(model_path, map_location='cpu') |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| |
|
| | model.eval() |
| | return model, config |
| |
|
| |
|
| | def preprocess_image(image_path, img_size=224): |
| | """Preprocess image for model input.""" |
| | transform = transforms.Compose([ |
| | transforms.Resize((img_size, img_size)), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| | std=[0.229, 0.224, 0.225]) |
| | ]) |
| |
|
| | image = Image.open(image_path).convert('RGB') |
| | return transform(image).unsqueeze(0) |
| |
|
| |
|
| | def classify_image(model, config, image_path): |
| | """Classify a single image.""" |
| | |
| | input_tensor = preprocess_image(image_path) |
| |
|
| | |
| | with torch.no_grad(): |
| | output = model(input_tensor) |
| | probabilities = torch.nn.functional.softmax(output[0], dim=0) |
| |
|
| | |
| | results = [] |
| | for idx, prob in enumerate(probabilities): |
| | style = config['id2label'][str(idx)] |
| | results.append({ |
| | 'style': style, |
| | 'confidence': float(prob) |
| | }) |
| |
|
| | |
| | results.sort(key=lambda x: x['confidence'], reverse=True) |
| |
|
| | return results |
| |
|
| |
|
| | def main(): |
| | import argparse |
| |
|
| | parser = argparse.ArgumentParser(description='Classify anime style') |
| | parser.add_argument('image', type=str, help='Path to image') |
| | parser.add_argument('--model', type=str, default=None, |
| | help='Path to model weights (auto-detects .safetensors or .pth if not specified)') |
| | parser.add_argument('--config', type=str, default='config.json') |
| | parser.add_argument('--top-k', type=int, default=3, help='Show top-K predictions') |
| |
|
| | args = parser.parse_args() |
| |
|
| | |
| | print(f"Loading model from {args.model}...") |
| | model, config = load_model(args.model, args.config) |
| |
|
| | |
| | print(f"Classifying {args.image}...") |
| | results = classify_image(model, config, args.image) |
| |
|
| | |
| | print() |
| | print("=" * 60) |
| | print("PREDICTIONS") |
| | print("=" * 60) |
| | for i, result in enumerate(results[:args.top_k], 1): |
| | print(f"{i}. {result['style']:12s} {result['confidence']:>7.2%}") |
| | print("=" * 60) |
| | print() |
| | print(f"Top prediction: {results[0]['style']} ({results[0]['confidence']:.2%})") |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|