Skip to content

🌟 T5 V1.1 #6285

Closed
Closed
@timoschick

Description

@timoschick

🌟 New model addition

Model description

T5 version t5.1.1.* is very similar to the original T5 model, with the following differences:

  • GEGLU activation in feed-forward hidden layer, rather than ReLU - see https://arxiv.org/abs/2002.05202 .
  • Dropout was turned off in pre-training (quality win). Dropout should be re-enabled during fine-tuning.
  • Pre-trained on C4 only without mixing in the downstream tasks.
  • no parameter sharing between embedding and classifier layer
  • "xl" and "xxl" replace "3B" and "11B". The model shapes are a bit different - larger d_model and smaller num_heads and d_ff.

The key reason why these models are interesting is that - unlike the originally released models - they were trained only on unlabeled data and not on any labeled data, making them applicable for few-shot learning experiments. As they are very similar to the original T5 models, I assume they are relatively easy to implement.

Open source status

(Also tagging @patrickvonplaten as he is mentioned in the who to tag guide for T5)

Activity

patrickvonplaten

patrickvonplaten commented on Sep 20, 2020

@patrickvonplaten
Contributor

Sorry for the long delay on this one - I hope to be able to take a look in the next two weeks :-)

patrickvonplaten

patrickvonplaten commented on Sep 20, 2020

@patrickvonplaten
Contributor

And thanks a lot for the very in-detail description here!

calclavia

calclavia commented on Sep 30, 2020

@calclavia

Any update on this task?

craffel

craffel commented on Oct 1, 2020

@craffel

Hi all, in case it is helpful Noam recently wrote up a maybe-exhaustive list of the differences between T5.1.1 and a vanilla Transformer. Copying it here:

Here are the main differences between t5.1.1.* and most other Transformer implementations.
Hope I have not forgotten anything. Please add where appropriate.

No positional embedding

(relies on relative attention - see below)

FFN Layers

No biases
version t5.1.1.* (but not t5.1.0.*) uses Gated-GELU activation
two input projections, approximate-gelu activation on one of them, then multiply componentwise, and apply the output projection
Approximate GELU:

cdf = 0.5 * (1.0 + tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * x * x * x))))
  return x * cdf

Attention Layers

"relative position bias" - This is a simplified form of relative attention, due to the fact that other relative attention algorithms are slow on TPU. This is present in the encoder self-attention layers and decoder self-attention layers, but not the encoder-decoder attention layers.
A learned "bias" value is added to the attention logit. The bias is different by bucketed relative position. The biases are different across attention heads, but shared across different attention layers in the same stack.
relative_position = memory_position - query_position
bucket(relative_position) is determined by the function here: https://github.com/tensorflow/mesh/blob/5f802ae5492fd9207fd506a7ced189f6dbc38f2c/mesh_tensorflow/transformer/transformer_layers.py#L996
bidirectional=True for the encoder and False for the decoder.
The variables representing the four linear transformations have their num_heads and d_kv dimensions combined. This caused the code to run faster on TPU for some unknown reason.
No biases on the input and output linear transformations.
No explicit scaling of the logits by d_kv^-0.5 . This is folded into the initializers of the linear transformations. With Adafactor, it's equivalent.
Not in any of the t5.1 configs, but may be in other configs: "extra logit" - This is equivalent to appending a 0 to the set of logits prior to softmax, and truncating it after the softmax. This allows for attending to nothing, if all of the logits are much less than zero. It's not clear whether this is an improvement or just a stumbling block for compatibility.

Embeddings

Encoder vocab embedding shared with decoder vocab embedding
in t5.1.0.* (but not in t5.1.1.*) this variable is also shared with the classifier layer. In that case, it is multiplied by d_model**-0.5 for use in the classifer.

Residuals, etc.

Before layer stack, apply dropout
For each layer apply
Y = X + dropout(F(rms_norm(X))
F is the core layer function, i.e. feed-forward, attention, etc.
RMS norm is a simplified version of layer norm.
After layer stack, apply rms_norm, then droupout.

arglog

arglog commented on Oct 13, 2020

@arglog

Hi @patrickvonplaten

Any updates on this? It's exciting to be able to use the T5 v1.1 models in huggingface! Thanks!

ratthachat

ratthachat commented on Oct 18, 2020

@ratthachat
Contributor

Hi Patrick,
There are newly released T5.1.1 checkpoints which give SOTA on natural question for non-retrieval models which I posted a
discussion here . Maybe it's a bit more encouragement to integrate T5.1.1 into HF :D

ratthachat

ratthachat commented on Oct 23, 2020

@ratthachat
Contributor

@craffel Thanks for your clarification about T5.1.1 .
However, I could not find any source code of T5.1.1 , is it possible to provide the link to the source ?

craffel

craffel commented on Oct 23, 2020

@craffel
acul3

acul3 commented on Oct 26, 2020

@acul3
Contributor

Multilingual t5 (mt5) has been released
https://github.com/google-research/multilingual-t5
https://arxiv.org/abs/2010.11934

it looks like use same implementation method as T5 v1.1
really look forward to be able use it on huggingface library

19 remaining items

Loading
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

    Development

    Participants

    @craffel@calclavia@timoschick@agemagician@shenfe

    Issue actions

      🌟 T5 V1.1 · Issue #6285 · huggingface/transformers