DEIMv2 / axmodel_inf.py
jordan0811's picture
Upload axmodel_inf.py with huggingface_hub
1a45478 verified
"""
DEIMv2: Real-Time Object Detection Meets DINOv3
Copyright (c) 2025 The DEIMv2 Authors. All Rights Reserved.
---------------------------------------------------------------------------------
Modified from D-FINE (https://github.com/Peterande/D-FINE)
Copyright (c) 2024 The D-FINE Authors. All Rights Reserved.
"""
import cv2
import numpy as np
import axengine as ort
import torch
import torchvision
import torchvision.transforms as T
from PIL import Image, ImageDraw
from torch import nn
import torch.nn.functional as F
def mod(a, b):
out = a - a // b * b
return out
class PostProcessor(nn.Module):
__share__ = [
'num_classes',
'use_focal_loss',
'num_top_queries',
'remap_mscoco_category'
]
def __init__(
self,
num_classes=80,
use_focal_loss=True,
num_top_queries=300,
remap_mscoco_category=False
) -> None:
super().__init__()
self.use_focal_loss = use_focal_loss
self.num_top_queries = num_top_queries
self.num_classes = int(num_classes)
self.remap_mscoco_category = remap_mscoco_category
self.deploy_mode = False
def extra_repr(self) -> str:
return f'use_focal_loss={self.use_focal_loss}, num_classes={self.num_classes}, num_top_queries={self.num_top_queries}'
# def forward(self, outputs, orig_target_sizes):
def forward(self, outputs, orig_target_sizes: torch.Tensor):
logits, boxes = outputs['pred_logits'], outputs['pred_boxes']
# orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
bbox_pred = torchvision.ops.box_convert(boxes, in_fmt='cxcywh', out_fmt='xyxy')
bbox_pred *= orig_target_sizes.repeat(1, 2).unsqueeze(1)
if self.use_focal_loss:
scores = F.sigmoid(logits)
scores, index = torch.topk(scores.flatten(1), self.num_top_queries, dim=-1)
# labels = index % self.num_classes
labels = mod(index, self.num_classes)
index = index // self.num_classes
boxes = bbox_pred.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, bbox_pred.shape[-1]))
else:
scores = F.softmax(logits)[:, :, :-1]
scores, labels = scores.max(dim=-1)
if scores.shape[1] > self.num_top_queries:
scores, index = torch.topk(scores, self.num_top_queries, dim=-1)
labels = torch.gather(labels, dim=1, index=index)
boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1]))
if self.deploy_mode:
return labels, boxes, scores
if self.remap_mscoco_category:
from ..data.dataset import mscoco_label2category
labels = torch.tensor([mscoco_label2category[int(x.item())] for x in labels.flatten()])\
.to(boxes.device).reshape(labels.shape)
results = []
for lab, box, sco in zip(labels, boxes, scores):
result = dict(labels=lab, boxes=box, scores=sco)
results.append(result)
return results
def deploy(self, ):
self.eval()
self.deploy_mode = True
return self
def resize_with_aspect_ratio(image, size, interpolation=Image.BILINEAR):
"""Resizes an image while maintaining aspect ratio and pads it."""
original_width, original_height = image.size
ratio = min(size / original_width, size / original_height)
new_width = int(original_width * ratio)
new_height = int(original_height * ratio)
image = image.resize((new_width, new_height), interpolation)
# Create a new image with the desired size and paste the resized image onto it
new_image = Image.new("RGB", (size, size))
new_image.paste(image, ((size - new_width) // 2, (size - new_height) // 2))
return new_image, ratio, (size - new_width) // 2, (size - new_height) // 2
def draw(images, labels, boxes, scores, ratios, paddings, thrh=0.4):
result_images = []
for i, im in enumerate(images):
draw = ImageDraw.Draw(im)
scr = scores[i]
lab = labels[i][scr > thrh]
box = boxes[i][scr > thrh]
scr = scr[scr > thrh]
ratio = ratios[i]
pad_w, pad_h = paddings[i]
for lbl, bb in zip(lab, box):
# Adjust bounding boxes according to the resizing and padding
bb = [
(bb[0] - pad_w) / ratio,
(bb[1] - pad_h) / ratio,
(bb[2] - pad_w) / ratio,
(bb[3] - pad_h) / ratio,
]
draw.rectangle(bb, outline='red')
draw.text((bb[0], bb[1]), text=str(lbl), fill='blue')
result_images.append(im)
return result_images
def process_image(sess, im_pil, size=640, model_size='s'):
post_processor = PostProcessor().deploy()
# Resize image while preserving aspect ratio
resized_im_pil, ratio, pad_w, pad_h = resize_with_aspect_ratio(im_pil, size)
orig_size = torch.tensor([[resized_im_pil.size[1], resized_im_pil.size[0]]])
transforms = T.Compose([
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
if model_size not in ['atto', 'femto', 'pico', 'n']
else T.Lambda(lambda x: x)
])
im_data = transforms(resized_im_pil).unsqueeze(0)
output = sess.run(
output_names=None,
input_feed={'images': im_data.numpy()}
)
output = {"pred_logits": torch.from_numpy(output[0]), "pred_boxes": torch.from_numpy(output[1])}
output = post_processor(output, orig_size)
labels, boxes, scores = output[0].numpy(), output[1].numpy(), output[2].numpy()
result_images = draw(
[im_pil], labels, boxes, scores,
[ratio], [(pad_w, pad_h)]
)
result_images[0].save('result.jpg')
print("Image processing complete. Result saved as 'result.jpg'.")
def process_video(sess, video_path, size=640, model_size='s'):
cap = cv2.VideoCapture(video_path)
# Get video properties
fps = cap.get(cv2.CAP_PROP_FPS)
orig_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
orig_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# Define the codec and create VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter('onnx_result.mp4', fourcc, fps, (orig_w, orig_h))
frame_count = 0
print("Processing video frames...")
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Convert frame to PIL image
frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
# Resize frame while preserving aspect ratio
resized_frame_pil, ratio, pad_w, pad_h = resize_with_aspect_ratio(frame_pil, size)
orig_size = torch.tensor([[resized_frame_pil.size[1], resized_frame_pil.size[0]]])
transforms = T.Compose([
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
if model_size not in ['atto', 'femto', 'pico', 'n']
else T.Lambda(lambda x: x)
])
im_data = transforms(resized_frame_pil).unsqueeze(0)
output = sess.run(
output_names=None,
input_feed={'images': im_data.numpy(), "orig_target_sizes": orig_size.numpy()}
)
labels, boxes, scores = output
# Draw detections on the original frame
result_images = draw(
[frame_pil], labels, boxes, scores,
[ratio], [(pad_w, pad_h)]
)
frame_with_detections = result_images[0]
# Convert back to OpenCV image
frame = cv2.cvtColor(np.array(frame_with_detections), cv2.COLOR_RGB2BGR)
# Write the frame
out.write(frame)
frame_count += 1
if frame_count % 10 == 0:
print(f"Processed {frame_count} frames...")
cap.release()
out.release()
print("Video processing complete. Result saved as 'result.mp4'.")
def main(args):
"""Main function."""
# Load the ONNX model
sess = ort.InferenceSession(args.axmodel)
size = sess.get_inputs()[0].shape[2]
input_path = args.input
try:
# Try to open the input as an image
im_pil = Image.open(input_path).convert('RGB')
process_image(sess, im_pil, size, args.model_size)
except IOError:
# Not an image, process as video
process_video(sess, input_path, size, args.model_size)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--axmodel', type=str, default="compiled.axmodel", help='Path to the axmodel model file.')
parser.add_argument('--input', type=str, required=True, help='Path to the input image or video file.')
parser.add_argument('-ms', '--model-size', type=str, required=True, choices=['atto', 'femto', 'pico', 'n', 's', 'm', 'l', 'x'],
help='Model size')
args = parser.parse_args()
main(args)