Skip to content

dynamic shape #328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
OValery16 opened this issue Nov 12, 2019 · 33 comments
Closed

dynamic shape #328

OValery16 opened this issue Nov 12, 2019 · 33 comments
Labels
question Further information is requested triaged Issue has been triaged by maintainers

Comments

@OValery16
Copy link

OValery16 commented Nov 12, 2019

My goal is to export the resnet18 model from pytorch to tensorRT. For the sake of experimentation, I use the resnet18 from torchvision. In addition, I define the input/output data as variables of dynamic shape (batch_size,3,224,224).

After exporting my model to onnx, I use onnx-tensorrt to re-export export it to tensorrt and I got the following error: tensorrt failed to convert it and stated that the networ has dynamic or shape inputs, but no optimization profile has been defined.

I exported the resnet18 model from pytorch to onnix via the following code:

        from torch.autograd import Variable
        import torch.onnx
        import torchvision

        # Export model to onnx format
        dummy_input = Variable(torch.randn(1, 3, 224, 224))
        model = torchvision.models.resnet18(pretrained=True)

        input_names = ["data"]
        output_names = ["output"]

        torch.onnx.export(
            model,
            dummy_input,
            "resnet18.onnx",
            dynamic_axes={"data": {0: "batch"}, "output": {0: "batch"}},
            input_names=input_names,
            output_names=output_names,
        )


