NeMo_Canary / scripts /export /export_mm_to_trtllm.py
Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script exports multimodal model to TensorRT and do a local inference test.
For multimodal model, it supports the following models:
- NEVA
- Video-NEVA
- LITA
- VILA
- VITA
- SALM
"""
import argparse
import os
from nemo.export.tensorrt_mm_exporter import TensorRTMMExporter
def parse_args():
parser = argparse.ArgumentParser(description='Export multimodal model to TensorRT')
parser.add_argument('--output_dir', required=True, help='Directory to save the exported model')
parser.add_argument(
'--visual_checkpoint_path',
required=True,
help='Path to the visual model checkpoint or perception model checkpoint',
)
parser.add_argument('--llm_checkpoint_path', required=True, help='Source .nemo file for llm')
parser.add_argument(
'--modality',
default="vision",
choices=["vision", "audio"],
help="Modality of the model",
)
parser.add_argument(
'--model_type',
type=str,
required=True,
choices=["neva", "video-neva", "lita", "vila", "vita", "salm"],
help="Type of the model that is supported.",
)
parser.add_argument(
'--llm_model_type',
type=str,
required=True,
choices=["gptnext", "gpt", "llama", "falcon", "starcoder", "mixtral", "gemma"],
help="Type of LLM. gptnext, gpt, llama, falcon, and starcoder are only supported."
" gptnext and gpt are the same and keeping it for backward compatibility",
)
parser.add_argument('--tensor_parallel_size', type=int, default=1, help='tensor parallelism size')
parser.add_argument('--max_input_len', type=int, default=4096, help='Maximum input length')
parser.add_argument('--max_output_len', type=int, default=256, help='Maximum output length')
parser.add_argument('--max_batch_size', type=int, default=1, help='Maximum batch size')
parser.add_argument(
'--vision_max_batch_size',
type=int,
default=1,
help='Max batch size of the visual inputs, for lita/vita model with video inference, this should be set to 256',
)
parser.add_argument('--max_multimodal_len', type=int, default=3072, help='Maximum multimodal length')
parser.add_argument(
"--dtype",
choices=["bfloat16", "float16"],
default="bfloat16",
type=str,
help="dtype of the model on TensorRT",
)
parser.add_argument(
'--delete_existing_files', action='store_true', help='Delete existing files in the output directory'
)
parser.add_argument(
'--test_export_only', action='store_true', help='Only test the export without saving the model'
)
parser.add_argument('--input_text', help='Input text for inference')
parser.add_argument('--input_media', default=None, help='Input media file for inference')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size for inference')
parser.add_argument('--max_output', type=int, default=128, help='Maximum output length for inference')
parser.add_argument('--top_k', type=int, default=1, help='Top k for sampling')
parser.add_argument('--top_p', type=float, default=0.0, help='Top p for sampling')
parser.add_argument("--temperature", default=1.0, type=float, help="temperature")
parser.add_argument("--repetition_penalty", default=1.0, type=float, help="repetition_penalty")
parser.add_argument("--num_beams", default=1, type=int, help="num_beams")
args = parser.parse_args()
return args
def main(args):
exporter = TensorRTMMExporter(model_dir=args.output_dir, load_model=False, modality=args.modality)
exporter.export(
visual_checkpoint_path=args.visual_checkpoint_path,
llm_checkpoint_path=args.llm_checkpoint_path,
model_type=args.model_type,
llm_model_type=args.llm_model_type,
tensor_parallel_size=args.tensor_parallel_size,
max_input_len=args.max_input_len,
max_output_len=args.max_output_len,
max_batch_size=args.max_batch_size,
vision_max_batch_size=args.vision_max_batch_size,
max_multimodal_len=args.max_multimodal_len,
dtype=args.dtype,
delete_existing_files=args.delete_existing_files,
load_model=not args.test_export_only,
)
test_inference = not args.test_export_only
if test_inference:
assert args.input_media is not None, "Input media file is required for inference"
assert os.path.exists(args.input_media), f"Input media file {args.input_media} does not exist"
output = exporter.forward(
input_text=args.input_text,
input_media=args.input_media,
batch_size=args.batch_size,
max_output_len=args.max_output,
top_k=args.top_k,
top_p=args.top_p,
temperature=args.temperature,
repetition_penalty=args.repetition_penalty,
num_beams=args.num_beams,
)
print(output)
if __name__ == '__main__':
args = parse_args()
main(args)