IGEV-plusplus / infer_ax650.py
lihongjie
add ax637 models
0b80cd8
import sys
import argparse
import glob
import numpy as np
import torch
from tqdm import tqdm
from pathlib import Path
from PIL import Image
from matplotlib import pyplot as plt
import os
import onnxruntime as ort
import axengine as axe
def load_image(imfile):
img = np.array(Image.open(imfile).resize((512,384))).astype(np.uint8)[..., :3]
img = torch.from_numpy(img).permute(2, 0, 1).float()
return img[None]
def visualize_disparity(disparity_map, title):
plt.figure(figsize=(10, 6))
plt.imshow(disparity_map, cmap='jet')
plt.colorbar(label="Disparity")
plt.title(title)
plt.axis('off')
# plt.show()
plt.savefig(f"{title}-rt.png")
def demo(args):
# PyTorch 和 ONNX 推理对比
left_images = sorted(glob.glob(args.left_imgs, recursive=True))
right_images = sorted(glob.glob(args.right_imgs, recursive=True))
for (imfile1, imfile2) in tqdm(list(zip(left_images, right_images))):
image1 = load_image(imfile1)
image2 = load_image(imfile2)
# ONNX 推理
ort_session = ort.InferenceSession("models/rt_sceneflow.onnx")
ax_session = axe.InferenceSession("models/rt_sceneflow.axmodel")
input_l_np = image1.cpu().numpy()
input_r_np = image2.cpu().numpy()
ax_inputs = {"left": input_l_np.transpose(0,2,3,1).astype(np.uint8), "right": input_r_np.transpose(0,2,3,1).astype(np.uint8)}
input_l_np = (2 * (input_l_np / 255.0) - 1.0)
input_r_np = (2 * (input_r_np / 255.0) - 1.0)
onnx_inputs = {"left": input_l_np, "right": input_r_np}
onnx_outputs = ort_session.run(None, onnx_inputs)
disp_onnx = onnx_outputs[0].squeeze()
ax_outputs = ax_session.run(None, ax_inputs)
disp_ax = ax_outputs[0].squeeze()
print("disp_onnx",disp_onnx)
print("disp_ax",disp_ax)
visualize_disparity(disp_onnx, title="ONNX Disparity Map")
visualize_disparity(disp_ax, title="AXModel Disparity Map")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-l', '--left_imgs', help="path to all first (left) frames",
default="demo-imgs/im0.png")
parser.add_argument('-r', '--right_imgs', help="path to all second (right) frames",
default="demo-imgs/im1.png")
args = parser.parse_args()
demo(args)