root@6d1fc8dca772:/# onnx2trt resnet18.onnx -o my_engine.trt -vvv
----------------------------------------------------------------
Input filename:   resnet18.onnx
ONNX IR version:  0.0.4
Opset version:    9
Producer name:    pytorch
Producer version: 1.3
Domain:           
Model version:    0
Doc string:       
----------------------------------------------------------------
Parsing model
[2019-11-08 17:20:17 UNKNOWN] Plugin Creator registration succeeded - GridAnchor_TRT
[2019-11-08 17:20:17 UNKNOWN] Plugin Creator registration succeeded - NMS_TRT
[2019-11-08 17:20:17 UNKNOWN] Plugin Creator registration succeeded - Reorg_TRT
[2019-11-08 17:20:17 UNKNOWN] Plugin Creator registration succeeded - Region_TRT
[2019-11-08 17:20:17 UNKNOWN] Plugin Creator registration succeeded - Clip_TRT
[2019-11-08 17:20:17 UNKNOWN] Plugin Creator registration succeeded - LReLU_TRT
[2019-11-08 17:20:17 UNKNOWN] Plugin Creator registration succeeded - PriorBox_TRT
[2019-11-08 17:20:17 UNKNOWN] Plugin Creator registration succeeded - Normalize_TRT
[2019-11-08 17:20:17 UNKNOWN] Plugin Creator registration succeeded - RPROI_TRT
[2019-11-08 17:20:17 UNKNOWN] Plugin Creator registration succeeded - BatchedNMS_TRT
[2019-11-08 17:20:17 UNKNOWN] Plugin Creator registration succeeded - FlattenConcat_TRT
[2019-11-08 17:20:17    INFO] 123:Conv -> (-1, 64, 112, 112)
[2019-11-08 17:20:17    INFO] 124:BatchNormalization -> (-1, 64, 112, 112)
[2019-11-08 17:20:17    INFO] 125:Relu -> (-1, 64, 112, 112)
[2019-11-08 17:20:17    INFO] 126:MaxPool -> (-1, 64, 56, 56)
[2019-11-08 17:20:17    INFO] 127:Conv -> (-1, 64, 56, 56)
[2019-11-08 17:20:17    INFO] 128:BatchNormalization -> (-1, 64, 56, 56)
[2019-11-08 17:20:17    INFO] 129:Relu -> (-1, 64, 56, 56)
[2019-11-08 17:20:17    INFO] 130:Conv -> (-1, 64, 56, 56)
[2019-11-08 17:20:17    INFO] 131:BatchNormalization -> (-1, 64, 56, 56)
[2019-11-08 17:20:17    INFO] 132:Add -> (-1, 64, 56, 56)
[2019-11-08 17:20:17    INFO] 133:Relu -> (-1, 64, 56, 56)
[2019-11-08 17:20:17    INFO] 134:Conv -> (-1, 64, 56, 56)
[2019-11-08 17:20:17    INFO] 135:BatchNormalization -> (-1, 64, 56, 56)
[2019-11-08 17:20:17    INFO] 136:Relu -> (-1, 64, 56, 56)
[2019-11-08 17:20:17    INFO] 137:Conv -> (-1, 64, 56, 56)
[2019-11-08 17:20:17    INFO] 138:BatchNormalization -> (-1, 64, 56, 56)
[2019-11-08 17:20:17    INFO] 139:Add -> (-1, 64, 56, 56)
[2019-11-08 17:20:17    INFO] 140:Relu -> (-1, 64, 56, 56)
[2019-11-08 17:20:17    INFO] 141:Conv -> (-1, 128, 28, 28)
[2019-11-08 17:20:17    INFO] 142:BatchNormalization -> (-1, 128, 28, 28)
[2019-11-08 17:20:17    INFO] 143:Relu -> (-1, 128, 28, 28)
[2019-11-08 17:20:17    INFO] 144:Conv -> (-1, 128, 28, 28)
[2019-11-08 17:20:17    INFO] 145:BatchNormalization -> (-1, 128, 28, 28)
[2019-11-08 17:20:17    INFO] 146:Conv -> (-1, 128, 28, 28)
[2019-11-08 17:20:17    INFO] 147:BatchNormalization -> (-1, 128, 28, 28)
[2019-11-08 17:20:17    INFO] 148:Add -> (-1, 128, 28, 28)
[2019-11-08 17:20:17    INFO] 149:Relu -> (-1, 128, 28, 28)
[2019-11-08 17:20:17    INFO] 150:Conv -> (-1, 128, 28, 28)
[2019-11-08 17:20:17    INFO] 151:BatchNormalization -> (-1, 128, 28, 28)
[2019-11-08 17:20:17    INFO] 152:Relu -> (-1, 128, 28, 28)
[2019-11-08 17:20:17    INFO] 153:Conv -> (-1, 128, 28, 28)
[2019-11-08 17:20:17    INFO] 154:BatchNormalization -> (-1, 128, 28, 28)
[2019-11-08 17:20:17    INFO] 155:Add -> (-1, 128, 28, 28)
[2019-11-08 17:20:17    INFO] 156:Relu -> (-1, 128, 28, 28)
[2019-11-08 17:20:17    INFO] 157:Conv -> (-1, 256, 14, 14)
[2019-11-08 17:20:17    INFO] 158:BatchNormalization -> (-1, 256, 14, 14)
[2019-11-08 17:20:17    INFO] 159:Relu -> (-1, 256, 14, 14)
[2019-11-08 17:20:17    INFO] 160:Conv -> (-1, 256, 14, 14)
[2019-11-08 17:20:17    INFO] 161:BatchNormalization -> (-1, 256, 14, 14)
[2019-11-08 17:20:17    INFO] 162:Conv -> (-1, 256, 14, 14)
[2019-11-08 17:20:17    INFO] 163:BatchNormalization -> (-1, 256, 14, 14)
[2019-11-08 17:20:17    INFO] 164:Add -> (-1, 256, 14, 14)
[2019-11-08 17:20:17    INFO] 165:Relu -> (-1, 256, 14, 14)
[2019-11-08 17:20:17    INFO] 166:Conv -> (-1, 256, 14, 14)
[2019-11-08 17:20:17    INFO] 167:BatchNormalization -> (-1, 256, 14, 14)
[2019-11-08 17:20:17    INFO] 168:Relu -> (-1, 256, 14, 14)
[2019-11-08 17:20:17    INFO] 169:Conv -> (-1, 256, 14, 14)
[2019-11-08 17:20:17    INFO] 170:BatchNormalization -> (-1, 256, 14, 14)
[2019-11-08 17:20:17    INFO] 171:Add -> (-1, 256, 14, 14)
[2019-11-08 17:20:17    INFO] 172:Relu -> (-1, 256, 14, 14)
[2019-11-08 17:20:17    INFO] 173:Conv -> (-1, 512, 7, 7)
[2019-11-08 17:20:17    INFO] 174:BatchNormalization -> (-1, 512, 7, 7)
[2019-11-08 17:20:17    INFO] 175:Relu -> (-1, 512, 7, 7)
[2019-11-08 17:20:17    INFO] 176:Conv -> (-1, 512, 7, 7)
[2019-11-08 17:20:17    INFO] 177:BatchNormalization -> (-1, 512, 7, 7)
[2019-11-08 17:20:17    INFO] 178:Conv -> (-1, 512, 7, 7)
[2019-11-08 17:20:17    INFO] 179:BatchNormalization -> (-1, 512, 7, 7)
[2019-11-08 17:20:17    INFO] 180:Add -> (-1, 512, 7, 7)
[2019-11-08 17:20:17    INFO] 181:Relu -> (-1, 512, 7, 7)
[2019-11-08 17:20:17    INFO] 182:Conv -> (-1, 512, 7, 7)
[2019-11-08 17:20:17    INFO] 183:BatchNormalization -> (-1, 512, 7, 7)
[2019-11-08 17:20:17    INFO] 184:Relu -> (-1, 512, 7, 7)
[2019-11-08 17:20:17    INFO] 185:Conv -> (-1, 512, 7, 7)
[2019-11-08 17:20:17    INFO] 186:BatchNormalization -> (-1, 512, 7, 7)
[2019-11-08 17:20:17    INFO] 187:Add -> (-1, 512, 7, 7)
[2019-11-08 17:20:17    INFO] 188:Relu -> (-1, 512, 7, 7)
[2019-11-08 17:20:17    INFO] 189:GlobalAveragePool -> (-1, 512, 1, 1)
[2019-11-08 17:20:17    INFO] 190:Flatten -> (-1, -1)
[2019-11-08 17:20:17    INFO] output:Gemm -> (-1, 1000)
Building TensorRT engine, FP16 available:1
    Max batch size:     32
    Max workspace size: 1024 MiB
