| | import argparse |
| | import os |
| | import re |
| | import sys |
| | import time |
| | import cv2 |
| | import math |
| | import glob |
| | import numpy as np |
| |
|
| | import axengine as axe |
| | from axengine import axclrt_provider_name, axengine_provider_name |
| |
|
| | def load_model(model_path: str | os.PathLike, selected_provider: str, selected_device_id: int = 0): |
| | if selected_provider == 'AUTO': |
| | |
| | return axe.InferenceSession(model_path) |
| |
|
| | providers = [] |
| | if selected_provider == axclrt_provider_name: |
| | provider_options = {"device_id": selected_device_id} |
| | providers.append((axclrt_provider_name, provider_options)) |
| | if selected_provider == axengine_provider_name: |
| | providers.append(axengine_provider_name) |
| |
|
| | return axe.InferenceSession(model_path, providers=providers) |
| |
|
| |
|
| | def get_frames(video_name): |
| | """获取视频帧 |
| | |
| | Args: |
| | video_name (_type_): _description_ |
| | |
| | Yields: |
| | _type_: _description_ |
| | """ |
| | if not video_name: |
| | rtsp = "rtsp://%s:%s@%s:554/cam/realmonitor?channel=1&subtype=1" % ("admin", "123456", "192.168.1.108") |
| | cap = cv2.VideoCapture(rtsp) if rtsp else cv2.VideoCapture() |
| | |
| | |
| | for i in range(5): |
| | cap.read() |
| | while True: |
| | ret, frame = cap.read() |
| | if ret: |
| | |
| | yield cv2.resize(frame,(800, 600)) |
| | else: |
| | break |
| | elif video_name.endswith('avi') or \ |
| | video_name.endswith('mp4'): |
| | cap = cv2.VideoCapture(video_name) |
| | while True: |
| | ret, frame = cap.read() |
| | if ret: |
| | yield frame |
| | else: |
| | break |
| | else: |
| | images = sorted(glob(os.path.join(video_name, 'img', '*.jp*'))) |
| | for img in images: |
| | frame = cv2.imread(img) |
| | yield frame |
| |
|
| |
|
| | class Preprocessor_wo_mask(object): |
| | def __init__(self): |
| | self.mean = np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1)).astype(np.float32) |
| | self.std = np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1)).astype(np.float32) |
| |
|
| | def process(self, img_arr: np.ndarray): |
| | |
| | img_tensor = img_arr.transpose((2, 0, 1)).reshape((1, 3, img_arr.shape[0], img_arr.shape[1])).astype(np.float32) / 255.0 |
| | img_tensor_norm = (img_tensor - self.mean) / self.std |
| | return img_tensor_norm |
| |
|
| |
|
| | class MFTrackerORT: |
| | def __init__(self, model_path, fp16=False) -> None: |
| | self.debug = True |
| | self.gpu_id = 0 |
| | self.providers = ["CUDAExecutionProvider"] |
| | self.provider_options = [{"device_id": str(self.gpu_id)}] |
| | self.model_path = model_path |
| | self.fp16 = fp16 |
| | |
| | self.init_track_net() |
| | self.preprocessor = Preprocessor_wo_mask() |
| | self.max_score_decay = 1.0 |
| | self.search_factor = 4.5 |
| | self.search_size = 224 |
| | self.template_factor = 2.0 |
| | self.template_size = 112 |
| | self.update_interval = 200 |
| | self.online_size = 1 |
| |
|
| | def init_track_net(self): |
| | """使用设置的参数初始化tracker网络 |
| | """ |
| | self.ax_session = load_model(self.model_path, selected_provider="AUTO") |
| |
|
| | def track_init(self, frame, target_pos=None, target_sz = None): |
| | """使用第一帧进行初始化 |
| | |
| | Args: |
| | frame (_type_): _description_ |
| | target_pos (_type_, optional): _description_. Defaults to None. |
| | target_sz (_type_, optional): _description_. Defaults to None. |
| | """ |
| | self.trace_list = [] |
| | try: |
| | |
| | init_state = [target_pos[0], target_pos[1], target_sz[0], target_sz[1]] |
| | z_patch_arr, _, z_amask_arr = self.sample_target(frame, init_state, self.template_factor, output_sz=self.template_size) |
| | template = self.preprocessor.process(z_patch_arr) |
| | self.template = template |
| | self.online_template = template |
| |
|
| | self.online_state = init_state |
| | self.online_image = frame |
| | self.max_pred_score = -1.0 |
| | self.online_max_template = template |
| | self.online_forget_id = 0 |
| |
|
| | |
| | self.state = init_state |
| | self.frame_id = 0 |
| | print(f"第一帧初始化完毕!") |
| | except: |
| | print(f"第一帧初始化异常!") |
| | exit() |
| |
|
| | def track(self, image, info: dict = None): |
| | H, W, _ = image.shape |
| | self.frame_id += 1 |
| | x_patch_arr, resize_factor, x_amask_arr = self.sample_target(image, self.state, self.search_factor, |
| | output_sz=self.search_size) |
| | search = self.preprocessor.process(x_patch_arr) |
| |
|
| | |
| | ort_inputs = {'img_t': self.template, 'img_ot': self.online_template, 'img_search': search} |
| |
|
| | ort_outs = self.ax_session.run(None, ort_inputs) |
| |
|
| | |
| | pred_boxes = ort_outs[0] |
| | pred_score = ort_outs[1] |
| | |
| | |
| | pred_box = (np.mean(pred_boxes, axis=0) * self.search_size / resize_factor).tolist() |
| | |
| | self.state = self.clip_box(self.map_box_back(pred_box, resize_factor), H, W, margin=10) |
| |
|
| | self.max_pred_score = self.max_pred_score * self.max_score_decay |
| | |
| | if pred_score > 0.5 and pred_score > self.max_pred_score: |
| | z_patch_arr, _, z_amask_arr = self.sample_target(image, self.state, |
| | self.template_factor, |
| | output_sz=self.template_size) |
| | self.online_max_template = self.preprocessor.process(z_patch_arr) |
| | self.max_pred_score = pred_score |
| |
|
| | |
| | if self.frame_id % self.update_interval == 0: |
| | if self.online_size == 1: |
| | self.online_template = self.online_max_template |
| | else: |
| | self.online_template[self.online_forget_id:self.online_forget_id+1] = self.online_max_template |
| | self.online_forget_id = (self.online_forget_id + 1) % self.online_size |
| |
|
| | self.max_pred_score = -1 |
| | self.online_max_template = self.template |
| |
|
| | |
| | if self.debug: |
| | x1, y1, w, h = self.state |
| | |
| | cv2.rectangle(image, (int(x1),int(y1)), (int(x1+w),int(y1+h)), color=(0,0,255), thickness=2) |
| |
|
| | return {"target_bbox": self.state, "conf_score": pred_score} |
| |
|
| | def map_box_back(self, pred_box: list, resize_factor: float): |
| | cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3] |
| | cx, cy, w, h = pred_box |
| | half_side = 0.5 * self.search_size / resize_factor |
| | cx_real = cx + (cx_prev - half_side) |
| | cy_real = cy + (cy_prev - half_side) |
| | return [cx_real - 0.5 * w, cy_real - 0.5 * h, w, h] |
| |
|
| | def map_box_back_batch(self, pred_box: np.ndarray, resize_factor: float): |
| | cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3] |
| | cx, cy, w, h = pred_box.T |
| | half_side = 0.5 * self.search_size / resize_factor |
| | cx_real = cx + (cx_prev - half_side) |
| | cy_real = cy + (cy_prev - half_side) |
| | return np.stack([cx_real - 0.5 * w, cy_real - 0.5 * h, w, h], axis=-1) |
| | |
| | def sample_target(self, im, target_bb, search_area_factor, output_sz=None, mask=None): |
| | """ Extracts a square crop centered at target_bb box, of area search_area_factor^2 times target_bb area |
| | |
| | args: |
| | im - cv image |
| | target_bb - target box [x, y, w, h] |
| | search_area_factor - Ratio of crop size to target size |
| | output_sz - (float) Size to which the extracted crop is resized (always square). If None, no resizing is done. |
| | |
| | returns: |
| | cv image - extracted crop |
| | float - the factor by which the crop has been resized to make the crop size equal output_size |
| | """ |
| | if not isinstance(target_bb, list): |
| | x, y, w, h = target_bb.tolist() |
| | else: |
| | x, y, w, h = target_bb |
| | |
| | crop_sz = math.ceil(math.sqrt(w * h) * search_area_factor) |
| |
|
| | if crop_sz < 1: |
| | raise Exception('Too small bounding box.') |
| |
|
| | x1 = int(round(x + 0.5 * w - crop_sz * 0.5)) |
| | x2 = int(x1 + crop_sz) |
| |
|
| | y1 = int(round(y + 0.5 * h - crop_sz * 0.5)) |
| | y2 = int(y1 + crop_sz) |
| |
|
| | x1_pad = int(max(0, -x1)) |
| | x2_pad = int(max(x2 - im.shape[1] + 1, 0)) |
| |
|
| | y1_pad = int(max(0, -y1)) |
| | y2_pad = int(max(y2 - im.shape[0] + 1, 0)) |
| |
|
| | |
| | im_crop = im[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad, :] |
| | if mask is not None: |
| | mask_crop = mask[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad] |
| |
|
| | |
| | im_crop_padded = cv2.copyMakeBorder(im_crop, y1_pad, y2_pad, x1_pad, x2_pad, cv2.BORDER_CONSTANT) |
| | |
| | H, W, _ = im_crop_padded.shape |
| | att_mask = np.ones((H,W)) |
| | end_x, end_y = -x2_pad, -y2_pad |
| | if y2_pad == 0: |
| | end_y = None |
| | if x2_pad == 0: |
| | end_x = None |
| | att_mask[y1_pad:end_y, x1_pad:end_x] = 0 |
| | if mask is not None: |
| | mask_crop_padded = cv2.copyMakeBorder(mask_crop, y1_pad, y2_pad, x1_pad, x2_pad, cv2.BORDER_CONSTANT) |
| |
|
| | if output_sz is not None: |
| | resize_factor = output_sz / crop_sz |
| | im_crop_padded = cv2.resize(im_crop_padded, (output_sz, output_sz)) |
| | att_mask = cv2.resize(att_mask, (output_sz, output_sz)).astype(np.bool_) |
| | if mask is None: |
| | return im_crop_padded, resize_factor, att_mask |
| | mask_crop_padded = \ |
| | mask_crop_padded = cv2.resize(mask_crop_padded, (output_sz, output_sz)) |
| | return im_crop_padded, resize_factor, att_mask, mask_crop_padded |
| |
|
| | else: |
| | if mask is None: |
| | return im_crop_padded, att_mask.astype(np.bool_), 1.0 |
| | return im_crop_padded, 1.0, att_mask.astype(np.bool_), mask_crop_padded |
| | |
| | def clip_box(self, box: list, H, W, margin=0): |
| | x1, y1, w, h = box |
| | x2, y2 = x1 + w, y1 + h |
| | x1 = min(max(0, x1), W-margin) |
| | x2 = min(max(margin, x2), W) |
| | y1 = min(max(0, y1), H-margin) |
| | y2 = min(max(margin, y2), H) |
| | w = max(margin, x2-x1) |
| | h = max(margin, y2-y1) |
| | return [x1, y1, w, h] |
| | |
| | |
| | def main(model_path, frame_path, repeat, selected_provider, selected_device_id): |
| | Tracker = MFTrackerORT(model_path = model_path, fp16=False) |
| | first_frame = True |
| | Tracker.video_name = frame_path |
| |
|
| | frame_id = 0 |
| | total_time = 0 |
| | for frame in get_frames(Tracker.video_name): |
| | |
| | |
| | |
| | if repeat is not None and frame_id >= repeat: |
| | print(f"Reached the maximum number of frames ({repeat}). Exiting loop.") |
| | break |
| | |
| | tic = cv2.getTickCount() |
| | if first_frame: |
| | |
| | x, y, w, h = 1079, 482, 99, 106 |
| |
|
| | target_pos = [x, y] |
| | target_sz = [w, h] |
| | print('====================type=================', target_pos, type(target_pos), type(target_sz)) |
| | Tracker.track_init(frame, target_pos, target_sz) |
| | first_frame = False |
| | else: |
| | state = Tracker.track(frame) |
| | frame_id += 1 |
| |
|
| | os.makedirs('axmodel_output', exist_ok=True) |
| | cv2.imwrite(f'axmodel_output/{str(frame_id)}.png', frame) |
| |
|
| | toc = cv2.getTickCount() - tic |
| | toc = int(1 / (toc / cv2.getTickFrequency())) |
| | total_time += toc |
| | print('Video: {:12s} {:3.1f}fps'.format('tracking', toc)) |
| | |
| | print('video: average {:12s} {:3.1f} fps'.format('finale average tracking fps', total_time/(frame_id - 1))) |
| |
|
| | |
| |
|
| | class ExampleParser(argparse.ArgumentParser): |
| | def error(self, message): |
| | self.print_usage(sys.stderr) |
| | print(f"\nError: {message}") |
| | print("\nExample usage:") |
| | print(" python3 run_mixformer2_axmodel.py -m <model_file> -f <frame_file>") |
| | print(" python3 run_mixformer2_axmodel.py -m compiled.axmodel -f car.avi") |
| | print( |
| | f" python3 run_mixformer2_axmodel.py -m compiled.axmodel -f car.avi -p {axengine_provider_name}") |
| | print( |
| | f" python3 run_mixformer2_axmodel.py -m compiled.axmodel -f car.avi -p {axclrt_provider_name}") |
| | sys.exit(1) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | ap = ExampleParser() |
| | ap.add_argument('-m', '--model-path', type=str, help='model path', required=True) |
| | ap.add_argument('-f', '--frame-path', type=str, help='frame path', required=True) |
| | ap.add_argument('-r', '--repeat', type=int, help='repeat times', default=100) |
| | ap.add_argument( |
| | '-p', |
| | '--provider', |
| | type=str, |
| | choices=["AUTO", f"{axclrt_provider_name}", f"{axengine_provider_name}"], |
| | help=f'"AUTO", "{axclrt_provider_name}", "{axengine_provider_name}"', |
| | default='AUTO' |
| | ) |
| | ap.add_argument( |
| | '-d', |
| | '--device-id', |
| | type=int, |
| | help=R'axclrt device index, depends on how many cards inserted', |
| | default=0 |
| | ) |
| | args = ap.parse_args() |
| |
|
| | model_file = args.model_path |
| | frame_file = args.frame_path |
| |
|
| | |
| | assert os.path.exists(model_file), f"model file path {model_file} does not exist" |
| | assert os.path.exists(frame_file), f"image file path {frame_file} does not exist" |
| |
|
| | repeat = args.repeat |
| |
|
| | provider = args.provider |
| | device_id = args.device_id |
| |
|
| | main(model_file, frame_file, repeat, provider, device_id) |
| |
|