-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathexport_trt_from_directory.py
62 lines (50 loc) · 2.26 KB
/
export_trt_from_directory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import torch
import time
from utilities import Engine
def export_trt(trt_path=None, onnx_path=None, use_fp16=True):
option = input("Choose an option:\n1. Convert a single ONNX file\n2. Convert all ONNX files in a directory\nEnter your choice (1 or 2): ")
if option == '1':
onnx_path = input("Enter the path to the ONNX model (e.g ./realesrgan.onnx): ")
onnx_files = [onnx_path]
trt_dir = input("Enter the path to save the TensorRT engine (e.g ./trt_engine/): ")
elif option == '2':
onnx_dir = input("Enter the directory path containing ONNX models (e.g ./onnx_models/): ")
onnx_files = [os.path.join(onnx_dir, file) for file in os.listdir(onnx_dir) if file.endswith('.onnx')]
if not onnx_files:
raise ValueError(f"No .onnx files found in directory: {onnx_dir}")
trt_dir = input("Enter the directory path to save the TensorRT engines (e.g ./trt_engine/): ")
else:
raise ValueError("Invalid option. Please choose either 1 or 2.")
# Check if trt_dir already exists as a directory
if not os.path.exists(trt_dir):
os.makedirs(trt_dir)
#os.makedirs(trt_dir, exist_ok=True)
total_files = len(onnx_files)
for index, onnx_path in enumerate(onnx_files):
engine = Engine(trt_path)
torch.cuda.empty_cache()
base_name = os.path.splitext(os.path.basename(onnx_path))[0]
trt_path = os.path.join(trt_dir, f"{base_name}.engine")
print(f"Converting {onnx_path} to {trt_path}")
s = time.time()
# Initialize Engine with trt_path and clear CUDA cache
engine = Engine(trt_path)
torch.cuda.empty_cache()
ret = engine.build(
onnx_path,
use_fp16,
enable_preview=True,
input_profile=[
{"input": [(1,3,256,256), (1,3,512,512), (1,3,1280,1280)]}, # any sizes from 256x256 to 1280x1280
],
)
e = time.time()
print(f"Time taken to build: {(e-s)} seconds")
if index < total_files - 1:
# Delay for 10 seconds
print("Delaying for 10 seconds...")
time.sleep(10)
print("Resuming operations after delay...")
return
export_trt()