[2019-11-08 17:20:17   ERROR] Network has dynamic or shape inputs, but no optimization profile has been defined.
terminate called after throwing an instance of 'std::runtime_error'
  what():  Failed to create object
Aborted (core dumped)

@lucasjinreal
Copy link
Contributor

@OValery16 Did u using onnx2trt convert or using your own plugin parse?

@OValery16
Copy link
Author

I use the following command onnx2trt resnet18.onnx -o my_engine.trt -vvv in the onnx-tensorrt docker image

@uefall
Copy link

uefall commented Dec 5, 2019

same problem with pytorch1.3 and TRT 6.0.1.5
so what's the right method to generate a batch infer trt model with onnx-tensorrt?
if I don't use dynamic shape, trt model could be generated, but while inference,get_binding_shape(binding) will show 1,3,w,h and this warning will occur

[TensorRT] WARNING: Explicit batch network detected and batch size specified, use enqueue without batch size instead.

@cocoyen1995
Copy link

Hi @OValery16 , did you solve the problem? I have a similar issue here...
My model is trained with Unet in keras, converted to onnx using keras2onnx.
My first dimension of the input layer is batch_size too -> (?, 512,960,3), and the error code during my conversion using the sample code "sampleOnnxMNIST.cpp" showed

----------------------------------------------------------------
Input filename:   B235_preR_v8.onnx
ONNX IR version:  0.0.6
Opset version:    11
Producer name:    keras2onnx
Producer version: 1.6.0
Domain:           onnx
Model version:    0
Doc string:
----------------------------------------------------------------
before builder
initialize profile
after setDimensions Input2
after setDimensions conv2d_38
add profile to config
[E] [TRT] Network has dynamic or shape inputs, but no optimization profile has been defined.
[E] [TRT] Network validation failed.

I've looked into the sample code "sampleDynamicReshape.cpp" to see how to add optimization profile during the conversion.
But the document of TensorRT7.0.0 shows that it can't be done with this kind of format(?
1. The network definition must not have an implicit batch dimension.
So I don't know how to deal with my problem now...
(I've tried to remove the batch_size dimension of my model in python using keras, but it doesn't work...)
Any help or advice is appreciated !

@cwentland0
Copy link

I'm having the same issue as @cocoyen1995, using keras2onnx to convert a tf.keras model to ONNX, and then attempting to use onnx2trt to create the inference engine. The result is the same dynamic inputs error.

It seems that the general ONNX parser cannot handle dynamic batch sizes. From the TensorRT C++ API documentation: Note: In TensorRT 7.0, the ONNX parser only supports full-dimensions mode, meaning that your network definition must be created with the explicitBatch flag set.

In the Working with Dynamic Shapes section, there is no explicit mention of the ONNX parser. I believe that the dynamic shape specification is only for non-batch shapes (e.g. H, W), in which case one would need to build the optimization profile following the given instructions. This kind of defeats the purpose of onnx2trt (easy construction of the engine from the ONNX networks), forcing one to go through the C++/Python API.

