31
GitHub - lucidrains/reformer-pytorch: Reformer, the efficient Transformer, imple...
source link: https://github.com/lucidrains/reformer-pytorch
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
README.md
Reformer, the Efficient Transformer, in Pytorch
This is a Pytorch implementation of Reformer https://openreview.net/pdf?id=rkgNKkHtvB
It includes LSH attention, reversible network, and chunking. It has been validated with a toy auto-regressive task.
Install
pip install reformer_pytorch
Usage
The full Reformer
# should fit in ~ 5gb - 8k tokens import torch from reformer_pytorch import Reformer model = Reformer( emb = 512, depth = 12, max_seq_len = 8192, num_tokens= 20000, heads = 8, lsh_dropout = 0.1, causal = True, # auto-regressive or not bucket_size = 64, # average size of qk per bucket, 64 was recommended in paper n_hashes = 8, # should keep at 8 per paper ff_chunks = 200, # number of chunks for feedforward layer weight_tie = False, # tie parameters of each layer for no memory per additional depth attn_chunks = 8, # process lsh attention in chunks, only way for memory to fit when scaling to 16k tokens use_full_attn = False # use full self attention, for comparison ).cuda() x = torch.randint(0, 20000, (1, 8192)).long().cuda() y = model(x)
Self Attention with LSH
import torch from reformer_pytorch import LSHSelfAttention attn = LSHSelfAttention( emb = 128, heads = 8, bucket_size = 64, n_hashes = 8, causal = False ) x = torch.randn(10, 1024, 128) y = attn(x)
LSH (locality sensitive hashing) Attention
import torch from reformer_pytorch import LSHAttention attn = LSHAttention( bucket_size = 64, n_hashes = 16, causal = True ) qk = torch.randn(10, 1024, 128) v = torch.randn(10, 1024, 128) attn_out, buckets = attn(qk, v) # buckets will contain the bucket number (post-argmax) of each token of each batch
Todo
- Make it so Reformer can be used as decoder where queries only attend to fed key/values
- Recurrence like Transformer XL
- All-attention learned memory key values
Recommend
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK