Skip to content

🌟 T5 V1.1 #6285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
3 tasks done
timoschick opened this issue Aug 6, 2020 · 22 comments · Fixed by #8552
Closed
3 tasks done

🌟 T5 V1.1 #6285

timoschick opened this issue Aug 6, 2020 · 22 comments · Fixed by #8552
Assignees

Comments

@timoschick
Copy link

timoschick commented Aug 6, 2020

🌟 New model addition

Model description

T5 version t5.1.1.* is very similar to the original T5 model, with the following differences:

  • GEGLU activation in feed-forward hidden layer, rather than ReLU - see https://arxiv.org/abs/2002.05202 .
  • Dropout was turned off in pre-training (quality win). Dropout should be re-enabled during fine-tuning.
  • Pre-trained on C4 only without mixing in the downstream tasks.
  • no parameter sharing between embedding and classifier layer
  • "xl" and "xxl" replace "3B" and "11B". The model shapes are a bit different - larger d_model and smaller num_heads and d_ff.

The key reason why these models are interesting is that - unlike the originally released models - they were trained only on unlabeled data and not on any labeled data, making them applicable for few-shot learning experiments. As they are very similar to the original T5 models, I assume they are relatively easy to implement.

Open source status

(Also tagging @patrickvonplaten as he is mentioned in the who to tag guide for T5)

@patrickvonplaten
Copy link
Contributor

Sorry for the long delay on this one - I hope to be able to take a look in the next two weeks :-)

@patrickvonplaten
Copy link
Contributor

And thanks a lot for the very in-detail description here!

@calclavia
Copy link

Any update on this task?

@craffel
Copy link

craffel commented Oct 1, 2020

Hi all, in case it is helpful Noam recently wrote up a maybe-exhaustive list of the differences between T5.1.1 and a vanilla Transformer. Copying it here:

Here are the main differences between t5.1.1.* and most other Transformer implementations.
Hope I have not forgotten anything. Please add where appropriate.

No positional embedding

(relies on relative attention - see below)

FFN Layers

No biases
version t5.1.1.* (but not t5.1.0.*) uses Gated-GELU activation
two input projections, approximate-gelu activation on one of them, then multiply componentwise, and apply the output projection
Approximate GELU:

cdf = 0.5 * (1.0 + tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * x * x * x))))
  return x * cdf

Attention Layers

"relative position bias" - This is a simplified form of relative attention, due to the fact that other relative attention algorithms are slow on TPU. This is present in the encoder self-attention layers and decoder self-attention layers, but not the encoder-decoder attention layers.
A learned "bias" value is added to the attention logit. The bias is different by bucketed relative position. The biases are different across attention heads, but shared across different attention layers in the same stack.
relative_position = memory_position - query_position
bucket(relative_position) is determined by the function here: https://github.com/tensorflow/mesh/blob/5f802ae5492fd9207fd506a7ced189f6dbc38f2c/mesh_tensorflow/transformer/transformer_layers.py#L996
bidirectional=True for the encoder and False for the decoder.
The variables representing the four linear transformations have their num_heads and d_kv dimensions combined. This caused the code to run faster on TPU for some unknown reason.
No biases on the input and output linear transformations.
No explicit scaling of the logits by d_kv^-0.5 . This is folded into the initializers of the linear transformations. With Adafactor, it's equivalent.
Not in any of the t5.1 configs, but may be in other configs: "extra logit" - This is equivalent to appending a 0 to the set of logits prior to softmax, and truncating it after the softmax. This allows for attending to nothing, if all of the logits are much less than zero. It's not clear whether this is an improvement or just a stumbling block for compatibility.

Embeddings

Encoder vocab embedding shared with decoder vocab embedding
in t5.1.0.* (but not in t5.1.1.*) this variable is also shared with the classifier layer. In that case, it is multiplied by d_model**-0.5 for use in the classifer.

Residuals, etc.

