TG Telegram Group & Channel
gonzo-обзоры ML статей | United States America (US)
Create: Update:

Head mixing convolution позволяет перемешивать внимание между разными головами в пределах одного временного шага. Все головы внимания разбиваются на группы заданного размера и перемешивание происходит внутри группы (его также можно рассматривать и как небольшой полносвязный слой). Это делается после софтмакса, но при желании можно делать и до, на логитах, тоже получается pre и post (по дефолту).

Итого, возможны четыре варианта блока MTA с разными комбинациями pre/post свёрток. Тут есть простор для оптимизации, так если оба варианта pre или post, то можно объединить это в одну трёхмерную свёртку.

Group normalization with depth scaling использует GroupNorm и независимый скейлинг для каждой головы по рецепту от Differential Transformer (https://arxiv.org/abs/2410.05258, может кстати тоже его разобрать?).

Эксперименты начинают с игрушечной задачи: модели дают последовательность блоков, каждый из N случайных букв. Далее следует L<N букв вопроса. Задача -- найти блок, содержащий все буквы из вопроса в любом порядке. Модель должна вывести все буквы целевого блока, или только его первый либо последний токен (три разные варианта задачи). Для стандартного трансформера задача сложная, так как требует L кусочков информации для определения целевого блока, и их надо закодировать в один query вектор. С MTA должно быть проще, так как он может найти позицию каждой буквы, а потом свёрткой увеличить вес внимания, если все L букв найдены вместе.

Проверили на N=5 и 8, L=2. Пример задачи (надо найти блок с pb):

hjnvt.qfjgt.whftb.bjtpq. ...(many blocks)... .pxjvf.ulhik.qoiax#pb


Обучали на 1M таких блоков, тестировали на отложенных 1K. Трансформер 4 слоя, 2 головы, размерность 256.

У MTA ошибка почти везде ноль или рядом, у обычного трансформера почти везде двузначные числа процентов. Размеры свёрток были c_q=2 (как L), c_k=2N-1, чтобы можно было покрыть весь блок. Свёртка для голов не использовалась.

Следующий эксперимент с LLM. Предобучили 880M модели с архитектурой LLaMa и сравнили обычный трансформер, Differential Transformer и MTA. Обучали на SlimPajama на 105B токенов. В MTA key-query convolution использовали в каждом четвёртом слое, а head convolution в каждом. Свёртки c_q=6, c_k=11, размер группы 2.

По перплексии MTA лучше (GroupNorm при этом важен). На наборе бенчмарков тоже обычно бьёт остальных, но не везде и разница часто в последней цифре (и непонятно какой доверительный интервал -- обучали дважды). В среднем лучше.

Проверили на отдельном пуле long-range dependency задач: LAMBADA, NeedleIn-A-Haystack и BabiLong. На ламбаде однозначно бьёт, на multi-needle (2,4,6) retrieval точность MTA обычно выше, причём без GroupNorm часто лучше. На BabiLong и QA1-5 у MTA тоже всё хорошо.

Приложили сколько-то визуализаций свёрточных ядер, заметное число близко к identity, но есть и более хитрые. Например, один с диагональной структурой, удобен чтобы находить точное совпадение с паттерном. Или есть аналог edge detection, усиливающий первый или последний из последовательных ключей с высоким вниманием. В свёртках по головам частый паттерн это контраст, вычитание одной головы из другой.

Абляции показали, что даже пары MTA слоёв достаточно для превосходства над бейзлайнами. Все предложенные компоненты что-то улучшают по перплексии.

В целом забавно. Кажется, свёртки по q/k это ещё не предел. Для каких-то задач и языков не удивлюсь, если более забористые и менее локальные интеракции рулят. Главное чтоб параметров много не добавляли. Здесь в примере с LLM разница была на уровне 0.001% (+10K параметров на фоне 880M).

По памяти и FLOPS текущая неоптимизированная имплементация сильно проигрывает у использующих обычное scaled dot product attention: памяти раза в три больше надо, флопсов меньше раз в пять. Но это скорее проблема отсутствия оптимизированного ядра для CUDA. Интересно, компиляция через XLA что бы дала.

Head mixing convolution позволяет перемешивать внимание между разными головами в пределах одного временного шага. Все головы внимания разбиваются на группы заданного размера и перемешивание происходит внутри группы (его также можно рассматривать и как небольшой полносвязный слой). Это делается после софтмакса, но при желании можно делать и до, на логитах, тоже получается pre и post (по дефолту).

Итого, возможны четыре варианта блока MTA с разными комбинациями pre/post свёрток. Тут есть простор для оптимизации, так если оба варианта pre или post, то можно объединить это в одну трёхмерную свёртку.

Group normalization with depth scaling использует GroupNorm и независимый скейлинг для каждой головы по рецепту от Differential Transformer (https://arxiv.org/abs/2410.05258, может кстати тоже его разобрать?).

Эксперименты начинают с игрушечной задачи: модели дают последовательность блоков, каждый из N случайных букв. Далее следует L<N букв вопроса. Задача -- найти блок, содержащий все буквы из вопроса в любом порядке. Модель должна вывести все буквы целевого блока, или только его первый либо последний токен (три разные варианта задачи). Для стандартного трансформера задача сложная, так как требует L кусочков информации для определения целевого блока, и их надо закодировать в один query вектор. С MTA должно быть проще, так как он может найти позицию каждой буквы, а потом свёрткой увеличить вес внимания, если все L букв найдены вместе.

Проверили на N=5 и 8, L=2. Пример задачи (надо найти блок с pb):
hjnvt.qfjgt.whftb.bjtpq. ...(many blocks)... .pxjvf.ulhik.qoiax#pb


Обучали на 1M таких блоков, тестировали на отложенных 1K. Трансформер 4 слоя, 2 головы, размерность 256.

У MTA ошибка почти везде ноль или рядом, у обычного трансформера почти везде двузначные числа процентов. Размеры свёрток были c_q=2 (как L), c_k=2N-1, чтобы можно было покрыть весь блок. Свёртка для голов не использовалась.

Следующий эксперимент с LLM. Предобучили 880M модели с архитектурой LLaMa и сравнили обычный трансформер, Differential Transformer и MTA. Обучали на SlimPajama на 105B токенов. В MTA key-query convolution использовали в каждом четвёртом слое, а head convolution в каждом. Свёртки c_q=6, c_k=11, размер группы 2.

По перплексии MTA лучше (GroupNorm при этом важен). На наборе бенчмарков тоже обычно бьёт остальных, но не везде и разница часто в последней цифре (и непонятно какой доверительный интервал -- обучали дважды). В среднем лучше.

Проверили на отдельном пуле long-range dependency задач: LAMBADA, NeedleIn-A-Haystack и BabiLong. На ламбаде однозначно бьёт, на multi-needle (2,4,6) retrieval точность MTA обычно выше, причём без GroupNorm часто лучше. На BabiLong и QA1-5 у MTA тоже всё хорошо.

Приложили сколько-то визуализаций свёрточных ядер, заметное число близко к identity, но есть и более хитрые. Например, один с диагональной структурой, удобен чтобы находить точное совпадение с паттерном. Или есть аналог edge detection, усиливающий первый или последний из последовательных ключей с высоким вниманием. В свёртках по головам частый паттерн это контраст, вычитание одной головы из другой.

Абляции показали, что даже пары MTA слоёв достаточно для превосходства над бейзлайнами. Все предложенные компоненты что-то улучшают по перплексии.

В целом забавно. Кажется, свёртки по q/k это ещё не предел. Для каких-то задач и языков не удивлюсь, если более забористые и менее локальные интеракции рулят. Главное чтоб параметров много не добавляли. Здесь в примере с LLM разница была на уровне 0.001% (+10K параметров на фоне 880M).

По памяти и FLOPS текущая неоптимизированная имплементация сильно проигрывает у использующих обычное scaled dot product attention: памяти раза в три больше надо, флопсов меньше раз в пять. Но это скорее проблема отсутствия оптимизированного ядра для CUDA. Интересно, компиляция через XLA что бы дала.
👍243👏1👌1


>>Click here to continue<<

gonzo-обзоры ML статей






Share with your best friend
VIEW MORE

United States America Popular Telegram Group (US)