I find this all a little strange, as batching is critical for good inference performance, and the UFF format has been deprecated as of TensorRT 7. What other options do we have besides importing from Caffe, or (God forbid) building TRT network definitions from scratch and manually loading in weights?

@jianyin2016
Copy link

Any updates?

@rmccorm4
Copy link

rmccorm4 commented Feb 27, 2020

Hi, @OValery16 ,

You can peek at the code here: https://github.com/rmccorm4/tensorrt-utils/blob/master/classification/imagenet/onnx_to_tensorrt.py as a rough reference for converting ONNX models to TensorRT, taking dynamic batch into consideration.

For example, you could try something like

./onnx_to_tensorrt.py --explicit-batch --onnx resnet18.onnx

And that should create some default optimization profiles with various batch sizes. You can tweak these numbers manually in the script or make your own script based off it.


@cwentland0 dynamic shape refers to any dimension with a value of -1, including the batch dimension (when considering an explicit batch network). UFF models do not support explicit batch, they are implicit batch only. I'm not sure about UFF support for dynamic shapes off the top of my head, but I don't think it's supported. You could try my script mentioned above as a reference.

All of the logic in that script can be applied to a C++ version if necessary as well, as it's the same API.

Additionally, I believe TensorRT's trtexec comes with a little bit of extra logic on top of onnx2trt to create a default optimization profile for you if none were specified (which is the error you're getting). I would generally stick with using trtexec over onnx2trt for simplicity.

@OValery16
Copy link
Author

@rmccorm4
Copy link

The link is wrong

Fixed, thanks.

@cwentland0
Copy link

@rmccorm4 I attempted to use your code, as I am at my wits' end trying to get trtexec to produce an engine with a max batch size greater than 1 from an ONNX model with a dynamic batch size. All specification of min/opt/maxShapes simply produces an engine which, when deserialized with the C++ API, only has one optimization profile and a getMaxBatchSize() output of 1. The verbose output of trtexec implies that an optimization profile has been created for the specified dimensions, but attempting to access any other optimization profile beyond the zeroth profile which has a dynamic batch size (via setOptimizationProfile()) throws an unsurprising parameter check failure.

