Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cosmos #10660

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft

Cosmos #10660

wants to merge 1 commit into from

Conversation

a-r-r-o-w
Copy link
Member

The cosmos is within us. We are made of star-stuff. We are a way for the universe to know itself.

WIP.

test attention
from typing import Optional
from einops import rearrange

import torch
import torch.nn as nn


class RMSNorm(torch.nn.Module):
    def __init__(
        self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None
    ):
        super().__init__()
        self.eps = eps
        self.learnable_scale = elementwise_affine
        if self.learnable_scale:
            self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
        else:
            self.register_parameter("weight", None)

    def forward(self, x):
        r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
        if self.weight is None:
            return r
        else:
            return r * self.weight.to(dtype=x.dtype, device=x.device)


def get_normalization(name: str, channels: int):
    if name == "I":
        return nn.Identity()
    elif name == "R":
    #     return te.pytorch.RMSNorm(channels, eps=1e-6)
        return RMSNorm(channels, eps=1e-6)
    else:
        raise ValueError(f"Normalization {name} not found")


class Attention(nn.Module):
    def __init__(
        self,
        query_dim: int,
        context_dim=None,
        heads=8,
        dim_head=64,
        dropout=0.0,
        qkv_bias: bool = False,
        out_bias: bool = False,
        qkv_norm: str = "SSI",
        qkv_norm_mode: str = "per_head",
        backend: str = "transformer_engine",
        qkv_format: str = "bshd",
    ) -> None:
        super().__init__()

        self.is_selfattn = context_dim is None  # self attention

        inner_dim = dim_head * heads
        context_dim = query_dim if context_dim is None else context_dim

        self.heads = heads
        self.dim_head = dim_head
        self.qkv_norm_mode = qkv_norm_mode
        self.qkv_format = qkv_format

        if self.qkv_norm_mode == "per_head":
            norm_dim = dim_head
        else:
            raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'")

        self.backend = backend

        self.to_q = nn.Sequential(
            nn.Linear(query_dim, inner_dim, bias=qkv_bias),
            get_normalization(qkv_norm[0], norm_dim),
        )
        self.to_k = nn.Sequential(
            nn.Linear(context_dim, inner_dim, bias=qkv_bias),
            get_normalization(qkv_norm[1], norm_dim),
        )
        self.to_v = nn.Sequential(
            nn.Linear(context_dim, inner_dim, bias=qkv_bias),
            get_normalization(qkv_norm[2], norm_dim),
        )

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim, bias=out_bias),
            nn.Dropout(dropout),
        )

    def cal_qkv(
        self, x, context=None, mask=None, rope_emb=None, **kwargs
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        q = self.to_q[0](x)
        context = x if context is None else context
        k = self.to_k[0](context)
        v = self.to_v[0](context)
        q, k, v = map(
            # lambda t: rearrange(t, "b ... (n c) -> b ... n c", n=self.heads, c=self.dim_head),
            lambda t: rearrange(t, "s b (n c) -> b n s c", n=self.heads, c=self.dim_head),
            (q, k, v),
        )

        q = self.to_q[1](q)
        k = self.to_k[1](k)
        v = self.to_v[1](v)
        if self.is_selfattn and rope_emb is not None:  # only apply to self-attention!
            print("here")
            # q = apply_rotary_pos_emb(q, rope_emb, tensor_format=self.qkv_format, fused=True)
            # k = apply_rotary_pos_emb(k, rope_emb, tensor_format=self.qkv_format, fused=True)
            # apply_rotary_pos_emb inlined
            q_shape = q.shape
            q = q.reshape(*q.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
            q = torch.cat([rope_emb[..., 0] * q[..., 0], rope_emb[..., 1] * q[..., 1]], dim=-1)
            # q = rope_emb[..., 0] * q[..., 0] + rope_emb[..., 1] * q[..., 1]
            q = q.movedim(-1, -2).reshape(*q_shape).to(x.dtype)

            # apply_rotary_pos_emb inlined
            k_shape = k.shape
            k = k.reshape(*k.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
            k = torch.cat([rope_emb[..., 0] * k[..., 0], rope_emb[..., 1] * k[..., 1]], dim=-1)
            # k = rope_emb[..., 0] * k[..., 0] + rope_emb[..., 1] * k[..., 1]
            k = k.movedim(-1, -2).reshape(*k_shape).to(x.dtype)
        return q, k, v

    def cal_attn(self, q, k, v, mask=None):
        out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
        out = rearrange(out, "b n s c -> s b (n c)")
        out = self.to_out(out)
        return out

    def forward(
        self,
        x,
        context=None,
        mask=None,
        rope_emb=None,
        **kwargs,
    ):
        """
        Args:
            x (Tensor): The query tensor of shape [B, Mq, K]
            context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
        """
        q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
        return self.cal_attn(q, k, v, mask)


@torch.no_grad()
def match_rms_norm():
    from diffusers.models.normalization import RMSNorm as DiffusersRMSNorm

    theirs_rmsnorm = RMSNorm(128, elementwise_affine=True, eps=1e-6)
    ours_rmsnorm = DiffusersRMSNorm(128, eps=1e-6, elementwise_affine=True)
    ours_rmsnorm.weight.data.copy_(theirs_rmsnorm.weight.data)

    input = torch.randn(1, 128)
    theirs_output = theirs_rmsnorm(input)
    ours_output = ours_rmsnorm(input)

    print(sum(p.numel() for p in theirs_rmsnorm.parameters()))
    print(sum(p.numel() for p in ours_rmsnorm.parameters()))
    print(torch.allclose(theirs_output, ours_output))


@torch.no_grad()
def match_attention():
    from diffusers.models.attention import Attention as DiffusersAttention

    theirs_attention = Attention(128, 128, heads=8, dim_head=16, qkv_bias=False, out_bias=False, qkv_norm="RRI")
    ours_attention = DiffusersAttention(128, 128, heads=8, dim_head=16, qk_norm="rms_norm", out_bias=False, elementwise_affine=False)
    ours_attention.to_q.weight.data.copy_(theirs_attention.to_q[0].weight.data)
    ours_attention.to_k.weight.data.copy_(theirs_attention.to_k[0].weight.data)
    ours_attention.to_v.weight.data.copy_(theirs_attention.to_v[0].weight.data)
    ours_attention.to_out[0].weight.data.copy_(theirs_attention.to_out[0].weight.data)

    input = torch.randn(1, 42, 128)
    theirs_output = rearrange(theirs_attention(rearrange(input, "b s c -> s b c")), "s b c -> b s c")
    ours_output = ours_attention(input)

    print(sum(p.numel() for p in theirs_attention.parameters()))
    print(sum(p.numel() for p in ours_attention.parameters()))
    print(torch.allclose(theirs_output, ours_output, atol=1e-3))


match_rms_norm()
match_attention()
test ff
from typing import Optional

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint


class FeedForward(nn.Module):
    def __init__(
        self,
        d_model: int,
        d_ff: int,
        dropout: float = 0.1,
        activation=nn.ReLU(),
        is_gated: bool = False,
        bias: bool = False,
    ) -> None:
        super().__init__()

        self.layer1 = nn.Linear(d_model, d_ff, bias=bias)
        self.layer2 = nn.Linear(d_ff, d_model, bias=bias)

        self.dropout = nn.Dropout(dropout)
        self.activation = activation
        self.is_gated = is_gated
        if is_gated:
            self.linear_gate = nn.Linear(d_model, d_ff, bias=False)

    def forward(self, x: torch.Tensor):
        g = self.activation(self.layer1(x))
        if self.is_gated:
            x = g * self.linear_gate(x)
        else:
            x = g
        assert self.dropout.p == 0.0, "we skip dropout"
        return self.layer2(x)


class GPT2FeedForward(FeedForward):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, bias: bool = False):
        super().__init__(
            d_model=d_model,
            d_ff=d_ff,
            dropout=dropout,
            activation=nn.GELU(),
            is_gated=False,
            bias=bias,
        )

    def forward(self, x: torch.Tensor):
        assert self.dropout.p == 0.0, "we skip dropout"

        x = self.layer1(x)

        def activation_layer2_forward(x):
            x = self.activation(x)
            x = self.layer2(x)
            return x

        x = checkpoint(activation_layer2_forward, x, use_reentrant=False)
        return x


@torch.no_grad()
def match_ff():
    from diffusers.models.attention import FeedForward as DiffusersFeedForward

    theirs_ff = FeedForward(128, 512, 0.0, activation=nn.GELU(), is_gated=True, bias=False)
    ours_ff = DiffusersFeedForward(128, mult=4, dropout=0.0, activation_fn="geglu", bias=False)
    ours_ff.net[0].proj.weight.data[:512, :].copy_(theirs_ff.linear_gate.weight.data)
    ours_ff.net[0].proj.weight.data[512:, :].copy_(theirs_ff.layer1.weight.data)
    ours_ff.net[2].weight.data.copy_(theirs_ff.layer2.weight.data)

    input = torch.randn(1, 128)
    theirs_output = theirs_ff(input)
    ours_output = ours_ff(input)

    print(sum(p.numel() for p in theirs_ff.parameters()))
    print(sum(p.numel() for p in ours_ff.parameters()))
    print(torch.allclose(theirs_output, ours_output))


match_ff()
test timesteps
import math

import torch
import torch.nn as nn

class Timesteps(nn.Module):
    def __init__(self, num_channels):
        super().__init__()
        self.num_channels = num_channels

    def forward(self, timesteps):
        in_dype = timesteps.dtype
        half_dim = self.num_channels // 2
        exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
        exponent = exponent / (half_dim - 0.0)

        emb = torch.exp(exponent)
        emb = timesteps[:, None].float() * emb[None, :]

        sin_emb = torch.sin(emb)
        cos_emb = torch.cos(emb)
        emb = torch.cat([cos_emb, sin_emb], dim=-1)

        return emb.to(in_dype)


class TimestepEmbedding(nn.Module):
    def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False):
        super().__init__()
        self.linear_1 = nn.Linear(in_features, out_features, bias=not use_adaln_lora)
        self.activation = nn.SiLU()
        self.use_adaln_lora = use_adaln_lora
        if use_adaln_lora:
            self.linear_2 = nn.Linear(out_features, 3 * out_features, bias=False)
        else:
            self.linear_2 = nn.Linear(out_features, out_features, bias=True)

    def forward(self, sample: torch.Tensor) -> torch.Tensor:
        emb = self.linear_1(sample)
        emb = self.activation(emb)
        emb = self.linear_2(emb)

        if self.use_adaln_lora:
            adaln_lora_B_3D = emb
            emb_B_D = sample
        else:
            emb_B_D = emb
            adaln_lora_B_3D = None

        return emb_B_D, adaln_lora_B_3D


class CosmosTimestepEmbedding(nn.Module):
    def __init__(self, in_features: int, out_features: int) -> None:
        super().__init__()
        self.linear_1 = nn.Linear(in_features, out_features, bias=False)
        self.activation = nn.SiLU()
        self.linear_2 = nn.Linear(out_features, 3 * out_features, bias=False)
    
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        emb = self.linear_1(hidden_states)
        emb = self.activation(emb)
        emb = self.linear_2(emb)
        return hidden_states, emb


@torch.no_grad()
def match_timestep():
    from diffusers.models.embeddings import Timesteps as DiffusersTimesteps

    theirs_timesteps = Timesteps(256)
    ours_timesteps = DiffusersTimesteps(256, flip_sin_to_cos=True, downscale_freq_shift=0.0)

    input = torch.tensor([1000.0], dtype=torch.float32)
    theirs_output = theirs_timesteps(input)
    ours_output = ours_timesteps(input)

    print(torch.allclose(theirs_output, ours_output))


@torch.no_grad()
def match_timestep_embedding():
    theirs_temb = TimestepEmbedding(256, 256, use_adaln_lora=True)
    ours_temb = CosmosTimestepEmbedding(256, 256)
    ours_temb.linear_1.weight.data.copy_(theirs_temb.linear_1.weight.data)
    ours_temb.linear_2.weight.data.copy_(theirs_temb.linear_2.weight.data)

    input = torch.randn(1, 256)
    theirs_output = theirs_temb(input)
    ours_output = ours_temb(input)

    print(sum(p.numel() for p in theirs_temb.parameters()))
    print(sum(p.numel() for p in ours_temb.parameters()))
    print(torch.allclose(theirs_output[0], ours_output[0]))
    print(torch.allclose(theirs_output[1], ours_output[1]))


match_timestep()
match_timestep_embedding()

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants