Skip to content

Export LayoutLMv2 to onnx  #14368

Open
Open
@fadi212

Description

@fadi212

I am trying to export LayoutLMv2 model to onnx but there is no support for that available in transformers library.
I have tried to follow the method available for layoutLM but that is not working.
Here is config class for LayoutLMv2

class LayoutLMv2OnnxConfig(OnnxConfig):
    def __init__(
        self,
        config: PretrainedConfig,
        task: str = "default",
        patching_specs: List[PatchingSpec] = None,
    ):
        super().__init__(config, task=task, patching_specs=patching_specs)
        self.max_2d_positions = config.max_2d_position_embeddings - 1

    @property
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        return OrderedDict(
            [
                ("input_ids", {0: "batch", 1: "sequence"}),
                ("bbox", {0: "batch", 1: "sequence"}),
                ("image", {0: "batch", 1: "sequence"}),
                ("attention_mask", {0: "batch", 1: "sequence"}),
                ("token_type_ids", {0: "batch", 1: "sequence"}),
            ]
        )

    def generate_dummy_inputs(
        self,
        tokenizer: PreTrainedTokenizer,
        batch_size: int = -1,
        seq_length: int = -1,
        is_pair: bool = False,
        framework: Optional[TensorType] = None,
    ) -> Mapping[str, Any]:
        """
        Generate inputs to provide to the ONNX exporter for the specific framework
        Args:
            tokenizer: The tokenizer associated with this model configuration
            batch_size: The batch size (int) to export the model for (-1 means dynamic axis)
            seq_length: The sequence length (int) to export the model for (-1 means dynamic axis)
            is_pair: Indicate if the input is a pair (sentence 1, sentence 2)
            framework: The framework (optional) the tokenizer will generate tensor for
        Returns:
            Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
        """

        input_dict = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)

        # Generate a dummy bbox
        box = [48, 84, 73, 128]

        if not framework == TensorType.PYTORCH:
            raise NotImplementedError("Exporting LayoutLM to ONNX is currently only supported for PyTorch.")

        if not is_torch_available():
            raise ValueError("Cannot generate dummy inputs without PyTorch installed.")
        import torch

        batch_size, seq_length = input_dict["input_ids"].shape
        input_dict["bbox"] = torch.tensor([*[box] * seq_length]).tile(batch_size, 1, 1)
        return input_dict

onnx_config = LayoutLMv2OnnxConfig(model.config)


export(tokenizer=tokenizer, model=model, config=onnx_config, opset=12, output=Path('onnx/layoutlmv2.onnx'))

Running the export line is raising this error,

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-25-99a1f167e396> in <module>()
----> 1 export(tokenizer=tokenizer, model=model, config=onnx_config, opset=12, output=Path('onnx/layoutlmv2.onnx'))

3 frames
/usr/local/lib/python3.7/dist-packages/transformers/models/layoutlmv2/tokenization_layoutlmv2.py in __call__(self, text, text_pair, boxes, word_labels, add_special_tokens, padding, truncation, max_length, stride, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)
    449 
    450         words = text if text_pair is None else text_pair
--> 451         assert boxes is not None, "You must provide corresponding bounding boxes"
    452         if is_batched:
    453             assert len(words) == len(boxes), "You must provide words and boxes for an equal amount of examples"

AssertionError: You must provide corresponding bounding boxes

Activity

LysandreJik

LysandreJik commented on Nov 11, 2021

@LysandreJik
Member

I believe @NielsRogge can help out here

NielsRogge

NielsRogge commented on Nov 11, 2021

@NielsRogge
Contributor

I'm not an ONNX expert, however. Pinging @michaelbenayoun for this.

fadi212

fadi212 commented on Nov 14, 2021

@fadi212
Author

@michaelbenayoun can you please help here.

wilbry

wilbry commented on Nov 15, 2021

@wilbry

I think it might have to do with the fact that your dummy inputs don't have the image field, so the inputs might be off?

michaelbenayoun

michaelbenayoun commented on Nov 15, 2021

@michaelbenayoun
Member

It seems to come from the LayoutLMv2Tokenizer which takes boxes (bbox) as inputs.
Here you are calling super().generate_dummy_inputs which uses the tokenizer to create dummy inputs, but this does not provide the boxes to the tokenizer, hence the error.

There are two ways of solving this issue:

  1. Make this supported in the base class, that could somehow take other keyword arguments for these kind of cases.
  2. Not using the super method, and implementing everything in the LayoutLMv2 OnnxConfig
fadi212

fadi212 commented on Nov 16, 2021

@fadi212
Author

Hi @michaelbenayoun ,
I have made the recommended changes in the LayoutLMv2 config file.

# coding=utf-8
# Copyright Microsoft Research and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" LayoutLMv2 model configuration """

from ...configuration_utils import PretrainedConfig
from ...file_utils import is_detectron2_available
from ...utils import logging
from ...onnx import OnnxConfig, PatchingSpec
from typing import Any, List, Mapping, Optional
from transformers import TensorType
from transformers import LayoutLMv2Processor
from datasets import load_dataset
from PIL import Image
from ... import is_torch_available
from collections import OrderedDict

logger = logging.get_logger(__name__)

LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "layoutlmv2-base-uncased": "https://huggingface.co/microsoft/layoutlmv2-base-uncased/resolve/main/config.json",
    "layoutlmv2-large-uncased": "https://huggingface.co/microsoft/layoutlmv2-large-uncased/resolve/main/config.json",
    # See all LayoutLMv2 models at https://huggingface.co/models?filter=layoutlmv2
}

# soft dependency
if is_detectron2_available():
    import detectron2