However, your script produced the same result, so I'm 100% sure I'm just being an idiot. Even if I generate an engine with entirely fixed dimensions, the getMaxBatchSize() is still equal to one (despite the batch dimension being greater than one). Before I proceed (since I'll need to make some significant code changes), can TensorRT still make batched inference on many samples if the batch size is greater than one but getMaxBatchSize() is equal to one? The documentation seems to imply this is not the case (a lot of mentions to defining the max batch size when building the engine).

@rmccorm4
Copy link

rmccorm4 commented Mar 5, 2020

Hi @cwentland0

I am at my wits' end trying to get trtexec to produce an engine with a max batch size greater than 1 from an ONNX model with a dynamic batch size.

I believe trtexec is only capable of creating an engine with one optimization profile (profile index 0) at the moment, I don't think it can create multiple profiles for the same engine. ThIt's mostly used for quick parsing/performance testing. To create an engine with multiple profiles (profile index 0, 1, 2, ..., N), you'll need to use the API - my script would be a very rough example of that.

getMaxBatchSize() output of 1

I believe getMaxBatchSize() only makes sense in the context of implicit batch networks/engines. ONNX networks in TRT 7 are required to be explicit batch, so getMaxBatchSize() should be irrelevant. The real "max batch size" will be whatever the highest allowed batch dimension of your optimization profiles is (batch dimension in kMAX dims for a single profile) in explicit batch networks/engines.

I believe something like this for an explicit batch ONNX model with a dynamic batch dimension (-1):

trtexec --explicitBatch --onnx=model.onnx --minShapes=input:1x3x224x224 --optShapes=input:32x3x224x224 --maxShapes=input:32x3x224x224 --saveEngine=model.batch1-32.engine

Would be roughly equivalent to setting builder.maxBatchSize = 32 for an implicit batch model, since implicit batch engines support batch size from 1 to maxBatchSize and optimize for their maxBatchSize, and in the example above, our optimization profile supports batch sizes from 1-32, and we set kOPT (the shape to optimize for) to batch size 32.

At least, that's my understanding. I hope this clears some things up.

@zheng-xing
Copy link

I am dealing with a ONNX model that is used for segmentation. All discussion above are about batch axis. I'm facing a problem of dynamic size on the other axis. Specifically, can I use ONNX and ONNX2trt to handle Tensor shape of (?, 1, ?, ?)? Yes, the width and height needs to be dynamic in this case. Thanks very much.

@rmccorm4
Copy link

rmccorm4 commented Mar 19, 2020

Hi @zheng-xing,

As long as you adhere to the restrictions here: https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#rest_dynamic_shapes, I think it should work.

You could probably quickly test this with something like:

# "input" should be your input layer name
trtexec --explicitBatch --onnx=model.onnx \
--minShapes=input:1x1x1x1 \
--optShapes=input:8x1x200x200 \
--maxShapes=input:16x1x500x500 \
--shapes=input:8x1x250x250 \      # Actual inference input shape
--saveEngine=model.engine

I chose arbitrary dimensions above - be sure to choose min,opt,max shape values for batch/height/width that are needed for your use case.

@zheng-xing
Copy link

zheng-xing commented Mar 23, 2020

Hi @zheng-xing,

As long as you adhere to the restrictions here: https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#rest_dynamic_shapes, I think it should work.

You could probably quickly test this with something like:

# "input" should be your input layer name
trtexec --explicitBatch --onnx=model.onnx \
--minShapes=input:1x1x1x1 \
--optShapes=input:8x1x200x200 \
--maxShapes=input:16x1x500x500 \
--shapes=input:8x1x250x250 \      # Actual inference input shape
--saveEngine=model.engine

I chose arbitrary dimensions above - be sure to choose min,opt,max shape values for batch/height/width that are needed for your use case.

Hi @rmccorm4 ,

Thanks! I really appreciate your suggestion. I had tried this and this command does run without any problem. An engine is generated. However, the segmentation results are wrong for any shapes that are within range of minShapes and maxShapes, but different from the input shape specified in ONNX model file.

I think the reasons is related to how I use this inference engine. My code samples is like following:

    int dim1 = 256, dim2 = 512;
    static float data[dim1][dim2];    // some gray scale image
    float *mydata = &(data[0][0]);

    float* predictions = new float[2 * dim1 * dim2];    // Two class segmentation: background and foreground

    IExecutionContext* context = engine->createExecutionContext();

    assert(engine->getNbBindings() == 2);
    void* buffers[2];

    int inputIndex, outputIndex;

    printf("Bindings after deserializing:\n");
    for (int bi = 0; bi < engine->getNbBindings(); bi++)
    {
        if (engine->bindingIsInput(bi) == true)
        {
            inputIndex = bi;
            printf("Binding %d (%s): Input.\n", bi, engine->getBindingName(bi));
        }
        else
        {
            outputIndex = bi;
            printf("Binding %d (%s): Output.\n", bi, engine->getBindingName(bi));
        }
    }

    // Create GPU buffers on device
    cudaMalloc(&buffers[inputIndex], dim1 * dim2 * sizeof(float));
    cudaMalloc(&buffers[outputIndex], 2 * dim1 * dim2 * sizeof(float));

    for (int i = 0; i < 1; i++) {
        // Create stream
        cudaStream_t stream;
        cudaStreamCreate(&stream);

        // DMA input batch data to device, infer on the batch asynchronously, and DMA output back to host
        cudaMemcpyAsync(buffers[inputIndex], mydata, dim1 * dim2 * sizeof(float), cudaMemcpyHostToDevice, stream);
        context->enqueueV2(buffers, stream, nullptr);
        cudaMemcpyAsync(predictions, buffers[outputIndex], 2 * dim1 * dim2 * sizeof(float), cudaMemcpyDeviceToHost, stream);
        cudaStreamSynchronize(stream);
    }

Because in GPU memory, the data is 1d array. An input data with shape 256 x 512 cannot be distinguished with an input data with shape 512 x 256. Thus have problems.

Do you have any idea how to use TensorRT in my scenario? Thanks.

@rmccorm4
Copy link

rmccorm4 commented Mar 23, 2020

Hi @zheng-xing,

but different from the input shape specified in ONNX model file.

Shouldn't the input shape in the ONNX model file be (-1, 1, -1, -1)? Otherwise, I don't think the optimization profiles would work correctly / make sense.

Generally, flattening to 1-D and re-expanding later shouldn't cause any issues.

An input data with shape 256 x 512 cannot be distinguished with an input data with shape 512 x 256. Thus have problems.

I think that's the point of setting the binding dimensions at runtime.

You should specify something like this:

context->setBindingDimensions(index, Dims4(1, 1, 256, 512));
context->enqueueV2(buffers, stream, nullptr);

or

context->setBindingDimensions(index, Dims4(1, 1, 512, 256));
context->enqueueV2(buffers, stream, nullptr);

depending on the input.

See this section: https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#runtime_dimensions


If you have a small number of possible input shapes, and enough memory, I believe it's more performant to create a context for each possible input shape and set the binding dimensions once on each context, then select your context to execute on based on the input shape rather than setting the binding dimension before every inference. Though this is certainly something you have to trade-off depending on resources and use case.

@zheng-xing
Copy link

Hi @zheng-xing,

but different from the input shape specified in ONNX model file.

Shouldn't the input shape in the ONNX model file be (-1, 1, -1, -1)? Otherwise, I don't think the optimization profiles would work correctly / make sense.

Generally, flattening to 1-D and re-expanding later shouldn't cause any issues.

An input data with shape 256 x 512 cannot be distinguished with an input data with shape 512 x 256. Thus have problems.

I think that's the point of setting the binding dimensions at runtime.

You should specify something like this:

context->setBindingDimensions(index, Dims4(1, 1, 256, 512));
context->enqueueV2(buffers, stream, nullptr);

or

context->setBindingDimensions(index, Dims4(1, 1, 512, 256));
context->enqueueV2(buffers, stream, nullptr);

depending on the input.

See this section: https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#runtime_dimensions

If you have a small number of possible input shapes, and enough memory, I believe it's more performant to create a context for each possible input shape and set the binding dimensions once on each context, then select your context to execute on based on the input shape rather than setting the binding dimension before every inference. Though this is certainly something you have to trade-off depending on resources and use case.

Thanks @rmccorm4 for your detailed explanations! I believe setBindingDimension is the key here.

Just one more question, I just tried it with a 3D segmentation case, the trtexec command gives me following error messages:

Setting layouts of network and plugin input/output tensors to linear, as 3D operators are found and 3D non-linear IO formats are not supported yet.

Assertion failed: x must be built-time constant.

That means the dynamic shapes are not supported for 3D segmentations yet, right? Thanks.

@rmccorm4
Copy link

I don't know the specific ops that a 3D segmentation is comprised of, but given that I'm assuming your channel dimension is constant from the discussion above, sounds like maybe a FullyConnectedLayer might be an issue for your shapes?

IConvolutionLayer and IDeconvolutionLayer require that the channel dimension be a build-time constant.
IFullyConnectedLayer requires that the last three dimensions be build-time constants.
Int8 requires that the channel dimension be a build-time constant.

https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#rest_dynamic_shapes


Is there more to the error that says what op that assertion failed for?

@zheng-xing
Copy link

Hi @rmccorm4 ,

Yes, the input has fixed channel size of 1. And I only use batchsize=1. The only thing I want to change is the input width, height or depth.

The network has no FullyConnectedLayer. It has convolution layers with different dilations.
The error I got does not indicate which op. It only says:

[W] [TRT] Setting layouts of network and plugin input/output tensors to linear, as 3D operators are found and 3D non-linear IO formats are not supported yet.
[F] [TRT] Assertion failed: x must be built-time constant.
C:\source\builder\symbolicDims.cpp:571
Aborting...

[E] [TRT] C:\source\builder\symbolicDims.cpp (571) - Assertion Error in nvinfer::builder::fromSymbolic:0 (x must be built-time constant)
[E] [TRT] Engine creation failed
[E] [TRT] Engine setup failed

@ganler
Copy link

ganler commented May 19, 2020

@rmccorm4 I met a problem with onnx version when I am trying to run the dynamic shape mode of my model.

ONNX IR version:  0.0.6
Opset version:    11
Producer name:    tf2onnx
Producer version: 1.5.6
Domain:           
Model version:    0
Doc string:       
----------------------------------------------------------------
ERROR: ModelImporter.cpp:457 In function importModel:
[4] Assertion failed: !_importer_ctx.network()->hasImplicitBatchDimension() && "This version of the ONNX parser only supports TensorRT INetworkDefinitions with an explicit batch dimension. Please ensure the network was created using the EXPLICIT_BATCH NetworkDefinitionCreationFlag."
ERROR: Failed to parse ONNX in data type: float32

I have the following questions:

  1. If version 0.0.6 cannot work, which version should I use?
  2. If there's a possibility to let this version work, what should I do?

Here's my environment settings:

  • Ubuntu 18.04
  • CUDA Version: 10.2
  • TRT: 7.0.0.11
  • Building codes:
    static nvinfer1::ICudaEngine*
    create_onnx_engine(const std::string& model_file, int max_batch_size, nvinfer1::DataType dtype)
    {
        destroy_ptr<nvinfer1::IBuilder> builder(nvinfer1::createInferBuilder(gLogger));

        destroy_ptr<nvinfer1::INetworkDefinition> network(builder->createNetworkV2(0));
        destroy_ptr<nvonnxparser::IParser> parser(nvonnxparser::createParser(*network, gLogger));

        if (!parser->parseFromFile(model_file.c_str(), static_cast<int>(nvinfer1::ILogger::Severity::kWARNING))) {
            gLogger.log(
                nvinfer1::ILogger::Severity::kERROR,
                ("Failed to parse ONNX in data type: " + to_string(dtype)).c_str());
            exit(1);
        }

        builder->setMaxBatchSize(max_batch_size);
        destroy_ptr<nvinfer1::IBuilderConfig> config(builder->createBuilderConfig());
        config->setMaxWorkspaceSize((1 << 20) * 512); // TODO: A better way to set the workspace.

        auto engine = builder->buildEngineWithConfig(*network, *config);

        if (nullptr == engine) {
            gLogger.log(nvinfer1::ILogger::Severity::kERROR, "Failed to created engine");
            exit(1);
        }
        return engine;
    }

Thanks for your patience in advance.

@rmccorm4
Copy link

Hi @ganler ,

[4] Assertion failed: !_importer_ctx.network()->hasImplicitBatchDimension() && "This version of the ONNX parser only supports TensorRT INetworkDefinitions with an explicit batch dimension. Please ensure the network was created using the EXPLICIT_BATCH NetworkDefinitionCreationFlag."
ERROR: Failed to parse ONNX in data type: float32

Probably from this line:

destroy_ptr<nvinfer1::INetworkDefinition> network(builder->createNetworkV2(0));

Try setting the EXPLICIT_BATCH flag here instead:

const auto explicitBatch = 1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);  
INetworkDefinition* network = builder->createNetworkV2(explicitBatch);

@rmccorm4
Copy link

@ganler
Copy link

ganler commented May 19, 2020

@rmccorm4 Thanks for your help. I still want to know how to get the output shape of a dynamic shape input. Is there any material that I can reference?

@rmccorm4
Copy link

rmccorm4 commented May 19, 2020

Hi @ganler ,

  1. You can reference the sampleDynamicReshape here: https://github.com/NVIDIA/TensorRT/blob/master/samples/opensource/sampleDynamicReshape/sampleDynamicReshape.cpp#L239. After you've specified the binding dimensions for all input bindings, TensorRT internally calculates the output binding dimensions, and you can query them with context->getBindingDimensions(output_binding_index).

  2. This blog post might help, and is written in C++: https://devblogs.nvidia.com/speeding-up-deep-learning-inference-using-tensorrt/

  3. I also wrote a minimal end-to-end example here: https://gist.github.com/rmccorm4/dabccb1f31dbdcf1019a4df431067e52 (It's in Python, but the same flow/order of API calls can be translated to C++). Specifically this function for setting input binding dimensions and getting output binding dimensions.

@ChauncyJin
Copy link

HI @rmccorm4,
Is there any C++ examples for running ONNX model by tensorRT with dynamic shape?
I export a pytorch model to an ONNX model with dynamic_axes arguments :
dynamic_axes = {'input': {2: "height", 3: 'width'}}
then run the ONNX model on tensorRT,but I get these errors:
While parsing node number 66 [GlobalAveragePool]:
ERROR: /TensorRT/parsers/onnx/builtin_op_importers.cpp:1086 In function importGlobalAveragePool:
[8] Assertion failed: !isDynamic(kernelSize) && "Cannot run global average pool on an input with dynamic spatial dimensions!"

