Desmitificando la lógica de atención de los transformadores: desentrañando la intuición y la implementación

Finalmente calculamos las puntuaciones y la salida de atención.

Cálculo de salida de atención
import torch
import torch.nn as nn
from typing import List

def get_input_embeddings(words: List[str], embeddings_dim: int):
# we are creating random vector of embeddings_dim size for each words
# normally we train a tokenizer to get the embeddings.
# check the blog on tokenizer to learn about this part
embeddings = [torch.randn(embeddings_dim) for word in words]
return embeddings

text = "I should sleep now"
words = text.split(" ")
len(words) # 4

embeddings_dim = 512 # 512 dim because the original paper uses it. we can use other dim also
embeddings = get_input_embeddings(words, embeddings_dim=embeddings_dim)
embeddings[0].shape # torch.Size([512])

# initialize the query, key and value metrices
query_matrix = nn.Linear(embeddings_dim, embeddings_dim)
key_matrix = nn.Linear(embeddings_dim, embeddings_dim)
value_matrix = nn.Linear(embeddings_dim, embeddings_dim)
query_matrix.weight.shape, key_matrix.weight.shape, value_matrix.weight.shape # torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512])

# query, key and value vectors computation for each words embeddings
query_vectors = torch.stack([query_matrix(embedding) for embedding in embeddings])
key_vectors = torch.stack([key_matrix(embedding) for embedding in embeddings])
value_vectors = torch.stack([value_matrix(embedding) for embedding in embeddings])
query_vectors.shape, key_vectors.shape, value_vectors.shape # torch.Size([4, 512]), torch.Size([4, 512]), torch.Size([4, 512])

# compute the score
scores = torch.matmul(query_vectors, key_vectors.transpose(-2, -1)) / torch.sqrt(torch.tensor(embeddings_dim, dtype=torch.float32))
scores.shape # torch.Size([4, 4])

# compute the attention weights for each of the words with the other words
softmax = nn.Softmax(dim=-1)
attention_weights = softmax(scores)
attention_weights.shape # torch.Size([4, 4])

# attention output
output = torch.matmul(attention_weights, value_vectors)
output.shape # torch.Size([4, 512])

Porque nunca se puede tener demasiada atención. 😛 – yo

La atención que he mencionado anteriormente es la atención de una sola cabeza. En la atención de múltiples cabezas tenemos más de una cabeza, 8 cabezas en el papel original.

Tanto para el cálculo de la atención de múltiples cabezas como de una sola cabeza son los mismos hasta consulta (q0-q3), clave (k0-k3), valor (v0-v3) vector intermedio.

Atención de múltiples cabezas

Después de eso, dividimos el vector de consulta en partes iguales en el número de cabezas que tenemos. En la imagen de arriba tenemos 8 cabezas y los vectores de consulta, clave y valor tienen una dimensión de 512. Entonces creamos 8 vectores de 64 dimensiones.

Llevamos los primeros 64 vectores dim a la primera cabeza, el segundo conjunto de vectores a la segunda cabeza y así sucesivamente. En la imagen de arriba solo he mostrado el cálculo para la primera cabeza.

Después de tener las miniconsultas, claves y valores (los que tienen 64 dim) en un encabezado, calculamos la lógica restante igual que la atención de un solo encabezado. Finalmente, tenemos 4 vectores de 64 dimensiones de cada una de las cabezas.

Combinamos las primeras 64 salidas de cada cabezal para obtener el vector de salida final de 512 atenuaciones. Lo mismo para los resultados de los 3 vectores restantes.

Combina los resultados de las cabezas.

Los transformadores con cabezales múltiples tienen una mayor capacidad para representar relaciones complejas en los datos. Cada cabeza es capaz de aprender diferentes patrones. Múltiples cabezales también brindan la capacidad de atender a diferentes subespacios (64 vectores dim del vector original 512 dim) de la representación de entrada simultáneamente.

num_heads = 8
# batch dim is 1 since we are processing one text.
batch_size = 1

text = "I should sleep now"
words = text.split(" ")
len(words) # 4

embeddings_dim = 512
embeddings = get_input_embeddings(words, embeddings_dim=embeddings_dim)
embeddings[0].shape # torch.Size([512])

# initialize the query, key and value metrices
query_matrix = nn.Linear(embeddings_dim, embeddings_dim)
key_matrix = nn.Linear(embeddings_dim, embeddings_dim)
value_matrix = nn.Linear(embeddings_dim, embeddings_dim)
query_matrix.weight.shape, key_matrix.weight.shape, value_matrix.weight.shape # torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512])

# query, key and value vectors computation for each words embeddings
query_vectors = torch.stack([query_matrix(embedding) for embedding in embeddings])
key_vectors = torch.stack([key_matrix(embedding) for embedding in embeddings])
value_vectors = torch.stack([value_matrix(embedding) for embedding in embeddings])
query_vectors.shape, key_vectors.shape, value_vectors.shape # torch.Size([4, 512]), torch.Size([4, 512]), torch.Size([4, 512])

# (batch_size, num_heads, seq_len, embeddings_dim)
query_vectors_view = query_vectors.view(batch_size, -1, num_heads, embeddings_dim//num_heads).transpose(1, 2)
key_vectors_view = key_vectors.view(batch_size, -1, num_heads, embeddings_dim//num_heads).transpose(1, 2)
value_vectors_view = value_vectors.view(batch_size, -1, num_heads, embeddings_dim//num_heads).transpose(1, 2)
query_vectors_view.shape, key_vectors_view.shape, value_vectors_view.shape
# torch.Size([1, 8, 4, 64]),
# torch.Size([1, 8, 4, 64]),
# torch.Size([1, 8, 4, 64])

# We are splitting the each vectors into 8 heads.
# Assuming we have one text (batch size of 1), So we split
# the embedding vectors also into 8 parts. Each head will
# take these parts. If we do this one head at a time.
head1_query_vector = query_vectors_view[0, 0, ...]
head1_key_vector = key_vectors_view[0, 0, ...]
head1_value_vector = value_vectors_view[0, 0, ...]
head1_query_vector.shape, head1_key_vector.shape, head1_value_vector.shape

# The above vectors are of same size as before only the feature dim is changed from 512 to 64
# compute the score
scores_head1 = torch.matmul(head1_query_vector, head1_key_vector.permute(1, 0)) / torch.sqrt(torch.tensor(embeddings_dim//num_heads, dtype=torch.float32))
scores_head1.shape # torch.Size([4, 4])

# compute the attention weights for each of the words with the other words
softmax = nn.Softmax(dim=-1)
attention_weights_head1 = softmax(scores_head1)
attention_weights_head1.shape # torch.Size([4, 4])

output_head1 = torch.matmul(attention_weights_head1, head1_value_vector)
output_head1.shape # torch.Size([4, 512])

# we can compute the output for all the heads
outputs = []
for head_idx in range(num_heads):
head_idx_query_vector = query_vectors_view[0, head_idx, ...]
head_idx_key_vector = key_vectors_view[0, head_idx, ...]
head_idx_value_vector = value_vectors_view[0, head_idx, ...]
scores_head_idx = torch.matmul(head_idx_query_vector, head_idx_key_vector.permute(1, 0)) / torch.sqrt(torch.tensor(embeddings_dim//num_heads, dtype=torch.float32))

softmax = nn.Softmax(dim=-1)
attention_weights_idx = softmax(scores_head_idx)
output = torch.matmul(attention_weights_idx, head_idx_value_vector)
outputs.append(output)

[out.shape for out in outputs]
# [torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64])]

# stack the result from each heads for the corresponding words
word0_outputs = torch.cat([out[0] for out in outputs])
word0_outputs.shape

# lets do it for all the words
attn_outputs = []
for i in range(len(words)):
attn_output = torch.cat([out[i] for out in outputs])
attn_outputs.append(attn_output)
[attn_output.shape for attn_output in attn_outputs] # [torch.Size([512]), torch.Size([512]), torch.Size([512]), torch.Size([512])]

# Now lets do it in vectorize way.
# We can not permute the last two dimension of the key vector.
key_vectors_view.permute(0, 1, 3, 2).shape # torch.Size([1, 8, 64, 4])

# Transpose the key vector on the last dim
score = torch.matmul(query_vectors_view, key_vectors_view.permute(0, 1, 3, 2)) # Q*k
score = torch.softmax(score, dim=-1)

# reshape the results
attention_results = torch.matmul(score, value_vectors_view)
attention_results.shape # [1, 8, 4, 64]

# merge the results
attention_results = attention_results.permute(0, 2, 1, 3).contiguous().view(batch_size, -1, embeddings_dim)
attention_results.shape # torch.Size([1, 4, 512])

El código implementado en este blog está agregado en este cuaderno. Siéntete libre de editar y probar cosas.

Espero que hayas disfrutado este blog. 🤗 Si está interesado en leer sobre el transformador de visión, consulte este blog: