|
|
import sys |
|
|
import argparse |
|
|
import glob |
|
|
import numpy as np |
|
|
import torch |
|
|
from tqdm import tqdm |
|
|
from pathlib import Path |
|
|
from PIL import Image |
|
|
from matplotlib import pyplot as plt |
|
|
import os |
|
|
import onnxruntime as ort |
|
|
import axengine as axe |
|
|
|
|
|
|
|
|
|
|
|
def load_image(imfile): |
|
|
img = np.array(Image.open(imfile).resize((512,384))).astype(np.uint8)[..., :3] |
|
|
img = torch.from_numpy(img).permute(2, 0, 1).float() |
|
|
return img[None] |
|
|
|
|
|
|
|
|
def visualize_disparity(disparity_map, title): |
|
|
plt.figure(figsize=(10, 6)) |
|
|
plt.imshow(disparity_map, cmap='jet') |
|
|
plt.colorbar(label="Disparity") |
|
|
plt.title(title) |
|
|
plt.axis('off') |
|
|
|
|
|
plt.savefig(f"{title}-rt.png") |
|
|
|
|
|
|
|
|
def demo(args): |
|
|
|
|
|
left_images = sorted(glob.glob(args.left_imgs, recursive=True)) |
|
|
right_images = sorted(glob.glob(args.right_imgs, recursive=True)) |
|
|
for (imfile1, imfile2) in tqdm(list(zip(left_images, right_images))): |
|
|
image1 = load_image(imfile1) |
|
|
image2 = load_image(imfile2) |
|
|
|
|
|
|
|
|
ort_session = ort.InferenceSession("models/rt_sceneflow.onnx") |
|
|
ax_session = axe.InferenceSession("models/rt_sceneflow.axmodel") |
|
|
input_l_np = image1.cpu().numpy() |
|
|
input_r_np = image2.cpu().numpy() |
|
|
ax_inputs = {"left": input_l_np.transpose(0,2,3,1).astype(np.uint8), "right": input_r_np.transpose(0,2,3,1).astype(np.uint8)} |
|
|
|
|
|
input_l_np = (2 * (input_l_np / 255.0) - 1.0) |
|
|
input_r_np = (2 * (input_r_np / 255.0) - 1.0) |
|
|
onnx_inputs = {"left": input_l_np, "right": input_r_np} |
|
|
|
|
|
onnx_outputs = ort_session.run(None, onnx_inputs) |
|
|
disp_onnx = onnx_outputs[0].squeeze() |
|
|
|
|
|
ax_outputs = ax_session.run(None, ax_inputs) |
|
|
disp_ax = ax_outputs[0].squeeze() |
|
|
|
|
|
print("disp_onnx",disp_onnx) |
|
|
print("disp_ax",disp_ax) |
|
|
visualize_disparity(disp_onnx, title="ONNX Disparity Map") |
|
|
visualize_disparity(disp_ax, title="AXModel Disparity Map") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('-l', '--left_imgs', help="path to all first (left) frames", |
|
|
default="demo-imgs/im0.png") |
|
|
parser.add_argument('-r', '--right_imgs', help="path to all second (right) frames", |
|
|
default="demo-imgs/im1.png") |
|
|
|
|
|
args = parser.parse_args() |
|
|
demo(args) |
|
|
|