@thancaocuong
Copy link

hi @rmccorm4 I did try your code to convert model from pytorch >> onnx >> tensorrt. But when inferrence with batchsize (smaller than max_batchsize), it throw an exception:

[TensorRT] ERROR: Parameter check failed at: engine.cpp::enqueue::387, condition: batchSize > 0 && batchSize <= mEngine.getMaxBatchSize(). Note: Batch size was: 16, but engine max batch size was: 1

I try to convert dynamic shape onnx model (in batchsize demension) to tensorrt. So could you tell me the problem?

@chen-san
Copy link

HI @rmccorm4,
Is there any C++ examples for running ONNX model by tensorRT with dynamic shape?
I export a pytorch model to an ONNX model with dynamic_axes arguments :
dynamic_axes = {'input': {2: "height", 3: 'width'}}
then run the ONNX model on tensorRT,but I get these errors:
While parsing node number 66 [GlobalAveragePool]:
ERROR: /TensorRT/parsers/onnx/builtin_op_importers.cpp:1086 In function importGlobalAveragePool:
[8] Assertion failed: !isDynamic(kernelSize) && "Cannot run global average pool on an input with dynamic spatial dimensions!"

@rmccorm4 I have the same question

@lucasjinreal
Copy link
Contributor

@rmccorm4 Is there a way export a model with one layer have dynamic channel which num determined by instances detected in image and it's can be get only when runtime?

@rmccorm4
Copy link

rmccorm4 commented Jul 7, 2020

@ChauncyJin @chen-san https://github.com/NVIDIA/TensorRT/tree/master/samples/opensource/sampleDynamicReshape is a C++ sample showing the Dynamic Shape concepts.

@thancaocuong It looks like you're using enqueue(), which is meant for implicit batch engines. For explicit batch engines, please use enqueue_v2() instead.

@jinfagang you can generally export any of the dimensions to be dynamic as long as the layers in the model support that dimension being dynamic. To satisfy the behavior you describe, you would need to define the min/max instances you expect to see in your optimization profile dimensions, and it's up to your application code to set the IExecutionContext's dimensions at runtime based on the number of instances detected.

@thancaocuong
Copy link

@rmccorm4 thank for your answer. But I have one more question: input shape of my model is 5d tensor, so How can I set setBindingDimensions for your model?

@zahidzqj
Copy link

error:[TensorRT] ERROR: Parameter check failed at: engine.cpp::enqueue::393, condition: bindings[x] != nullptr

How to solve this problem?

@kevinch-nv
Copy link
Collaborator

I would recommend folks who are having trouble with importing their ONNX networks with dynamic shapes into TensorRT to first use our CLI binary trtexec with the latest TensorRT version to rule out the misuse of TRT APIs.

If your model still cannot be parsed / run, please open a new issue and attach information about your model for our team to look at.

I'll be closing this issue for now, as this issue has sort of branched off to many individual discussions.

@kevinch-nv kevinch-nv added question Further information is requested triaged Issue has been triaged by maintainers labels Dec 1, 2020
@ttanzhiqiang
Copy link

@serser
Copy link

serser commented Aug 7, 2021

I am using onnx2trt command to do the conversion of a dynamic input model onnx. What arguments should I pass?

onnx2trt alt_gvt_small_192_dynamic.onnx -o alt_gvt_small_192_dynamic.onnx -b 192

Which gives,

----------------------------------------------------------------
Input filename:   alt_gvt_small_192_dynamic.onnx
ONNX IR version:  0.0.6
Opset version:    11
Producer name:    pytorch
Producer version: 1.7
Domain:           
Model version:    0
Doc string:       
----------------------------------------------------------------
Parsing model
[2021-08-07 09:01:19 WARNING] /workdir/zhangbo97/onnx-tensorrt/onnx2trt_utils.cpp:235: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[2021-08-07 09:01:20 WARNING] Calling isShapeTensor before the entire network is constructed may result in an inaccurate result.
Building TensorRT engine, FP16 available:1
    Max batch size:     192
    Max workspace size: 1024 MiB
[2021-08-07 09:01:20   ERROR] Network has dynamic or shape inputs, but no optimization profile has been defined.
[2021-08-07 09:01:20   ERROR] Network validation failed.
terminate called after throwing an instance of 'std::runtime_error'
  what():  Failed to create object
Aborted (core dumped)

To recap, the model is generated as

img = torch.zeros((args.batch_size, 3, 224, 224))
torch.onnx.export(model, img, "{}_{}_dynamic.onnx".format(args.model, args.batch_size), input_names=['input'], output_names=['output'], verbose=True, opset_version=11, operator_export_type=torch.onnx.OperatorExportTypes.ONNX, dynamic_axes={"input": {0:"batch"}, "output": {0:"batch"}})

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests