Multi-Token Attention
Olga Golovneva, Tianlu Wang, Jason Weston, Sainbayar Sukhbaatar
Статья: https://arxiv.org/abs/2504.00927
Продолжаем разборы архитектур.
Как известно, веса внимания в классическом механизме внимания определяются одним вектором значений query и одним вектором значений key. Этот “single token attention” является своеобразным боттлнеком для отделения важных частей от всего остального. Новый подход Multi-Token Attention (MTA) позволяет устранить боттлнек и определять веса внимания на основе нескольких векторов query и keys одновременно
Напомним, что в стандартном внимании веса внимания определяются как softmax(QK/sqrt(d)). Для каждого токена есть вектор эмбеддинга, этот вектор проецируется в три отдельных вектора Q, K и V, и по скалярному произведению векторов Q и K различных токенов определяется их “похожесть” или “важность”. После нормализации на корень от размерности эмбеддинга и взятию софтмакса от результата получаются веса внимания A. Далее с этими весами взвешиваются и суммируются вектора V и генерятся новые эмбеддинги для каждого токена. На наличие множества голов, маски декодера и прочего мы в этом объяснении забиваем, если хотите лучше понять/вспомнить этот процесс, отсылаю к классике (https://jalammar.github.io/illustrated-transformer/).
Внутри и снаружи этого базового механизма внимания можно много чего модифицировать -- мы писали про температуру в софтмаксе (https://hottg.com/gonzo_ML/3013), про отказ от нормализации до или после слоёв внимания (https://hottg.com/gonzo_ML/3478), 100500 вариантов разреженного и прочего модифицированного внимания, которые даже перечислять долго (просто как пример -- Reformer, https://hottg.com/gonzo_ML/176, далее воспользуйтесь поиском по каналу). Текущая работа тоже где-то в этом пуле.
Допустим, мы хотим найти предложение, содержащее несколько элементов. Пусть для примера это будет предложение “Where did Alice see the rabbit?” и мы хотим найти одновременное упоминание Алисы и кролика, им соответствуют query вектора q_a и q_r. Стандартный механизм считает веса внимания описанным выше способом, мы можем “найти” места в контексте, содержащие эти слова, и нам надо бы проверить, что они находятся где-то в соседних позициях. Но стандартный механизм внимания не даёт этого сделать в пределах одного слоя (через увеличение глубины можно, но хотелось бы и без), поскольку никаких взаимодействий между отдельными attention maps в нём нет, и даже если мы используем отдельные головы внимания для обнаружения Алисы и кролика, то нет механизма для комбинирования этих весов внимания. Модификация внимания в MTA позволяет добавить это взаимодействие между attention maps для соседних позиций Q и K или между отдельными головами.
На уровне конкретных модификаций внутри стандартного механизма внимания появляются три новых блока:
1) key-query convolution: комбинирует несколько key и query внутри головы
2) head mixing convolution: шарит информацию между головами и усиливает важную
3) group normalization with depth scaling: улучшает поток градиентов
Key-query convolution перемешивает веса внимания от разных временных шагов и работает так: к логитам внимания перед софтсаксом (QK/sqrt(d)) применяется двумерная обучаемая свёртка по измерениям q и k, измерения для батча и голов внимания не трогаются. Каждая голова внимания учит свою свёртку. Внутри свёртки используется маска с занулением элементов, чтобы не залезать в будущее. Это был pre-softmax convolution, он будет использоваться по дефолту. Можно также сделать post-softmax convolution, тогда свёртка считается не поверх логитов, а уже после софтмакса. Это делает взаимодействия между весами внимания аддитивными, а не мультипликативными. Я кстати не до конца понял, почему они до софтмакса прям мультипликативные...
>>Click here to continue<<