Open
Description
I am trying to export LayoutLMv2 model to onnx but there is no support for that available in transformers library.
I have tried to follow the method available for layoutLM but that is not working.
Here is config class for LayoutLMv2
class LayoutLMv2OnnxConfig(OnnxConfig):
def __init__(
self,
config: PretrainedConfig,
task: str = "default",
patching_specs: List[PatchingSpec] = None,
):
super().__init__(config, task=task, patching_specs=patching_specs)
self.max_2d_positions = config.max_2d_position_embeddings - 1
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("bbox", {0: "batch", 1: "sequence"}),
("image", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
("token_type_ids", {0: "batch", 1: "sequence"}),
]
)
def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
"""
Generate inputs to provide to the ONNX exporter for the specific framework
Args:
tokenizer: The tokenizer associated with this model configuration
batch_size: The batch size (int) to export the model for (-1 means dynamic axis)
seq_length: The sequence length (int) to export the model for (-1 means dynamic axis)
is_pair: Indicate if the input is a pair (sentence 1, sentence 2)
framework: The framework (optional) the tokenizer will generate tensor for
Returns:
Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
"""
input_dict = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
# Generate a dummy bbox
box = [48, 84, 73, 128]
if not framework == TensorType.PYTORCH:
raise NotImplementedError("Exporting LayoutLM to ONNX is currently only supported for PyTorch.")
if not is_torch_available():
raise ValueError("Cannot generate dummy inputs without PyTorch installed.")
import torch
batch_size, seq_length = input_dict["input_ids"].shape
input_dict["bbox"] = torch.tensor([*[box] * seq_length]).tile(batch_size, 1, 1)
return input_dict
onnx_config = LayoutLMv2OnnxConfig(model.config)
export(tokenizer=tokenizer, model=model, config=onnx_config, opset=12, output=Path('onnx/layoutlmv2.onnx'))
Running the export line is raising this error,
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-25-99a1f167e396> in <module>()
----> 1 export(tokenizer=tokenizer, model=model, config=onnx_config, opset=12, output=Path('onnx/layoutlmv2.onnx'))
3 frames
/usr/local/lib/python3.7/dist-packages/transformers/models/layoutlmv2/tokenization_layoutlmv2.py in __call__(self, text, text_pair, boxes, word_labels, add_special_tokens, padding, truncation, max_length, stride, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)
449
450 words = text if text_pair is None else text_pair
--> 451 assert boxes is not None, "You must provide corresponding bounding boxes"
452 if is_batched:
453 assert len(words) == len(boxes), "You must provide words and boxes for an equal amount of examples"
AssertionError: You must provide corresponding bounding boxes
Activity
LysandreJik commentedon Nov 11, 2021
I believe @NielsRogge can help out here
NielsRogge commentedon Nov 11, 2021
I'm not an ONNX expert, however. Pinging @michaelbenayoun for this.
fadi212 commentedon Nov 14, 2021
@michaelbenayoun can you please help here.
wilbry commentedon Nov 15, 2021
I think it might have to do with the fact that your dummy inputs don't have the image field, so the inputs might be off?
michaelbenayoun commentedon Nov 15, 2021
It seems to come from the
LayoutLMv2Tokenizer
which takes boxes (bbox) as inputs.Here you are calling
super().generate_dummy_inputs
which uses the tokenizer to create dummy inputs, but this does not provide the boxes to the tokenizer, hence the error.There are two ways of solving this issue:
fadi212 commentedon Nov 16, 2021
Hi @michaelbenayoun ,
I have made the recommended changes in the LayoutLMv2 config file.
Now when I am trying to run the below code,
I am facing the below error.
One more thing, for dummy input I have provide image as
"image", {0:"batch"}
, is this mapping right or do we have to provide image in a different manner.mykolamelnykml commentedon Nov 16, 2021
+1
mykolamelnykml commentedon Nov 16, 2021
+1
NielsRogge commentedon Nov 16, 2021
Hi,
Would be great if you could Google the errors before pinging us (because we at Huggingface are pretty busy). Eg in this case, you can find the answer in the first result on Google: onnx/tutorials#63 (comment)
=> The reason is that LayoutLMv2 uses a visual backbone, which includes layers like AdapativeAvgPool2d which aren't supported natively by ONNX.
fadi212 commentedon Nov 18, 2021
Hi @NielsRogge ,
I followed your guide and made the required changes. I updated the pooling layer and now I am faced with the below error.
I had googled the previous issue as well but was not kind of sure where to make pooling layer changes.
This time I had searched for the subjected issue but to no avail as I am kind of new to to onnx.
Would you please point out where I am making error in the code below.
Running the above code is raising the below error,
viantirreau commentedon Nov 19, 2021
@fadi212 Have you tried using another
opset
version, such as 11?Speaking from complete ignorance here, but maybe worth a try :)
lalitr994 commentedon Nov 23, 2021
my model is converted to onnx but at time of loading model to onnxruntime I am getting below error.
Type Error: Type parameter (T) bound to different types (tensor(double) and tensor(float) in node ()
@michaelbenayoun @wilbry @fadi212
NielsRogge commentedon Nov 23, 2021
Hi,
Can you check out the solution provided here?
Also, if you managed to convert the model to ONNX, feel free to open a PR which we can review, it will benefit the community a lot.
Thanks!
13 remaining items