| | import ctypes |
| | import os |
| | from typing import List, Tuple |
| | import numpy as np |
| | import platform |
| | from pyaxdev import _lib, AxDeviceType, AxDevices, check_error |
| |
|
| | class ModelType(ctypes.c_int): |
| | ax_det_model_type_unknown = -1 |
| | ax_det_model_type_yolov5 = 0 |
| | ax_det_model_type_yolov8 = 1 |
| | ax_det_model_type_yolov8_pose = 2 |
| | ax_det_model_type_yolo11 = 3 |
| | ax_det_model_type_yolo11_pose = 4 |
| |
|
| | class DetInit(ctypes.Structure): |
| | _fields_ = [ |
| | ('dev_type', AxDeviceType), |
| | ('devid', ctypes.c_char), |
| | ('model_type', ModelType), |
| | ('model_path', ctypes.c_char * 256), |
| | ('num_classes', ctypes.c_int), |
| | ('num_kpt', ctypes.c_int), |
| | ('threshold', ctypes.c_float), |
| | ('mean', ctypes.c_float * 3), |
| | ('std', ctypes.c_float * 3), |
| | ] |
| |
|
| | class DetImage(ctypes.Structure): |
| | _fields_ = [ |
| | ('width', ctypes.c_int), |
| | ('height', ctypes.c_int), |
| | ('channels', ctypes.c_int), |
| | ('stride', ctypes.c_int), |
| | ('data', ctypes.POINTER(ctypes.c_ubyte)), |
| | ] |
| |
|
| |
|
| | class ObjectItem(ctypes.Structure): |
| | _fields_ = [ |
| | ('box', ctypes.c_int * 4), |
| | ('kpts', ctypes.c_int * 2 * 32), |
| | ('num_kpt', ctypes.c_int), |
| | ('score', ctypes.c_float), |
| | ('label', ctypes.c_int), |
| | ] |
| | |
| | class ObjectResult(ctypes.Structure): |
| | _fields_ = [ |
| | ('objects', ObjectItem * 64), |
| | ('num_objs', ctypes.c_int), |
| | ] |
| |
|
| | _lib.ax_det_init.argtypes = [ctypes.POINTER(DetInit), ctypes.POINTER(ctypes.c_void_p)] |
| | _lib.ax_det_init.restype = ctypes.c_int |
| |
|
| | _lib.ax_det_deinit.argtypes = [ctypes.c_void_p] |
| | _lib.ax_det_deinit.restype = ctypes.c_int |
| |
|
| | _lib.ax_det.argtypes = [ctypes.c_void_p, ctypes.POINTER(DetImage), ctypes.POINTER(ObjectResult)] |
| | _lib.ax_det.restype = ctypes.c_int |
| |
|
| | class Object: |
| | def __init__(self, box: List[int], score: float, label: int, kpts: List[int] = []): |
| | self.box = box |
| | self.score = score |
| | self.label = label |
| | self.kpts = kpts |
| |
|
| | def __repr__(self): |
| | return f"Object(box={self.box}, score={self.score:.2f}, label={self.label}, kpts={self.kpts})" |
| | |
| | class AXDet: |
| | def __init__(self, model_path: str, model_type: ModelType, num_classes: int, |
| | num_kpt: int = 0, |
| | threshold: float= 0.25, |
| | mean: List[float] = [0,0,0], std: List[float] = [1,1,1], |
| | dev_type: AxDeviceType = AxDeviceType.axcl_device, |
| | devid: int = 0): |
| | self.handle = None |
| | self.init_info = DetInit() |
| | |
| | |
| | self.init_info.dev_type = dev_type |
| | self.init_info.devid = devid |
| | |
| | |
| | self.init_info.model_type = model_type |
| | self.init_info.model_path = model_path.encode('utf-8') |
| | self.init_info.num_classes = num_classes |
| | self.init_info.num_kpt = num_kpt |
| | self.init_info.threshold = threshold |
| | for i in range(3): |
| | self.init_info.mean[i] = mean[i] |
| | self.init_info.std[i] = std[i] |
| | |
| | |
| | handle = ctypes.c_void_p() |
| | check_error(_lib.ax_det_init(ctypes.byref(self.init_info), ctypes.byref(handle))) |
| | self.handle = handle |
| |
|
| | def __del__(self): |
| | if self.handle: |
| | _lib.ax_det_deinit(self.handle) |
| |
|
| | def detect(self, image_data: np.ndarray): |
| | |
| | image = DetImage() |
| | image.data = ctypes.cast(image_data.ctypes.data, ctypes.POINTER(ctypes.c_ubyte)) |
| | image.width = image_data.shape[1] |
| | image.height = image_data.shape[0] |
| | image.channels = image_data.shape[2] |
| | image.stride = image_data.shape[1] * image_data.shape[2] |
| | result = ObjectResult() |
| | check_error(_lib.ax_det(self.handle, ctypes.byref(image), ctypes.byref(result))) |
| | objects = [] |
| | for i in range(result.num_objs): |
| | _obj = result.objects[i] |
| | |
| | obj = Object( |
| | box=[_obj.box[0], _obj.box[1], _obj.box[2], _obj.box[3]], |
| | score=_obj.score, |
| | label=_obj.label, |
| | kpts=[(_obj.kpts[j][0], _obj.kpts[j][1]) for j in range(_obj.num_kpt)], |
| | ) |
| | objects.append(obj) |
| | return objects |
| |
|
| |
|