| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import List, Optional |
| |
|
| | import numpy as np |
| | import torch |
| | import transformer_engine as te |
| | from einops import rearrange |
| | from torch import nn |
| | from torch.utils.checkpoint import checkpoint |
| | from transformer_engine.pytorch.attention import DotProductAttention, apply_rotary_pos_emb |
| |
|
| | |
| |
|
| |
|
| | class FeedForward(nn.Module): |
| | """ |
| | Transformer FFN with optional gating |
| | |
| | Parameters: |
| | d_model (int): Dimensionality of input features. |
| | d_ff (int): Dimensionality of the hidden layer. |
| | dropout (float, optional): Dropout rate applied after the activation function. Defaults to 0.1. |
| | activation (callable, optional): The activation function applied after the first linear layer. |
| | Defaults to nn.ReLU(). |
| | is_gated (bool, optional): If set to True, incorporates gating mechanism to the feed-forward layer. |
| | Defaults to False. |
| | bias (bool, optional): If set to True, adds a bias to the linear layers. Defaults to True. |
| | |
| | Example: |
| | >>> ff = FeedForward(d_model=512, d_ff=2048) |
| | >>> x = torch.randn(64, 10, 512) # Example input tensor |
| | >>> output = ff(x) |
| | >>> print(output.shape) # Expected shape: (64, 10, 512) |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | d_model: int, |
| | d_ff: int, |
| | dropout: float = 0.1, |
| | activation=nn.ReLU(), |
| | is_gated: bool = False, |
| | bias: bool = False, |
| | ) -> None: |
| | super().__init__() |
| |
|
| | self.layer1 = nn.Linear(d_model, d_ff, bias=bias) |
| | self.layer2 = nn.Linear(d_ff, d_model, bias=bias) |
| |
|
| | self.dropout = nn.Dropout(dropout) |
| | self.activation = activation |
| | self.is_gated = is_gated |
| | if is_gated: |
| | self.linear_gate = nn.Linear(d_model, d_ff, bias=False) |
| |
|
| | def forward(self, x: torch.Tensor): |
| | g = self.activation(self.layer1(x)) |
| | if self.is_gated: |
| | x = g * self.linear_gate(x) |
| | else: |
| | x = g |
| | assert self.dropout.p == 0.0, "we skip dropout" |
| | return self.layer2(x) |
| |
|
| |
|
| | class GPT2FeedForward(FeedForward): |
| | def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, bias: bool = False): |
| | super().__init__( |
| | d_model=d_model, |
| | d_ff=d_ff, |
| | dropout=dropout, |
| | activation=nn.GELU(), |
| | is_gated=False, |
| | bias=bias, |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor): |
| | assert self.dropout.p == 0.0, "we skip dropout" |
| |
|
| | x = self.layer1(x) |
| |
|
| | def activation_layer2_forward(x): |
| | x = self.activation(x) |
| | x = self.layer2(x) |
| | return x |
| |
|
| | x = checkpoint(activation_layer2_forward, x, use_reentrant=False) |
| | return x |
| |
|
| |
|
| | |
| |
|
| |
|
| | def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor: |
| | """ |
| | Normalizes the input tensor along specified dimensions such that the average square norm of elements is adjusted. |
| | |
| | Args: |
| | x (torch.Tensor): The input tensor to normalize. |
| | dim (list, optional): The dimensions over which to normalize. If None, normalizes over all dimensions except the first. |
| | eps (float, optional): A small constant to ensure numerical stability during division. |
| | |
| | Returns: |
| | torch.Tensor: The normalized tensor. |
| | """ |
| | if dim is None: |
| | dim = list(range(1, x.ndim)) |
| | norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) |
| | norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel())) |
| | return x / norm.to(x.dtype) |
| |
|
| |
|
| | def get_normalization(name: str, channels: int): |
| | if name == "I": |
| | return nn.Identity() |
| | elif name == "R": |
| | return te.pytorch.RMSNorm(channels, eps=1e-6) |
| | else: |
| | raise ValueError(f"Normalization {name} not found") |
| |
|
| |
|
| | class BaseAttentionOp(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| |
|
| |
|
| | class Attention(nn.Module): |
| | """ |
| | Generalized attention impl. |
| | |
| | Allowing for both self-attention and cross-attention configurations depending on whether a `context_dim` is provided. |
| | If `context_dim` is None, self-attention is assumed. |
| | |
| | Parameters: |
| | query_dim (int): Dimension of each query vector. |
| | context_dim (int, optional): Dimension of each context vector. If None, self-attention is assumed. |
| | heads (int, optional): Number of attention heads. Defaults to 8. |
| | dim_head (int, optional): Dimension of each head. Defaults to 64. |
| | dropout (float, optional): Dropout rate applied to the output of the attention block. Defaults to 0.0. |
| | attn_op (BaseAttentionOp, optional): Custom attention operation to be used instead of the default. |
| | qkv_bias (bool, optional): If True, adds a learnable bias to query, key, and value projections. Defaults to False. |
| | out_bias (bool, optional): If True, adds a learnable bias to the output projection. Defaults to False. |
| | qkv_norm (str, optional): A string representing normalization strategies for query, key, and value projections. |
| | Defaults to "SSI". |
| | qkv_norm_mode (str, optional): A string representing normalization mode for query, key, and value projections. |
| | Defaults to 'per_head'. Only support 'per_head'. |
| | |
| | Examples: |
| | >>> attn = Attention(query_dim=128, context_dim=256, heads=4, dim_head=32, dropout=0.1) |
| | >>> query = torch.randn(10, 128) # Batch size of 10 |
| | >>> context = torch.randn(10, 256) # Batch size of 10 |
| | >>> output = attn(query, context) # Perform the attention operation |
| | |
| | Note: |
| | https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | query_dim: int, |
| | context_dim=None, |
| | heads=8, |
| | dim_head=64, |
| | dropout=0.0, |
| | attn_op: Optional[BaseAttentionOp] = None, |
| | qkv_bias: bool = False, |
| | out_bias: bool = False, |
| | qkv_norm: str = "SSI", |
| | qkv_norm_mode: str = "per_head", |
| | backend: str = "transformer_engine", |
| | qkv_format: str = "bshd", |
| | ) -> None: |
| | super().__init__() |
| |
|
| | self.is_selfattn = context_dim is None |
| |
|
| | inner_dim = dim_head * heads |
| | context_dim = query_dim if context_dim is None else context_dim |
| |
|
| | self.heads = heads |
| | self.dim_head = dim_head |
| | self.qkv_norm_mode = qkv_norm_mode |
| | self.qkv_format = qkv_format |
| |
|
| | if self.qkv_norm_mode == "per_head": |
| | norm_dim = dim_head |
| | else: |
| | raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'") |
| |
|
| | self.backend = backend |
| |
|
| | self.to_q = nn.Sequential( |
| | nn.Linear(query_dim, inner_dim, bias=qkv_bias), |
| | get_normalization(qkv_norm[0], norm_dim), |
| | ) |
| | self.to_k = nn.Sequential( |
| | nn.Linear(context_dim, inner_dim, bias=qkv_bias), |
| | get_normalization(qkv_norm[1], norm_dim), |
| | ) |
| | self.to_v = nn.Sequential( |
| | nn.Linear(context_dim, inner_dim, bias=qkv_bias), |
| | get_normalization(qkv_norm[2], norm_dim), |
| | ) |
| |
|
| | self.to_out = nn.Sequential( |
| | nn.Linear(inner_dim, query_dim, bias=out_bias), |
| | nn.Dropout(dropout), |
| | ) |
| |
|
| | if attn_op: |
| | self.attn_op = attn_op |
| | elif self.backend == "transformer_engine": |
| | sequence_parallel = False |
| | self.attn_op: BaseAttentionOp = DotProductAttention( |
| | self.heads, |
| | self.dim_head, |
| | num_gqa_groups=self.heads, |
| | attention_dropout=0, |
| | qkv_format=qkv_format, |
| | attn_mask_type="no_mask", |
| | tp_size=1, |
| | tp_group=None, |
| | sequence_parallel=sequence_parallel, |
| | ) |
| | else: |
| | raise ValueError(f"Backend {backend} not found") |
| |
|
| | def cal_qkv( |
| | self, x, context=None, mask=None, rope_emb=None, **kwargs |
| | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | del kwargs |
| |
|
| | """ |
| | self.to_q, self.to_k, self.to_v are nn.Sequential with projection + normalization layers. |
| | Before 07/24/2024, these modules normalize across all heads. |
| | After 07/24/2024, to support tensor parallelism and follow the common practice in the community, |
| | we support to normalize per head. |
| | To keep the checkpoint copatibility with the previous code, |
| | we keep the nn.Sequential but call the projection and the normalization layers separately. |
| | We use a flag `self.qkv_norm_mode` to control the normalization behavior. |
| | The default value of `self.qkv_norm_mode` is "per_head", which means we normalize per head. |
| | """ |
| | if self.qkv_norm_mode == "per_head": |
| | q = self.to_q[0](x) |
| | context = x if context is None else context |
| | k = self.to_k[0](context) |
| | v = self.to_v[0](context) |
| | q, k, v = map( |
| | lambda t: rearrange(t, "b ... (n c) -> b ... n c", n=self.heads, c=self.dim_head), |
| | (q, k, v), |
| | ) |
| | else: |
| | raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'") |
| |
|
| | q = self.to_q[1](q) |
| | k = self.to_k[1](k) |
| | v = self.to_v[1](v) |
| | if self.is_selfattn and rope_emb is not None: |
| | q = apply_rotary_pos_emb(q, rope_emb, tensor_format=self.qkv_format, fused=True) |
| | k = apply_rotary_pos_emb(k, rope_emb, tensor_format=self.qkv_format, fused=True) |
| | return q, k, v |
| |
|
| | def cal_attn(self, q, k, v, mask=None): |
| | if self.backend == "transformer_engine": |
| | seq_dim = self.qkv_format.index("s") |
| | assert ( |
| | q.shape[seq_dim] > 1 and k.shape[seq_dim] > 1 |
| | ), "Seqlen must be larger than 1 for TE Attention starting with 1.8 TE version." |
| | out = self.attn_op(q, k, v, core_attention_bias_type="no_bias", core_attention_bias=None) |
| | return self.to_out(out) |
| | elif self.backend == "torch": |
| | out = self.attn_op(q, k, v, mask=mask) |
| | return self.to_out(rearrange(out, " b ... n c -> b ... (n c)")) |
| | else: |
| | raise ValueError(f"Backend {self.backend} not found") |
| |
|
| | def forward( |
| | self, |
| | x, |
| | context=None, |
| | mask=None, |
| | rope_emb=None, |
| | **kwargs, |
| | ): |
| | """ |
| | Args: |
| | x (Tensor): The query tensor of shape [B, Mq, K] |
| | context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None |
| | """ |
| | q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs) |
| | return self.cal_attn(q, k, v, mask) |
| |
|