FlashRNN - библиотека, которая реализует традиционные RNN, такие как LSTM, GRU и сети Элмана, а также новейшую архитектуру sLSTM в CUDA и Triton.
В отличие от распространенных современных моделей архитектуры Transformers, RNN обладают возможностями отслеживания состояния, оставаясь актуальными для решения задач моделирования временных рядов и логического мышления.
FlashRNN предлагает два варианта оптимизации: чередующийся и объединенный.
За автоматизацию настройки параметров FlashRNN отвечает библиотека ConstrINT
, которая решает задачи целочисленного удовлетворения ограничений, моделируя аппаратные ограничения в виде равенств, неравенств и ограничений делимости.
Эксперименты с FlashRNN показали существенное увеличение скорости работы: до 50 раз по сравнению с PyTorch. FlashRNN также позволяет использовать большие размеры скрытых состояний, чем нативная реализация Triton.
# Install FlashRNN
pip install flashrnn
# FlashRNN employs a functional structure, none of the parameters are tied to the `flashrnn` function:
import torch
from flashrnn import flashrnn
device = torch.device('cuda')
dtype = torch.bfloat16
B = 8 # batch size
T = 1024 # sequence length
N = 3 # number of heads
D = 256 # head dimension
G = 4 # number of gates / pre-activations for LSTM example
S = 2 # number of states
Wx = torch.randn([B, T, G, N, D], device=device, dtype=dtype, requires_grad=True)
R = torch.randn([G, N, D, D], device=device, dtype=dtype, requires_grad=True)
b = torch.randn([G, N, D], device=device, dtype=dtype, requires_grad=True)
states_initial = torch.randn([S, B, 1, N, D], device=device, dtype=dtype, requires_grad=True)
# available functions
# lstm, gru, elman, slstm
# available backend
# cuda_fused, cuda, triton and vanilla
states, last_states = flashrnn(Wx, R, b, states=states_initial, function="lstm", backend="cuda_fused")
# for LSTM the hidden h state is the first of [h, c]
# [S, B, T, N, D]
hidden_state = states[0]
@ai_machinelearning_big_data
#AI #ML #RNN #FlashRNN