Before layer stack, apply dropout
For each layer apply
Y = X + dropout(F(rms_norm(X))
F is the core layer function, i.e. feed-forward, attention, etc.
RMS norm is a simplified version of layer norm.
After layer stack, apply rms_norm, then droupout.

@arglog
Copy link

arglog commented Oct 13, 2020

Hi @patrickvonplaten

Any updates on this? It's exciting to be able to use the T5 v1.1 models in huggingface! Thanks!

@ratthachat
Copy link
Contributor

Hi Patrick,
There are newly released T5.1.1 checkpoints which give SOTA on natural question for non-retrieval models which I posted a
discussion here . Maybe it's a bit more encouragement to integrate T5.1.1 into HF :D

@ratthachat
Copy link
Contributor

@craffel Thanks for your clarification about T5.1.1 .
However, I could not find any source code of T5.1.1 , is it possible to provide the link to the source ?

@craffel
Copy link

craffel commented Oct 23, 2020

Hi, the source is all in the mesh TF transformer codebase
https://github.com/tensorflow/mesh/tree/master/mesh_tensorflow/transformer
Here is the gin config for t5.1.1.base
https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/models/gin/models/t5.1.1.base.gin

@acul3
Copy link
Contributor

acul3 commented Oct 26, 2020

Multilingual t5 (mt5) has been released
https://github.com/google-research/multilingual-t5
https://arxiv.org/abs/2010.11934

it looks like use same implementation method as T5 v1.1
really look forward to be able use it on huggingface library

@neurowide
Copy link

@julien-c thanks for your amazing nlp lib.
When do you plan to support mT5 ?
When #6285 will be release ?
Cheers
Philippe

@patrickvonplaten
Copy link
Contributor

Hey guys,

I will start adding mT5 next week

@sachinsharma9780
Copy link

@patrickvonplaten : waiting for mt5 :)

@patrickvonplaten
Copy link
Contributor

Yep will start working on it this week :-)

@patrickvonplaten
Copy link
Contributor

Think a reasonable estimate for official release is in ~2 weeks: #8488

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Nov 12, 2020

T5 V1.1 and MT5 have the same architecture. I'm struggling a bit with finding a good name for the library.

Not sure if I like the names T5V1_1Model and T5V1_1ForConditionalGeneration, maybe T5v2Model is better?
MT5Model will be aliased to the new model architecture.

=> Going for T5v2Model and T5v2ForConditionalGeneration now. MT5Model will be aliased to it. If someone has better name suggestions please add a comment :-) Names are easy to change before integration.

@ratthachat
Copy link
Contributor

ratthachat commented Nov 12, 2020

Hi @patrickvonplaten , thanks again !
I think T5v2 is a nicer name. However, "if" somebody releases the official T5v2 in the future (like GPT GPT-2 GPT-3), maybe it will cause confusion. Can it be T5v11 (no '_') ?

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Nov 13, 2020

Yeah good point @ratthachat!

@craffel - We decided internally that we will make a new model file for T5v1.1 / mT5 as it's more in line with the libraries' philosophy. The best name that I can think of at the moment is T5V2Model and T5V2ForConditionalGeneration respectively - IMO it's better than T5v1p1Model, ... However, if you guys would release a T5v2 the naming would be a bit awkward.

Would be super interested in hearing your opinion about it! Or better name suggestions in case you have some :-)

@craffel
Copy link

craffel commented Nov 13, 2020

It might be confusing to refer to T5.1.1 as T5 v2 since it would result in an inconsistent versioning system. I think T511Model is probably ok, but I defer to you all as to what HF's naming convention should be.

@agemagician
Copy link
Contributor

agemagician commented Nov 13, 2020

I would either suggest:

  1. Follow @craffel suggestion.
  2. To just have one version and adjust the json file to load the correct configuration. Since most of the code is exactly the same except few changes.

@ratthachat
Copy link
Contributor

If possible and not cause any harm I support @agemagician choice 2. above.

@shenfe
Copy link
Contributor

shenfe commented Nov 14, 2020

I haven't reproduce benchmark performance (such as glue cola, mrpc, etc.) with PyTorch T5.1.1 so far. Is anyone else trying this?

@shenfe
Copy link
Contributor

shenfe commented Nov 15, 2020

I haven't reproduce benchmark performance (such as glue cola, mrpc, etc.) with PyTorch T5.1.1 so far. Is anyone else trying this?

I have reproduced mT5-small model by finetuning XNLI benchmark task now. It seems to work.

@patrickvonplaten patrickvonplaten linked a pull request Nov 15, 2020 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

10 participants