Pretrained Transformers as Universal Computation Engines
Kevin Lu, Aditya Grover, Pieter Abbeel, Igor Mordatch
Статья: https://arxiv.org/abs/2103.05247
Код: https://github.com/kzl/universal-computation
Прикольная работа весны этого года из Беркли/Google/FB, про которую нелишне ещё раз поговорить.
Общая идея в том, что обычные предобученные трансформеры для NLP (GPT-2) на самом деле довольно легко генерализуют на другие модальности (комп.зрение, вычислительные операции или предсказание структуры белка) без файнтюнинга self-attention или полносвязных слоёв. Получается, что предобучение на языковых задачах выучивает какие-то более глубокие и универсальные структуры вычислительной реальности, помогающие в довольно далёких задачах иной природы.
Для работы берут предобученный трансформер (в данном случае GPT-2), замораживают все его веса кроме линейных входных и выходных слоёв, а также позиционных эмбеддингов и параметров LayerNorm. Эти незамороженные веса потом будут файнтюниться, а такой замороженный трансформер называется Frozen Pretrained Transformer (FPT).
И этот FPT обучается решать задачи с качеством, аналогичным обучению полного трансформера или LSTM, несмотря на то, что файнтюнилась только 0.1% всех параметров и вообще не трогались параметры self-attention. Это интересно.
Задачи взяты следующие:
1. Bit memory: показывают пять битовых строк длины 1000, затем модели показывают одну из этих пяти с замаскированными с вероятностью 0.5 битами, и модель должна восстановить оригинальную строку.
2. Bit XOR: модели дают две битовых строки длины пять (по одному биту за раз), и она должна выдать поэлементный XOR.
3. ListOps: на вход дают последовательность операций со списками (похоже на LISP) и на выходе нужно получить результат этих операций.
4. MNIST: стандартный MNIST, где надо классифицировать картинку с цифрой, но на вход поступают 64 токена с патчами 4x4.
5. CIFAR-10: аналогично
6. CIFAR-10 LRA: модификация с Long-Range Arena, всё переведено в grayscale и flattened и подаётся по токену размера 1 за раз. По сути sequential CIFAR.
7. Remote Homology Detection: по последовательности аминокислот белка предсказать метку (всего 1195 классов). При этом никакого предобучения на базах последовательностей типа Pfam.
В случае GPT-2 Base n_dim=768, n_layers=12, входные и выходные размерности d_in, d_out зависят от задачи, максимальная длина последовательности l.
Для кейса с CIFAR-10 (d_out=10, d_in=16, l=64) получаются обучаемые параметры:
- выходной слой (768*10 = 7680 параметров)
- входной слой (16*768 = 13056)
- параметры LayerNorm (4*768*12 = 36684)
- позиционные эмбеддинги (64*768 = 49512)
Всё скейлится линейно по параметрам датасета. Для базовой модели GPT-2 со 124M параметров это всего лишь 0.086%, а для GPT-2 XL вообще 0.029%.
Поскольку веса внимания заморожены и не обучаются, а входы и выходы не соединяют разные токены, вся коммуникация между токенами заморожена.
Сравниваются с 1) трансформером, который обучается с нуля на задачу, с 2) LSTM также обучающемся сразу на задачу.
На перечисленном пуле задач FPT выглядит не хуже, а местами лучше альтернатив. То есть перенос на другие модальности работает.
В некоторых случаях с маленькими датасетами обучить 12-слойный трансформер с нуля не получается, поэтому берут трансформер поменьше, например трёхслойный для CIFAR. С другой стороны для FPT увеличение модели только улучшает результат.
Сравниваются с разным другим предобучением (случайная инициализация без предобучения, предобучение на bit memory задаче, предобученный ViT -- с ним есть тонкость, потому что он обучался как энкодер, а GPT-2 -- декодер).
Здесь, во-первых, интересно, что случайная инициализация даёт весьма хороший результат (частый кейс и сильный бэйзлайн, рандомные фичи рулят). Во-вторых, предобучение на языке явно рулит.
>>Click here to continue<<