-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Description
Hello!
I have a model.plan file that does not load on the TRTIS server. It's supposed to be a relatively small model (a modification of a ResNet50. It was a .caffemodel file at first but I converted into a .plan file with this script.
import pretrainedmodels
import torch
import pretrainedmodels.utils as utils
from torch.nn import DataParallel, Sequential
from utils import ListImagesDataset, append
from torch.utils.data import DataLoader
import argparse
from tqdm import tqdm
import h5py
from torch.autograd import Variable
import torch.onnx
import torchvision
import tensorflow as tf
#import uff
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import numpy as np
datatype = trt.float32
# The Onnx path is used for Onnx models.
def build_engine_onnx(deploy_file, model_file, max_batch_size):
with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.CaffeParser() as parser:
builder.max_workspace_size = 15 << 20
builder.max_batch_size = max_batch_size
# Load the Onnx model and parse it in order to populate the TensorRT network.
model_tensors = parser.parse(deploy=deploy_file, model=model_file, network=network, dtype=datatype)
print(network.get_layer(network.num_layers-1).get_output(0).shape)
network.mark_output(network.get_layer(network.num_layers - 1).get_output(0))
return builder.build_cuda_engine(network)
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", default="resnet152Max", type=str)
parser.add_argument("--deploy_file", default="model", type=str)
parser.add_argument("--model_file", default="model", type=str)
parser.add_argument("--output_dir", required=True, type=str)
parser.add_argument("--num_workers", default=8, type=int)
parser.add_argument("--batch_size", required=True, type=int)
args = parser.parse_args()
deploy_file = args.deploy_file
model_file = args.model_file
with build_engine_onnx(args.deploy_file, args.model_file, args.batch_size) as engine:
with open(args.model_name+'.plan', 'wb') as f:
print('ok')
f.write(engine.serialize())
When I try to load it into TRTIS, I get :
===============================
== TensorRT Inference Server ==
NVIDIA Release 18.09 (build 688039)
Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
Copyright 2018 The TensorFlow Authors. All rights reserved.
Various files include modifications (c) NVIDIA CORPORATION. All rights reserved.
NVIDIA modifications are covered by the license terms that apply to the underlying
project or file.
I0205 16:59:32.507128 1 server.cc:631] Initializing TensorRT Inference Server
I0205 16:59:32.507207 1 server.cc:680] Reporting prometheus metrics on port 8002
I0205 16:59:33.342361 1 metrics.cc:129] found 8 GPUs supported power usage metric
I0205 16:59:33.348772 1 metrics.cc:139] GPU 0: Tesla V100-SXM2-16GB
I0205 16:59:33.360104 1 metrics.cc:139] GPU 1: Tesla V100-SXM2-16GB
I0205 16:59:33.366832 1 metrics.cc:139] GPU 2: Tesla V100-SXM2-16GB
I0205 16:59:33.373640 1 metrics.cc:139] GPU 3: Tesla V100-SXM2-16GB
I0205 16:59:33.381472 1 metrics.cc:139] GPU 4: Tesla V100-SXM2-16GB
I0205 16:59:33.388678 1 metrics.cc:139] GPU 5: Tesla V100-SXM2-16GB
I0205 16:59:33.396049 1 metrics.cc:139] GPU 6: Tesla V100-SXM2-16GB
I0205 16:59:33.403472 1 metrics.cc:139] GPU 7: Tesla V100-SXM2-16GB
I0205 16:59:33.404022 1 server.cc:884] Starting server 'inference:0' listening on
I0205 16:59:33.404050 1 server.cc:888] localhost:8001 for gRPC requests
I0205 16:59:33.404723 1 server.cc:898] localhost:8000 for HTTP requests
[warn] getaddrinfo: address family for nodename not supported
[evhttp_server.cc : 235] RAW: Entering the event loop ...
I0205 16:59:33.580886 1 server_core.cc:465] Adding/updating models.
I0205 16:59:33.580913 1 server_core.cc:520] (Re-)adding model: classifyNSFW
I0205 16:59:33.580919 1 server_core.cc:520] (Re-)adding model: resnext101_32x4d
I0205 16:59:33.681215 1 basic_manager.cc:739] Successfully reserved resources to load servable {name: resnext101_32x4d version: 1}
I0205 16:59:33.681268 1 loader_harness.cc:66] Approving load for servable version {name: resnext101_32x4d version: 1}
I0205 16:59:33.681277 1 loader_harness.cc:74] Loading servable version {name: resnext101_32x4d version: 1}
I0205 16:59:33.781259 1 basic_manager.cc:739] Successfully reserved resources to load servable {name: classifyNSFW version: 1}
I0205 16:59:33.781313 1 loader_harness.cc:66] Approving load for servable version {name: classifyNSFW version: 1}
I0205 16:59:33.781326 1 loader_harness.cc:74] Loading servable version {name: classifyNSFW version: 1}
I0205 16:59:33.824357 1 plan_bundle.cc:301] Creating instance classifyNSFW_0_0_gpu0 on GPU 0 (7.0) using model.plan
I0205 16:59:34.841770 1 logging.cc:39] Glob Size is 56 bytes.
I0205 16:59:34.843509 1 logging.cc:39] Added linear block of size 8589934597
I0205 16:59:34.843521 1 logging.cc:39] Added linear block of size 18446532056943951878
I0205 16:59:34.843525 1 logging.cc:39] Added linear block of size 47244640284
I0205 16:59:34.843528 1 logging.cc:39] Added linear block of size 154618822688
I0205 16:59:34.843531 1 logging.cc:39] Added linear block of size 18446744069414584508
I0205 16:59:34.843534 1 logging.cc:39] Added linear block of size 17179869216
I0205 16:59:34.843538 1 logging.cc:39] Added linear block of size 1651470960
I0205 16:59:34.843541 1 logging.cc:39] Added linear block of size 1305670057985
I0205 16:59:34.843544 1 logging.cc:39] Added linear block of size 773094113281
I0205 16:59:34.843547 1 logging.cc:39] Added linear block of size 17179869185
I0205 16:59:34.843550 1 logging.cc:39] Added linear block of size 38630843628
I0205 16:59:34.843554 1 logging.cc:39] Added linear block of size 17179869200
It seems like it tries to add some really big linear block, but I have no idea why...
Here is my .plan file:
Thank you!
Activity
deadeyegoodwin commentedon Feb 6, 2019
Thanks for the report. Your script is a llittle confusing since it mentions ONNX but is really using Caffe, but I don't see anything obviously wrong with it. I see from the log that you are using a CUDA Capability 7.0 GPU with TRTIS. Did you generate the TRT plan on a system with the same capability GPU? That is required.
Can you try using caffe2plan.cc from src/test to generate your plan file. If that still fails for you then we will know it is not a problem with your script.
gitvipin commentedon Feb 7, 2019
@stygian2a Do we get any extra information after changing log to DEBUG log ?
Basically changing
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
to
TRT_LOGGER = trt.Logger(trt.Logger.DEBUG)
stygian2a commentedon Feb 7, 2019
@deadeyegoodwin Yes, the TRT plan file was generated with the same capability GPU. I don't know much about c++; what exactly do I need to do?
@gitvipin here is the output with Logger.INFO (DEBUG does not exists) https://pastebin.com/5t6BJaSx
Thank you!
deadeyegoodwin commentedon Feb 7, 2019
@stygian2a Can you share the caffe files and the command-line you used to run your script and we will try to reproduce.
stygian2a commentedon Feb 7, 2019
@deadeyegoodwin Of course!
nsfw_model.zip
The command-line is:
python CreateTensorRTModelfromCaffe.py --model_name model --deploy_file deploy.prototxt --model_file resnet_50_1by2_nsfw.caffemodel --output_dir . --batch_size 1
Thank you again!
deadeyegoodwin commentedon Feb 7, 2019
Your model worked ok for me after converting with caffe2plan.
I0207 17:22:43.158137 309 logging.cc:49] Glob Size is 23962120 bytes.
I0207 17:22:43.161996 309 logging.cc:49] Added linear block of size 3211264
I0207 17:22:43.162029 309 logging.cc:49] Added linear block of size 1605632
I0207 17:22:43.162033 309 logging.cc:49] Added linear block of size 802816
I0207 17:22:43.162036 309 logging.cc:49] Added linear block of size 50176
I0207 17:22:44.287410 309 logging.cc:49] Deserialize required 1129671 microseconds.
I see that you are using 18.09 version of the inference server. That is quite old. Can you update to the 19.01 version and try again. Also, where are you running your script to generate the TensorRT plan. The plan must be generated using the same version of TensorRT that is being used by the inference server. The easiest way to ensure that is to run your script within the TensorRT container that is the same version as the inference server. So, use this container to run your script and generate the plan:
docker pull nvcr.io/nvidia/tensorrt:19.01-py3
And this inference server to run it:
$ docker pull nvcr.io/nvidia/tensorrtserver:19.01-py3
stygian2a commentedon Feb 11, 2019
I updated tensorrt and tensorrt inference server and it works now! Thank you!!