class LayoutLMv2Config(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a :class:`~transformers.LayoutLMv2Model`. It is used
    to instantiate an LayoutLMv2 model according to the specified arguments, defining the model architecture.
    Instantiating a configuration with the defaults will yield a similar configuration to that of the LayoutLMv2
    `microsoft/layoutlmv2-base-uncased <https://huggingface.co/microsoft/layoutlmv2-base-uncased>`__ architecture.

    Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
    outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.

    Args:
        vocab_size (:obj:`int`, `optional`, defaults to 30522):
            Vocabulary size of the LayoutLMv2 model. Defines the number of different tokens that can be represented by
            the :obj:`inputs_ids` passed when calling :class:`~transformers.LayoutLMv2Model` or
            :class:`~transformers.TFLayoutLMv2Model`.
        hidden_size (:obj:`int`, `optional`, defaults to 768):
            Dimension of the encoder layers and the pooler layer.
        num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
            Number of hidden layers in the Transformer encoder.
        num_attention_heads (:obj:`int`, `optional`, defaults to 12):
            Number of attention heads for each attention layer in the Transformer encoder.
        intermediate_size (:obj:`int`, `optional`, defaults to 3072):
            Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
        hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
            The non-linear activation function (function or string) in the encoder and pooler. If string,
            :obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
        hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
            The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
        attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
            The dropout ratio for the attention probabilities.
        max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
            The maximum sequence length that this model might ever be used with. Typically set this to something large
            just in case (e.g., 512 or 1024 or 2048).
        type_vocab_size (:obj:`int`, `optional`, defaults to 2):
            The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.LayoutLMv2Model`
            or :class:`~transformers.TFLayoutLMv2Model`.
        initializer_range (:obj:`float`, `optional`, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
            The epsilon used by the layer normalization layers.
        max_2d_position_embeddings (:obj:`int`, `optional`, defaults to 1024):
            The maximum value that the 2D position embedding might ever be used with. Typically set this to something
            large just in case (e.g., 1024).
        max_rel_pos (:obj:`int`, `optional`, defaults to 128):
            The maximum number of relative positions to be used in the self-attention mechanism.
        rel_pos_bins (:obj:`int`, `optional`, defaults to 32):
            The number of relative position bins to be used in the self-attention mechanism.
        fast_qkv (:obj:`bool`, `optional`, defaults to :obj:`True`):
            Whether or not to use a single matrix for the queries, keys, values in the self-attention layers.
        max_rel_2d_pos (:obj:`int`, `optional`, defaults to 256):
            The maximum number of relative 2D positions in the self-attention mechanism.
        rel_2d_pos_bins (:obj:`int`, `optional`, defaults to 64):
            The number of 2D relative position bins in the self-attention mechanism.
        image_feature_pool_shape (:obj:`List[int]`, `optional`, defaults to [7, 7, 256]):
            The shape of the average-pooled feature map.
        coordinate_size (:obj:`int`, `optional`, defaults to 128):
            Dimension of the coordinate embeddings.
        shape_size (:obj:`int`, `optional`, defaults to 128):
            Dimension of the width and height embeddings.
        has_relative_attention_bias (:obj:`bool`, `optional`, defaults to :obj:`True`):
            Whether or not to use a relative attention bias in the self-attention mechanism.
        has_spatial_attention_bias (:obj:`bool`, `optional`, defaults to :obj:`True`):
            Whether or not to use a spatial attention bias in the self-attention mechanism.
        has_visual_segment_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to add visual segment embeddings.
        detectron2_config_args (:obj:`dict`, `optional`):
            Dictionary containing the configuration arguments of the Detectron2 visual backbone. Refer to `this file
            <https://github.com/microsoft/unilm/blob/master/layoutlmft/layoutlmft/models/layoutlmv2/detectron2_config.py>`__
            for details regarding default values.

    Example::

        >>> from transformers import LayoutLMv2Model, LayoutLMv2Config

        >>> # Initializing a LayoutLMv2 microsoft/layoutlmv2-base-uncased style configuration
        >>> configuration = LayoutLMv2Config()

        >>> # Initializing a model from the microsoft/layoutlmv2-base-uncased style configuration
        >>> model = LayoutLMv2Model(configuration)

        >>> # Accessing the model configuration
        >>> configuration = model.config
    """
    model_type = "layoutlmv2"

    def __init__(
        self,
        vocab_size=30522,
        hidden_size=768,
        num_hidden_layers=12,
        num_attention_heads=12,
        intermediate_size=3072,
        hidden_act="gelu",
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_position_embeddings=512,
        type_vocab_size=2,
        initializer_range=0.02,
        layer_norm_eps=1e-12,
        pad_token_id=0,
        max_2d_position_embeddings=1024,
        max_rel_pos=128,
        rel_pos_bins=32,
        fast_qkv=True,
        max_rel_2d_pos=256,
        rel_2d_pos_bins=64,
        convert_sync_batchnorm=True,
        image_feature_pool_shape=[7, 7, 256],
        coordinate_size=128,
        shape_size=128,
        has_relative_attention_bias=True,
        has_spatial_attention_bias=True,
        has_visual_segment_embedding=False,
        detectron2_config_args=None,
        **kwargs
    ):
        super().__init__(
            vocab_size=vocab_size,
            hidden_size=hidden_size,
            num_hidden_layers=num_hidden_layers,
            num_attention_heads=num_attention_heads,
            intermediate_size=intermediate_size,
            hidden_act=hidden_act,
            hidden_dropout_prob=hidden_dropout_prob,
            attention_probs_dropout_prob=attention_probs_dropout_prob,
            max_position_embeddings=max_position_embeddings,
            type_vocab_size=type_vocab_size,
            initializer_range=initializer_range,
            layer_norm_eps=layer_norm_eps,
            pad_token_id=pad_token_id,
            **kwargs,
        )
        self.max_2d_position_embeddings = max_2d_position_embeddings
        self.max_rel_pos = max_rel_pos
        self.rel_pos_bins = rel_pos_bins
        self.fast_qkv = fast_qkv
        self.max_rel_2d_pos = max_rel_2d_pos
        self.rel_2d_pos_bins = rel_2d_pos_bins
        self.convert_sync_batchnorm = convert_sync_batchnorm
        self.image_feature_pool_shape = image_feature_pool_shape
        self.coordinate_size = coordinate_size
        self.shape_size = shape_size
        self.has_relative_attention_bias = has_relative_attention_bias
        self.has_spatial_attention_bias = has_spatial_attention_bias
        self.has_visual_segment_embedding = has_visual_segment_embedding
        self.detectron2_config_args = (
            detectron2_config_args if detectron2_config_args is not None else self.get_default_detectron2_config()
        )

    @classmethod
    def get_default_detectron2_config(self):
        return {
            "MODEL.MASK_ON": True,
            "MODEL.PIXEL_STD": [57.375, 57.120, 58.395],
            "MODEL.BACKBONE.NAME": "build_resnet_fpn_backbone",
            "MODEL.FPN.IN_FEATURES": ["res2", "res3", "res4", "res5"],
            "MODEL.ANCHOR_GENERATOR.SIZES": [[32], [64], [128], [256], [512]],
            "MODEL.RPN.IN_FEATURES": ["p2", "p3", "p4", "p5", "p6"],
            "MODEL.RPN.PRE_NMS_TOPK_TRAIN": 2000,
            "MODEL.RPN.PRE_NMS_TOPK_TEST": 1000,
            "MODEL.RPN.POST_NMS_TOPK_TRAIN": 1000,
            "MODEL.POST_NMS_TOPK_TEST": 1000,
            "MODEL.ROI_HEADS.NAME": "StandardROIHeads",
            "MODEL.ROI_HEADS.NUM_CLASSES": 5,
            "MODEL.ROI_HEADS.IN_FEATURES": ["p2", "p3", "p4", "p5"],
            "MODEL.ROI_BOX_HEAD.NAME": "FastRCNNConvFCHead",
            "MODEL.ROI_BOX_HEAD.NUM_FC": 2,
            "MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION": 14,
            "MODEL.ROI_MASK_HEAD.NAME": "MaskRCNNConvUpsampleHead",
            "MODEL.ROI_MASK_HEAD.NUM_CONV": 4,
            "MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION": 7,
            "MODEL.RESNETS.DEPTH": 101,
            "MODEL.RESNETS.SIZES": [[32], [64], [128], [256], [512]],
            "MODEL.RESNETS.ASPECT_RATIOS": [[0.5, 1.0, 2.0]],
            "MODEL.RESNETS.OUT_FEATURES": ["res2", "res3", "res4", "res5"],
            "MODEL.RESNETS.NUM_GROUPS": 32,
            "MODEL.RESNETS.WIDTH_PER_GROUP": 8,
            "MODEL.RESNETS.STRIDE_IN_1X1": False,
        }

    def get_detectron2_config(self):
        detectron2_config = detectron2.config.get_cfg()
        for k, v in self.detectron2_config_args.items():
            attributes = k.split(".")
            to_set = detectron2_config
            for attribute in attributes[:-1]:
                to_set = getattr(to_set, attribute)
            setattr(to_set, attributes[-1], v)

        return detectron2_config


class LayoutLMv2OnnxConfig(OnnxConfig):
    def __init__(
        self,
        config: PretrainedConfig,
        task: str = "default",
        patching_specs: List[PatchingSpec] = None,
    ):
        super().__init__(config, task=task, patching_specs=patching_specs)
        self.max_2d_positions = config.max_2d_position_embeddings - 1

    @property
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        return OrderedDict(
            [
                ("input_ids", {0: "batch", 1: "sequence"}),
                ("bbox", {0: "batch", 1: "sequence"}),
                ("image", {0:"batch"}),
                ("attention_mask", {0: "batch", 1: "sequence"}),
                ("token_type_ids", {0: "batch", 1: "sequence"}),
            ]
        )

    def generate_dummy_inputs(
        self,
        processor: LayoutLMv2Processor,
        batch_size: int = -1,
        seq_length: int = -1,
        is_pair: bool = False,
        framework: Optional[TensorType] = None,
    ) -> Mapping[str, Any]:
        """
        Generate inputs to provide to the ONNX exporter for the specific framework

        Args:
            tokenizer: The tokenizer associated with this model configuration
            batch_size: The batch size (int) to export the model for (-1 means dynamic axis)
            seq_length: The sequence length (int) to export the model for (-1 means dynamic axis)
            is_pair: Indicate if the input is a pair (sentence 1, sentence 2)
            framework: The framework (optional) the tokenizer will generate tensor for
is_pair
        Returns:
            Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
        """
        datasets = load_dataset("nielsr/funsd")
        labels = datasets['train'].features['ner_tags'].feature.names
        example = datasets["test"][0]
        # print(example.keys())
        image = Image.open(example['image_path'])
        image = image.convert("RGB")

        if not framework == TensorType.PYTORCH:
            raise NotImplementedError("Exporting LayoutLM to ONNX is currently only supported for PyTorch.")

        if not is_torch_available():
            raise ValueError("Cannot generate dummy inputs without PyTorch installed.")
        import torch

        input_dict = processor(image, example['words'], boxes=example['bboxes'], word_labels=example['ner_tags'],
                                   return_tensors=framework)

        axis = 0
        for key_i in input_dict.data.keys():
            input_dict.data[key_i] = torch.cat((input_dict.data[key_i], input_dict.data[key_i]), axis)

        return input_dict.data

Now when I am trying to run the below code,

processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr")
model = LayoutLMv2ForTokenClassification.from_pretrained("microsoft/layoutlmv2-base-uncased", torchscript=True)

onnx_config = LayoutLMv2OnnxConfig(model.config)
export(tokenizer=processor, model=model, config=onnx_config, opset=13, output=Path('onnx/layout.onnx'))

I am facing the below error.

Traceback (most recent call last):
  File "/home/muhammad/PycharmProjects/js_labs
/Layoutv2/convert_lmv2.py", line 11, in <module>
    export(tokenizer=processor, model=model, config=onnx_config, opset=9, output=Path('onnx/layout.onnx'))
  File "/home/muhammad/PycharmProjects/js_labs
/anaconda3/envs/onnx-env/lib/python3.7/site-packages/transformers/onnx/convert.py", line 125, in export
    opset_version=opset,
  File "/home/muhammad/PycharmProjects/js_labs
/anaconda3/envs/onnx-env/lib/python3.7/site-packages/torch/onnx/_init_.py", line 320, in export
    custom_opsets, enable_onnx_checker, use_external_data_format)
  File "/home/muhammad/PycharmProjects/js_labs
/anaconda3/envs/onnx-env/lib/python3.7/site-packages/torch/onnx/utils.py", line 111, in export
    custom_opsets=custom_opsets, use_external_data_format=use_external_data_format)
  File "/home/muhammad/PycharmProjects/js_labs
/anaconda3/envs/onnx-env/lib/python3.7/site-packages/torch/onnx/utils.py", line 740, in _export
    val_add_node_names, val_use_external_data_format, model_file_location)
RuntimeError: ONNX export failed: Couldn't export operator aten::adaptive_avg_pool2d

One more thing, for dummy input I have provide image as "image", {0:"batch"}, is this mapping right or do we have to provide image in a different manner.

mykolamelnykml

mykolamelnykml commented on Nov 16, 2021

@mykolamelnykml

+1

mykolamelnykml

mykolamelnykml commented on Nov 16, 2021

@mykolamelnykml

+1

NielsRogge

NielsRogge commented on Nov 16, 2021

@NielsRogge
Contributor

Hi,

Would be great if you could Google the errors before pinging us (because we at Huggingface are pretty busy). Eg in this case, you can find the answer in the first result on Google: onnx/tutorials#63 (comment)

=> The reason is that LayoutLMv2 uses a visual backbone, which includes layers like AdapativeAvgPool2d which aren't supported natively by ONNX.

fadi212

fadi212 commented on Nov 18, 2021

@fadi212
Author

Hi @NielsRogge ,
I followed your guide and made the required changes. I updated the pooling layer and now I am faced with the below error.
I had googled the previous issue as well but was not kind of sure where to make pooling layer changes.
This time I had searched for the subjected issue but to no avail as I am kind of new to to onnx.

Would you please point out where I am making error in the code below.

from transformers.onnx import OnnxConfig, PatchingSpec
from transformers.configuration_utils import PretrainedConfig
from typing import Any, List, Mapping, Optional, Tuple, Union, Iterable
from collections import OrderedDict
from transformers import LayoutLMv2Processor
from datasets import load_dataset
from PIL import Image
import torch
from transformers import PreTrainedModel, TensorType
from torch.onnx import export
from transformers.file_utils import torch_version, is_torch_onnx_dict_inputs_support_available
from pathlib import Path
from transformers.utils import logging
from inspect import signature
from itertools import chain
from transformers import LayoutLMv2ForTokenClassification
from torch import nn
from torch.onnx import OperatorExportTypes

logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


class LayoutLMv2OnnxConfig(OnnxConfig):
    def __init__(
            self,
            config: PretrainedConfig,
            task: str = "default",
            patching_specs: List[PatchingSpec] = None,
    ):
        super().__init__(config, task=task, patching_specs=patching_specs)
        self.max_2d_positions = config.max_2d_position_embeddings - 1

    @property
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        return OrderedDict(
            [
                ("input_ids", {0: "batch", 1: "sequence"}),
                ("bbox", {0: "batch", 1: "sequence"}),
                ("image", {0: "batch"}),
                ("attention_mask", {0: "batch", 1: "sequence"}),
                ("token_type_ids", {0: "batch", 1: "sequence"}),
            ]
        )

    def generate_dummy_inputs(
            self,
            processor: LayoutLMv2Processor,
            batch_size: int = -1,
            seq_length: int = -1,
            is_pair: bool = False,
            framework: Optional[TensorType] = None,
    ) -> Mapping[str, Any]:

        datasets = load_dataset("nielsr/funsd")
        example = datasets["test"][0]
        image = Image.open(example['image_path'])
        image = image.convert("RGB")

        if not framework == TensorType.PYTORCH:
            raise NotImplementedError("Exporting LayoutLM to ONNX is currently only supported for PyTorch.")

        input_dict = processor(image, example['words'], boxes=example['bboxes'], word_labels=example['ner_tags'],
                               return_tensors=framework)

        axis = 0
        for key_i in input_dict.data.keys():
            input_dict.data[key_i] = torch.cat((input_dict.data[key_i], input_dict.data[key_i]), axis)

        return input_dict.data


class pool_layer(nn.Module):
    def __init__(self):
        super(pool_layer, self).__init__()
        self.fc = nn.AvgPool2d(kernel_size=[8, 8], stride=[8, 8])

    def forward(self, x):
        output = self.fc(x)
        return output


def ensure_model_and_config_inputs_match(
        model: PreTrainedModel, model_inputs: Iterable[str]
) -> Tuple[bool, List[str]]:
    """

    :param model:
    :param model_inputs:
    :return:
    """
    forward_parameters = signature(model.forward).parameters
    model_inputs_set = set(model_inputs)

    # We are fine if config_inputs has more keys than model_inputs
    forward_inputs_set = set(forward_parameters.keys())
    is_ok = model_inputs_set.issubset(forward_inputs_set)

    # Make sure the input order match (VERY IMPORTANT !!!!)
    matching_inputs = forward_inputs_set.intersection(model_inputs_set)
    ordered_inputs = [parameter for parameter in forward_parameters.keys() if parameter in matching_inputs]
    return is_ok, ordered_inputs


def export_model(
        processor: LayoutLMv2Processor, model: PreTrainedModel, config: LayoutLMv2OnnxConfig, opset: int, output: Path
) -> Tuple[List[str], List[str]]:
    """
    Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR

    Args:
        processor:
        model:
        config:
        opset:
        output:

    Returns:

    """

    if not is_torch_onnx_dict_inputs_support_available():
        raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}")

    logger.info(f"Using framework PyTorch: {torch.__version__}")
    with torch.no_grad():
        model.config.return_dict = True
        model.eval()

        # Check if we need to override certain configuration item
        if config.values_override is not None:
            logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
            for override_config_key, override_config_value in config.values_override.items():
                logger.info(f"\t- {override_config_key} -> {override_config_value}")
                setattr(model.config, override_config_key, override_config_value)

        model_inputs = config.generate_dummy_inputs(processor, framework=TensorType.PYTORCH)
        inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
        print(matched_inputs)
        onnx_outputs = list(config.outputs.keys())

        if not inputs_match:
            raise ValueError("Model and config inputs doesn't match")

        config.patch_ops()
        model_inputs.pop("labels")
        export(
            model,
            (model_inputs,),
            f=output.as_posix(),
            input_names=list(config.inputs.keys()),
            output_names=onnx_outputs,
            dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
            do_constant_folding=True,
            use_external_data_format=config.use_external_data_format(model.num_parameters()),
            enable_onnx_checker=True,
            opset_version=opset,
        #    operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK
        )

        config.restore_ops()

    return matched_inputs, onnx_outputs


if __name__ == '__main__':
    processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr")
    model = LayoutLMv2ForTokenClassification.from_pretrained("microsoft/layoutlmv2-base-uncased", torchscript = True)
    model.layoutlmv2.visual.pool = torch.nn.Sequential(pool_layer())
    onnx_config = LayoutLMv2OnnxConfig(model.config)
    export_model(processor=processor, model=model, config=onnx_config, opset=13, output=Path('onnx/layout.onnx'))

Running the above code is raising the below error,

RuntimeError                              Traceback (most recent call last)
<ipython-input-6-134631b21e61> in <module>()
    168     model.layoutlmv2.visual.pool = torch.nn.Sequential(pool_layer())
    169     onnx_config = LayoutLMv2OnnxConfig(model.config)
--> 170     export_model(processor=processor, model=model, config=onnx_config, opset=13, output=Path('onnx/layout.onnx'))

4 frames
<ipython-input-6-134631b21e61> in export_model(processor, model, config, opset, output)
    154             use_external_data_format=config.use_external_data_format(model.num_parameters()),
    155             enable_onnx_checker=True,
--> 156             opset_version=opset,
    157         #    operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK
    158         )

/usr/local/lib/python3.7/dist-packages/torch/onnx/__init__.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
    274                         do_constant_folding, example_outputs,
    275                         strip_doc_string, dynamic_axes, keep_initializers_as_inputs,
--> 276                         custom_opsets, enable_onnx_checker, use_external_data_format)
    277 
    278 

/usr/local/lib/python3.7/dist-packages/torch/onnx/utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
     92             dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs,
     93             custom_opsets=custom_opsets, enable_onnx_checker=enable_onnx_checker,
---> 94             use_external_data_format=use_external_data_format)
     95 
     96 

/usr/local/lib/python3.7/dist-packages/torch/onnx/utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, opset_version, _retain_param_name, do_constant_folding, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, enable_onnx_checker, use_external_data_format, onnx_shape_inference, use_new_jit_passes)
    696                                 training=training,
    697                                 use_new_jit_passes=use_new_jit_passes,
--> 698                                 dynamic_axes=dynamic_axes)
    699 
    700             # TODO: Don't allocate a in-memory string for the protobuf

/usr/local/lib/python3.7/dist-packages/torch/onnx/utils.py in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, example_outputs, _retain_param_name, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, use_new_jit_passes, dynamic_axes)
    498     if do_constant_folding and _export_onnx_opset_version in torch.onnx.constant_folding_opset_versions:
    499         params_dict = torch._C._jit_pass_onnx_constant_fold(graph, params_dict,
--> 500                                                             _export_onnx_opset_version)
    501         torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
    502 

RuntimeError: Tensors must have same number of dimensions: got 2 and 1
viantirreau

viantirreau commented on Nov 19, 2021

@viantirreau

@fadi212 Have you tried using another opset version, such as 11?
Speaking from complete ignorance here, but maybe worth a try :)

lalitr994

lalitr994 commented on Nov 23, 2021

@lalitr994

my model is converted to onnx but at time of loading model to onnxruntime I am getting below error.
Type Error: Type parameter (T) bound to different types (tensor(double) and tensor(float) in node ()

@michaelbenayoun @wilbry @fadi212

NielsRogge

NielsRogge commented on Nov 23, 2021

@NielsRogge
Contributor

Hi,

Can you check out the solution provided here?

Also, if you managed to convert the model to ONNX, feel free to open a PR which we can review, it will benefit the community a lot.

Thanks!

13 remaining items

Loading
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      No branches or pull requests

        Participants

        @mykolamelnykml@amyeroberts@michaelbenayoun@LysandreJik@lalitr994

        Issue actions

          Export LayoutLMv2 to onnx · Issue #14368 · huggingface/transformers