Skip to content

Attribute not found: height_scale #77

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
shocho3858 opened this issue Nov 22, 2018 · 12 comments · Fixed by #80
Closed

Attribute not found: height_scale #77

shocho3858 opened this issue Nov 22, 2018 · 12 comments · Fixed by #80

Comments

@shocho3858
Copy link

shocho3858 commented Nov 22, 2018

The onnx model file is exported from a pytorch model by torch.onnx.export. There is an error when I then use onnx2trt to do the conversion.

``
Input filename: model.onnx
ONNX IR version: 0.0.3
Opset version: 9
Producer name: pytorch
Producer version: 0.4
Domain:
Model version: 0
Doc string:

Parsing model
terminate called after throwing an instance of 'std::out_of_range'
what(): Attribute not found: height_scale
Aborted
``

I found this is due to torch.nn.Upsample(scale_factor=4, mode='nearest').

In face, the torch's onnx exporter transform an Upsample operation like this:
``
%244 : Dynamic = onnx::Constantvalue= 1 1 4 4 [ CPUFloatType{4} ], scope: East/Upsample[unpool1]

%245 : Float(1, 512, 100, 100) = onnx::Upsample[mode="nearest"](%243, %244), scope: East/Upsample[unpool1]

return (%245);
``

I guess the height_scale and width_scale hide in onnx:Constant and there is no scale attributes in onnx:Upsample. But in your code, height_scale and width_scale are required.

I am using pytorch 0.4.1 and tensorrt 5.0.2.6.

Any suggestions to turn it out?

@yinghai
Copy link

yinghai commented Nov 26, 2018

@houseroad Do we have height_scale in UpSample in ONNX?

@m7thon
Copy link

m7thon commented Nov 26, 2018

In the current ONNX operator specs, Upsample takes the scales as an input, no longer as an an attribute. See https://github.com/onnx/onnx/blob/master/docs/Operators.md#upsample
AFAIK, pytorch just recently changed the way upsample is exported to onnx to match the current specs. This allows resizing to a size that is determined dynamically.
@shocho3858 are you sure you are using pytorch 0.4.1, and not a recent nighty build?

@shocho3858
Copy link
Author

shocho3858 commented Nov 27, 2018

In the current ONNX operator specs, Upsample takes the scales as an input, no longer as an an attribute. See https://github.com/onnx/onnx/blob/master/docs/Operators.md#upsample
AFAIK, pytorch just recently changed the way upsample is exported to onnx to match the current specs. This allows resizing to a size that is determined dynamically.
@shocho3858 are you sure you are using pytorch 0.4.1, and not a recent nighty build?

You are right. The onnx model that pytorch exports matches the current specs. But why I can't convert it to the tensorrt engine by onnx2trt command? Please have a look at the error I've mentioned.

@yinghai
Copy link

yinghai commented Nov 27, 2018

@m7thon Thanks for pointing this out. We can probably handle this like what we did for Reshape:

DEFINE_BUILTIN_OP_IMPORTER(Reshape) {
auto input = inputs.at(0);
nvinfer1::Dims new_shape;
if( ctx->getOpsetVersion() >= 5 ) {
ASSERT(inputs.size() == 2, ErrorCode::kINVALID_NODE);
auto new_shape_input = inputs.at(1);
ASSERT(new_shape_input.is_weights(), ErrorCode::kUNSUPPORTED_NODE);
ShapedWeights new_shape_weights = new_shape_input.weights();
ASSERT(new_shape_weights.shape.nbDims == 1, ErrorCode::kINVALID_NODE);
ASSERT(new_shape_weights.type == ::ONNX_NAMESPACE::TensorProto::INT64,
ErrorCode::kINVALID_NODE);
int64_t const* new_shape_ptr =
static_cast<int64_t const*>(new_shape_weights.values);
new_shape.nbDims = new_shape_weights.shape.d[0];
std::copy(new_shape_ptr, new_shape_ptr + new_shape.nbDims, new_shape.d);
} else {
OnnxAttrs attrs(node);
new_shape = attrs.get<nvinfer1::Dims>("shape");
}

@houseroad Could you help track since which opset version did we change the input of upsample?

@houseroad
Copy link
Member

It's changed in opset 9, here is the change: onnx/onnx#1467

@yinghai
Copy link

yinghai commented Nov 27, 2018

@shocho3858 Could you try patch from #80 and see if this works for you?

@shocho3858
Copy link
Author

@shocho3858 Could you try patch from #80 and see if this works for you?

yeah, It's ok now. thx.

@maiminh1996
Copy link

How can i fix this issu? I use Python 3.5 and i want to convert the model from pytorch to onnx and then tensorrt

@xiongzhangdavid
Copy link

same issue, I use tensorrt 5 and look like the problem lies in backend.py line 80 trt.OnnxParser

@alexbuyval
Copy link

I have the same issue.

@xxradon
Copy link

xxradon commented Mar 25, 2019

@xiongzhangdavid @alexbuyval @maiminh1996 I think I have find how to fix this bug,actually not by me,is nvidia's engineer ,at this example retina-example.

       import torch.onnx.symbolic

        # Override Upsample's ONNX export until new opset is supported
        @torch.onnx.symbolic.parse_args('v', 'is')
        def upsample_nearest2d(g, input, output_size):
            height_scale = float(output_size[-2]) / input.type().sizes()[-2]
            width_scale = float(output_size[-1]) / input.type().sizes()[-1]
            return g.op("Upsample", input,
                scales_f=(1, 1, height_scale, width_scale),
                mode_s="nearest")
        torch.onnx.symbolic.upsample_nearest2d = upsample_nearest2d

We need to override upsample_nearest2d symbol so that tensorRT5.0 can parsing the upsample operator.
By the way,the latest onnx-tensorrt can onnx2trt onnx model well,but the serialized file can not be used bt tensorRT5.0,and error is segment error.
So we have to serialize model by using tensorRT5.0 provided API,like this

void onnxToTRTModel(const std::string& modelFile, // name of the onnx model
                    unsigned int maxBatchSize,    // batch size - NB must be at least as large as the batch we want to run with
                    nvinfer1::IHostMemory*& trtModelStream,
                    nvinfer1::DataType dataType,
                    nvinfer1::IInt8Calibrator* calibrator,
                    std::string save_name) // output buffer for the TensorRT model
{
    int verbosity = (int)nvinfer1::ILogger::Severity::kINFO;
    // create the builder
    nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(gLogger);
    nvinfer1::INetworkDefinition* network = builder->createNetwork();

    auto parser = nvonnxparser::createParser(*network, gLogger);

    if (!parser->parseFromFile(modelFile.c_str(), verbosity))
    {
        string msg("failed to parse onnx file");
        gLogger.log(nvinfer1::ILogger::Severity::kERROR, msg.c_str());
        exit(EXIT_FAILURE);
    }
    if ((dataType == nvinfer1::DataType::kINT8 && !builder->platformHasFastInt8()) )
        exit(EXIT_FAILURE);  //如果不支持kint8或不支持khalf就返回false
    // Build the engine

    builder->setMaxBatchSize(maxBatchSize);
    builder->setMaxWorkspaceSize(4_GB); //不能超过你的实际能用的显存的大小,例如我的1060的可用为4.98GB,超过4.98GB会报错
    builder->setInt8Mode(dataType == nvinfer1::DataType ::kINT8);  //
    builder->setInt8Calibrator(calibrator);  //
    samplesCommon::enableDLA(builder, gUseDLACore);
    nvinfer1::ICudaEngine* engine = builder->buildCudaEngine(*network);
    assert(engine);

    // we can destroy the parser
    parser->destroy();

    // serialize the engine, then close everything down  序列化
    trtModelStream = engine->serialize();

    gieModelStream.write((const char*)trtModelStream->data(), trtModelStream->size());
    std::ofstream SaveFile(save_name, std::ios::out | std::ios::binary);
    SaveFile.seekp(0, std::ios::beg);
    SaveFile << gieModelStream.rdbuf();
    gieModelStream.seekg(0, gieModelStream.beg);


    engine->destroy();
    network->destroy();
    builder->destroy();
}

@manhongnie
Copy link

mark

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

9 participants