diff --git a/diffsynth_engine/configs/__init__.py b/diffsynth_engine/configs/__init__.py index 9474d9c..2761cef 100644 --- a/diffsynth_engine/configs/__init__.py +++ b/diffsynth_engine/configs/__init__.py @@ -1,7 +1,9 @@ from .base import PipelineConfig from .qwen_image import QwenImagePipelineConfig +from .wan import WanPipelineConfig __all__ = [ "PipelineConfig", "QwenImagePipelineConfig", + "WanPipelineConfig", ] diff --git a/diffsynth_engine/configs/wan.py b/diffsynth_engine/configs/wan.py new file mode 100644 index 0000000..6b437a4 --- /dev/null +++ b/diffsynth_engine/configs/wan.py @@ -0,0 +1,8 @@ +from dataclasses import dataclass + +from diffsynth_engine.configs.base import PipelineConfig + + +@dataclass +class WanPipelineConfig(PipelineConfig): + pass diff --git a/diffsynth_engine/forward_context.py b/diffsynth_engine/forward_context.py index 71a61a9..0415e0c 100644 --- a/diffsynth_engine/forward_context.py +++ b/diffsynth_engine/forward_context.py @@ -3,6 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + from contextlib import contextmanager from dataclasses import dataclass from typing import TYPE_CHECKING, Optional @@ -13,7 +15,7 @@ @dataclass class ForwardContext: - attn_metadata: Optional["AttentionMetadata"] = None + attn_metadata: Optional[AttentionMetadata] = None attn_type: Optional[str] = None @@ -42,7 +44,7 @@ def override_forward_context(forward_context: Optional[ForwardContext] = None): @contextmanager def set_forward_context( - attn_metadata: Optional["AttentionMetadata"] = None, + attn_metadata: Optional[AttentionMetadata] = None, attn_type: Optional[str] = None, ): """A context manager to that stores the current forward context.""" diff --git a/diffsynth_engine/models/base.py b/diffsynth_engine/models/base.py index a21405a..3b10d95 100644 --- a/diffsynth_engine/models/base.py +++ b/diffsynth_engine/models/base.py @@ -15,6 +15,9 @@ class DiffusionModel(nn.Module, ConfigMixin): config_name = CONFIG_NAME + _keep_in_fp32_modules: list[str] | None = None + _keys_to_ignore_on_load_unexpected: list[str] | None = None + @property def dtype(self) -> torch.dtype: param = next(self.parameters(), None) @@ -37,8 +40,30 @@ def from_pretrained( with init_empty_weights(): model = cls.from_config(config_dict) - # load model weights - state_dict = load_model_weights(model_path, subfolder, device, dtype) + # avoid precision loss + if dtype is not None and dtype != torch.float32 and cls._keep_in_fp32_modules: + state_dict = load_model_weights(model_path, subfolder, device, dtype=None) + for key in state_dict: + if any(m in key.split(".") for m in cls._keep_in_fp32_modules): + state_dict[key] = state_dict[key].to(device=device, dtype=torch.float32) + else: + state_dict[key] = state_dict[key].to(device=device, dtype=dtype) + else: + state_dict = load_model_weights(model_path, subfolder, device, dtype) + + # filter unexpected keys that the model explicitly ignores + if cls._keys_to_ignore_on_load_unexpected: + keys_to_remove = [ + key for key in state_dict if any(pattern in key for pattern in cls._keys_to_ignore_on_load_unexpected) + ] + for key in keys_to_remove: + del state_dict[key] + if keys_to_remove: + logger.info( + f"Dropped {len(keys_to_remove)} unexpected key(s) matching " + f"{cls._keys_to_ignore_on_load_unexpected} from state_dict." + ) + model.load_state_dict(state_dict, strict=True, assign=True) model.to(device=device) return model diff --git a/diffsynth_engine/models/wan/__init__.py b/diffsynth_engine/models/wan/__init__.py new file mode 100644 index 0000000..4436a2e --- /dev/null +++ b/diffsynth_engine/models/wan/__init__.py @@ -0,0 +1,6 @@ +from .autoencoder_kl_wan import AutoencoderKLWan +from .transformer_wan import WanTransformer3DModel +from .transformer_wan_animate import WanAnimateTransformer3DModel +from .transformer_wan_vace import WanVACETransformer3DModel + +__all__ = ["AutoencoderKLWan", "WanTransformer3DModel", "WanAnimateTransformer3DModel", "WanVACETransformer3DModel"] diff --git a/diffsynth_engine/models/wan/autoencoder_kl_wan.py b/diffsynth_engine/models/wan/autoencoder_kl_wan.py new file mode 100644 index 0000000..5e06ba9 --- /dev/null +++ b/diffsynth_engine/models/wan/autoencoder_kl_wan.py @@ -0,0 +1,1665 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_wan.py + +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import register_to_config +from diffusers.models.activations import get_activation +from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution +from diffusers.models.modeling_outputs import AutoencoderKLOutput + +from diffsynth_engine.models.base import DiffusionModel +from diffsynth_engine.utils import logging + +logger = logging.get_logger(__name__) + +CACHE_T = 2 + + +class AvgDown3D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert in_channels * self.factor % out_channels == 0 + self.group_size = in_channels * self.factor // out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t + pad = (0, 0, 0, 0, pad_t, 0) + x = F.pad(x, pad) + B, C, T, H, W = x.shape + x = x.view( + B, + C, + T // self.factor_t, + self.factor_t, + H // self.factor_s, + self.factor_s, + W // self.factor_s, + self.factor_s, + ) + x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + x = x.view( + B, + C * self.factor, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.view( + B, + self.out_channels, + self.group_size, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.mean(dim=2) + return x + + +class DupUp3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert out_channels * self.factor % in_channels == 0 + self.repeats = out_channels * self.factor // in_channels + + def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor: + x = x.repeat_interleave(self.repeats, dim=1) + x = x.view( + x.size(0), + self.out_channels, + self.factor_t, + self.factor_s, + self.factor_s, + x.size(2), + x.size(3), + x.size(4), + ) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + x = x.view( + x.size(0), + self.out_channels, + x.size(2) * self.factor_t, + x.size(4) * self.factor_s, + x.size(6) * self.factor_s, + ) + if first_chunk: + x = x[:, :, self.factor_t - 1 :, :, :] + return x + + +class WanCausalConv3d(nn.Conv3d): + r""" + A custom 3D causal convolution layer with feature caching support. + + This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature + caching for efficient inference. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0 + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int, int], + stride: int | tuple[int, int, int] = 1, + padding: int | tuple[int, int, int] = 0, + ) -> None: + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + # Set up causal padding + self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + return super().forward(x) + + +class WanRMS_norm(nn.Module): + r""" + A custom RMS normalization layer. + + Args: + dim (int): The number of dimensions to normalize over. + channel_first (bool, optional): Whether the input tensor has channels as the first dimension. + Default is True. + images (bool, optional): Whether the input represents image data. Default is True. + bias (bool, optional): Whether to include a learnable bias term. Default is False. + """ + + def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class WanUpsample(nn.Upsample): + r""" + Perform upsampling while ensuring the output tensor has the same data type as the input. + + Args: + x (torch.Tensor): Input tensor to be upsampled. + + Returns: + torch.Tensor: Upsampled tensor with the same data type as the input. + """ + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + +class WanResample(nn.Module): + r""" + A custom resampling module for 2D and 3D data. + + Args: + dim (int): The number of input/output channels. + mode (str): The resampling mode. Must be one of: + - 'none': No resampling (identity operation). + - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution. + - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution. + - 'downsample2d': 2D downsampling with zero-padding and convolution. + - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. + """ + + def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None: + super().__init__() + self.dim = dim + self.mode = mode + + # default to dim //2 + if upsample_out_dim is None: + upsample_out_dim = dim // 2 + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, upsample_out_dim, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, upsample_out_dim, 3, padding=1), + ) + self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + +class WanResidualBlock(nn.Module): + r""" + A custom residual block module. + + Args: + in_dim (int): Number of input channels. + out_dim (int): Number of output channels. + dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0. + non_linearity (str, optional): Type of non-linearity to use. Default is "silu". + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + dropout: float = 0.0, + non_linearity: str = "silu", + ) -> None: + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.nonlinearity = get_activation(non_linearity) + + # layers + self.norm1 = WanRMS_norm(in_dim, images=False) + self.conv1 = WanCausalConv3d(in_dim, out_dim, 3, padding=1) + self.norm2 = WanRMS_norm(out_dim, images=False) + self.dropout = nn.Dropout(dropout) + self.conv2 = WanCausalConv3d(out_dim, out_dim, 3, padding=1) + self.conv_shortcut = WanCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # Apply shortcut connection + h = self.conv_shortcut(x) + + # First normalization and activation + x = self.norm1(x) + x = self.nonlinearity(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + # Second normalization and activation + x = self.norm2(x) + x = self.nonlinearity(x) + + # Dropout + x = self.dropout(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv2(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv2(x) + + # Add residual connection + return x + h + + +class WanAttentionBlock(nn.Module): + r""" + Causal self-attention with a single head. + + Args: + dim (int): The number of channels in the input tensor. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = WanRMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + def forward(self, x): + identity = x + batch_size, channels, time, height, width = x.size() + + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width) + x = self.norm(x) + + # compute query, key, value + qkv = self.to_qkv(x) + qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) + qkv = qkv.permute(0, 1, 3, 2).contiguous() + q, k, v = qkv.chunk(3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention(q, k, v) + + x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width) + + # output projection + x = self.proj(x) + + # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w] + x = x.view(batch_size, time, channels, height, width) + x = x.permute(0, 2, 1, 3, 4) + + return x + identity + + +class WanMidBlock(nn.Module): + """ + Middle block for WanVAE encoder and decoder. + + Args: + dim (int): Number of input/output channels. + dropout (float): Dropout rate. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1): + super().__init__() + self.dim = dim + + # Create the components + resnets = [WanResidualBlock(dim, dim, dropout, non_linearity)] + attentions = [] + for _ in range(num_layers): + attentions.append(WanAttentionBlock(dim)) + resnets.append(WanResidualBlock(dim, dim, dropout, non_linearity)) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # First residual block + x = self.resnets[0](x, feat_cache=feat_cache, feat_idx=feat_idx) + + # Process through attention and residual blocks + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + x = attn(x) + + x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx) + + return x + + +class WanResidualDownBlock(nn.Module): + def __init__(self, in_dim, out_dim, dropout, num_res_blocks, temperal_downsample=False, down_flag=False): + super().__init__() + + # Shortcut path with downsample + self.avg_shortcut = AvgDown3D( + in_dim, + out_dim, + factor_t=2 if temperal_downsample else 1, + factor_s=2 if down_flag else 1, + ) + + # Main path with residual blocks and downsample + resnets = [] + for _ in range(num_res_blocks): + resnets.append(WanResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + self.resnets = nn.ModuleList(resnets) + + # Add the final downsample block + if down_flag: + mode = "downsample3d" if temperal_downsample else "downsample2d" + self.downsampler = WanResample(out_dim, mode=mode) + else: + self.downsampler = None + + def forward(self, x, feat_cache=None, feat_idx=[0]): + x_copy = x.clone() + for resnet in self.resnets: + x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx) + if self.downsampler is not None: + x = self.downsampler(x, feat_cache=feat_cache, feat_idx=feat_idx) + + return x + self.avg_shortcut(x_copy) + + +class WanEncoder3d(nn.Module): + r""" + A 3D encoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_downsample (list of bool): Whether to downsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + in_channels: int = 3, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + non_linearity: str = "silu", + is_residual: bool = False, # wan 2.2 vae use a residual downblock + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1) + + # downsample blocks + self.down_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if is_residual: + self.down_blocks.append( + WanResidualDownBlock( + in_dim, + out_dim, + dropout, + num_res_blocks, + temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False, + down_flag=i != len(dim_mult) - 1, + ) + ) + else: + for _ in range(num_res_blocks): + self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + self.down_blocks.append(WanAttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + self.down_blocks.append(WanResample(out_dim, mode=mode)) + scale /= 2.0 + + # middle blocks + self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1) + + # output blocks + self.norm_out = WanRMS_norm(out_dim, images=False) + self.conv_out = WanCausalConv3d(out_dim, z_dim, 3, padding=1) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## downsamples + for layer in self.down_blocks: + if feat_cache is not None: + x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx) + else: + x = layer(x) + + ## middle + x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + + return x + + +class WanResidualUpBlock(nn.Module): + """ + A block that handles upsampling for the WanVAE decoder. + + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + temperal_upsample (bool): Whether to upsample on temporal dimension + up_flag (bool): Whether to upsample or not + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + temperal_upsample: bool = False, + up_flag: bool = False, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + if up_flag: + self.avg_shortcut = DupUp3D( + in_dim, + out_dim, + factor_t=2 if temperal_upsample else 1, + factor_s=2, + ) + else: + self.avg_shortcut = None + + # create residual blocks + resnets = [] + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity)) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + if up_flag: + upsample_mode = "upsample3d" if temperal_upsample else "upsample2d" + self.upsampler = WanResample(out_dim, mode=upsample_mode, upsample_out_dim=out_dim) + else: + self.upsampler = None + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + """ + Forward pass through the upsampling block. + + Args: + x (torch.Tensor): Input tensor + feat_cache (list, optional): Feature cache for causal convolutions + feat_idx (list, optional): Feature index for cache management + + Returns: + torch.Tensor: Output tensor + """ + x_copy = x.clone() + + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx) + else: + x = resnet(x) + + if self.upsampler is not None: + if feat_cache is not None: + x = self.upsampler(x, feat_cache=feat_cache, feat_idx=feat_idx) + else: + x = self.upsampler(x) + + if self.avg_shortcut is not None: + x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk) + + return x + + +class WanUpBlock(nn.Module): + """ + A block that handles upsampling for the WanVAE decoder. + + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d') + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + upsample_mode: str | None = None, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # Create layers list + resnets = [] + # Add residual blocks and attention if needed + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity)) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + self.upsamplers = None + if upsample_mode is not None: + self.upsamplers = nn.ModuleList([WanResample(out_dim, mode=upsample_mode)]) + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None): + """ + Forward pass through the upsampling block. + + Args: + x (torch.Tensor): Input tensor + feat_cache (list, optional): Feature cache for causal convolutions + feat_idx (list, optional): Feature index for cache management + + Returns: + torch.Tensor: Output tensor + """ + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx) + else: + x = resnet(x) + + if self.upsamplers is not None: + if feat_cache is not None: + x = self.upsamplers[0](x, feat_cache=feat_cache, feat_idx=feat_idx) + else: + x = self.upsamplers[0](x) + return x + + +class WanDecoder3d(nn.Module): + r""" + A 3D decoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_upsample (list of bool): Whether to upsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + non_linearity: str = "silu", + out_channels: int = 3, + is_residual: bool = False, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + + # init block + self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.mid_block = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1) + + # upsample blocks + self.up_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i > 0 and not is_residual: + # wan vae 2.1 + in_dim = in_dim // 2 + + # determine if we need upsampling + up_flag = i != len(dim_mult) - 1 + # determine upsampling mode, if not upsampling, set to None + upsample_mode = None + if up_flag and temperal_upsample[i]: + upsample_mode = "upsample3d" + elif up_flag: + upsample_mode = "upsample2d" + + # Create and add the upsampling block + if is_residual: + up_block = WanResidualUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + temperal_upsample=temperal_upsample[i] if up_flag else False, + up_flag=up_flag, + non_linearity=non_linearity, + ) + else: + up_block = WanUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + upsample_mode=upsample_mode, + non_linearity=non_linearity, + ) + self.up_blocks.append(up_block) + + # output blocks + self.norm_out = WanRMS_norm(out_dim, images=False) + self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1) + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## middle + x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx) + + ## upsamples + for up_block in self.up_blocks: + x = up_block(x, feat_cache=feat_cache, feat_idx=feat_idx, first_chunk=first_chunk) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +def patchify(x, patch_size): + if patch_size == 1: + return x + + if x.dim() != 5: + raise ValueError(f"Invalid input shape: {x.shape}") + + # x shape: [batch_size, channels, frames, height, width] + batch_size, channels, frames, height, width = x.shape + + # Ensure height and width are divisible by patch_size + if height % patch_size != 0 or width % patch_size != 0: + raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})") + + # Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size] + x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size) + + # Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size] + x = x.permute(0, 1, 6, 4, 2, 3, 5).contiguous() + x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size) + return x + + +def unpatchify(x, patch_size): + if patch_size == 1: + return x + + if x.dim() != 5: + raise ValueError(f"Invalid input shape: {x.shape}") + + # x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width] + batch_size, c_patches, frames, height, width = x.shape + channels = c_patches // (patch_size * patch_size) + + # Reshape to [b, c, patch_size, patch_size, f, h, w] + x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width) + + # Rearrange to [b, c, f, h * patch_size, w * patch_size] + x = x.permute(0, 1, 4, 5, 3, 6, 2).contiguous() + x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size) + return x + + +class AutoencoderKLWan(DiffusionModel): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. + Introduced in [Wan 2.1]. + """ + + @register_to_config + def __init__( + self, + base_dim: int = 96, + decoder_base_dim: int | None = None, + z_dim: int = 16, + dim_mult: list[int] = [1, 2, 4, 4], + num_res_blocks: int = 2, + attn_scales: list[float] = [], + temperal_downsample: list[bool] = [False, True, True], + dropout: float = 0.0, + latents_mean: list[float] = [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ], + latents_std: list[float] = [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ], + is_residual: bool = False, + in_channels: int = 3, + out_channels: int = 3, + patch_size: int | None = None, + scale_factor_temporal: int | None = 4, + scale_factor_spatial: int | None = 8, + ) -> None: + super().__init__() + + self.z_dim = z_dim + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + if decoder_base_dim is None: + decoder_base_dim = base_dim + + self.encoder = WanEncoder3d( + in_channels=in_channels, + dim=base_dim, + z_dim=z_dim * 2, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, + dropout=dropout, + is_residual=is_residual, + ) + self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1) + self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1) + + self.decoder = WanDecoder3d( + dim=decoder_base_dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_upsample=self.temperal_upsample, + dropout=dropout, + out_channels=out_channels, + is_residual=is_residual, + ) + + self.spatial_compression_ratio = scale_factor_spatial + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 192 + self.tile_sample_stride_width = 192 + + # Precompute and cache conv counts for encoder and decoder for clear_cache speedup + self._cached_conv_counts = { + "decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules()) + if self.decoder is not None + else 0, + "encoder": sum(isinstance(m, WanCausalConv3d) for m in self.encoder.modules()) + if self.encoder is not None + else 0, + } + + def enable_tiling( + self, + tile_sample_min_height: int | None = None, + tile_sample_min_width: int | None = None, + tile_sample_stride_height: int | None = None, + tile_sample_stride_width: int | None = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + + def disable_tiling(self) -> None: + """Disable tiled VAE decoding.""" + self.use_tiling = False + + def enable_slicing(self) -> None: + """Enable sliced VAE decoding (process one batch element at a time).""" + self.use_slicing = True + + def disable_slicing(self) -> None: + """Disable sliced VAE decoding.""" + self.use_slicing = False + + def clear_cache(self): + # Use cached conv counts for decoder and encoder to avoid re-iterating modules each call + self._conv_num = self._cached_conv_counts["decoder"] + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = self._cached_conv_counts["encoder"] + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + def _encode(self, x: torch.Tensor): + _, _, num_frame, height, width = x.shape + + self.clear_cache() + if self.config.patch_size is not None: + x = patchify(x, patch_size=self.config.patch_size) + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + iter_ = 1 + (num_frame - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + out = torch.cat([out, out_], 2) + + enc = self.quant_conv(out) + self.clear_cache() + return enc + + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]: + r""" + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True): + _, _, num_frame, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + self.clear_cache() + x = self.post_quant_conv(z) + for i in range(num_frame): + self._conv_idx = [0] + if i == 0: + out = self.decoder( + x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True + ) + else: + out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + + if self.config.patch_size is not None: + out = unpatchify(out, patch_size=self.config.patch_size) + + out = torch.clamp(out, min=-1.0, max=1.0) + + self.clear_cache() + if not return_dict: + return (out,) + + return DecoderOutput(sample=out) + + def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: + r""" + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + _, _, num_frames, height, width = x.shape + encode_spatial_compression_ratio = self.spatial_compression_ratio + if self.config.patch_size is not None: + assert encode_spatial_compression_ratio % self.config.patch_size == 0 + encode_spatial_compression_ratio = self.spatial_compression_ratio // self.config.patch_size + + latent_height = height // encode_spatial_compression_ratio + latent_width = width // encode_spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // encode_spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // encode_spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // encode_spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // encode_spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + self.clear_cache() + time = [] + frame_range = 1 + (num_frames - 1) // 4 + for k in range(frame_range): + self._enc_conv_idx = [0] + if k == 0: + tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + else: + tile = x[ + :, + :, + 1 + 4 * (k - 1) : 1 + 4 * k, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + tile = self.quant_conv(tile) + time.append(tile) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + _, _, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + tile_sample_stride_height = self.tile_sample_stride_height + tile_sample_stride_width = self.tile_sample_stride_width + if self.config.patch_size is not None: + sample_height = sample_height // self.config.patch_size + sample_width = sample_width // self.config.patch_size + tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size + tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size + blend_height = self.tile_sample_min_height // self.config.patch_size - tile_sample_stride_height + blend_width = self.tile_sample_min_width // self.config.patch_size - tile_sample_stride_width + else: + blend_height = self.tile_sample_min_height - tile_sample_stride_height + blend_width = self.tile_sample_min_width - tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + self.clear_cache() + time = [] + for k in range(num_frames): + self._conv_idx = [0] + tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width] + tile = self.post_quant_conv(tile) + decoded = self.decoder( + tile, feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=(k == 0) + ) + time.append(decoded) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if self.config.patch_size is not None: + dec = unpatchify(dec, patch_size=self.config.patch_size) + + dec = torch.clamp(dec, min=-1.0, max=1.0) + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def _build_1d_mask(self, length: int, left_bound: bool, right_bound: bool, border_width: int) -> torch.Tensor: + """Build a 1D linear ramp mask for tile blending.""" + mask = torch.ones((length,)) + if not left_bound: + mask[:border_width] = (torch.arange(border_width) + 1) / border_width + if not right_bound: + mask[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,)) + return mask + + def _build_spatial_mask( + self, + data: torch.Tensor, + is_bound: tuple[bool, bool, bool, bool], + border_width: tuple[int, int], + ) -> torch.Tensor: + """Build a 2D spatial mask for tile blending using linear ramps at borders.""" + _, _, _, spatial_height, spatial_width = data.shape + height_mask = self._build_1d_mask(spatial_height, is_bound[0], is_bound[1], border_width[0]) + width_mask = self._build_1d_mask(spatial_width, is_bound[2], is_bound[3], border_width[1]) + + height_mask = height_mask.unsqueeze(1).expand(spatial_height, spatial_width) + width_mask = width_mask.unsqueeze(0).expand(spatial_height, spatial_width) + + mask = torch.stack([height_mask, width_mask]).min(dim=0).values + mask = mask.reshape(1, 1, 1, spatial_height, spatial_width) + return mask + + def tiled_encode_with_mask( + self, + x: torch.Tensor, + tile_size: tuple[int, int] = (256, 256), + tile_stride: tuple[int, int] = (192, 192), + ) -> torch.Tensor: + """ + Encode using mask-weighted spatial tiling. + + This approach uses smooth gradient masks at tile boundaries for blending, + which can produce better results than the simple blend approach. + + Args: + x: Input tensor [B, C, T, H, W]. + tile_size: (height, width) of each tile in pixel space. + tile_stride: (height, width) stride between tiles in pixel space. + + Returns: + Encoded latent tensor. + """ + _, _, num_frames, height, width = x.shape + size_h, size_w = tile_size + stride_h, stride_w = tile_stride + + if self.config.patch_size is not None: + x = patchify(x, patch_size=self.config.patch_size) + _, _, _, height, width = x.shape + + encode_spatial_compression_ratio = self.spatial_compression_ratio + if self.config.patch_size is not None: + encode_spatial_compression_ratio = self.spatial_compression_ratio // self.config.patch_size + size_h = size_h // self.config.patch_size + size_w = size_w // self.config.patch_size + stride_h = stride_h // self.config.patch_size + stride_w = stride_w // self.config.patch_size + + # Build tile tasks + tasks = [] + for h in range(0, height, stride_h): + if h - stride_h >= 0 and h - stride_h + size_h >= height: + continue + for w in range(0, width, stride_w): + if w - stride_w >= 0 and w - stride_w + size_w >= width: + continue + tasks.append((h, h + size_h, w, w + size_w)) + + out_t = 1 + (num_frames - 1) // 4 + latent_h = height // encode_spatial_compression_ratio + latent_w = width // encode_spatial_compression_ratio + + weight = torch.zeros((1, 1, out_t, latent_h, latent_w), dtype=x.dtype, device="cpu") + values = torch.zeros((1, self.z_dim, out_t, latent_h, latent_w), dtype=x.dtype, device="cpu") + + for h, h_end, w, w_end in tasks: + tile = x[:, :, :, h:h_end, w:w_end] + + self.clear_cache() + iter_ = 1 + (num_frames - 1) // 4 + for k in range(iter_): + self._enc_conv_idx = [0] + if k == 0: + enc_out = self.encoder( + tile[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx + ) + else: + enc_chunk = self.encoder( + tile[:, :, 1 + 4 * (k - 1) : 1 + 4 * k, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + enc_out = torch.cat([enc_out, enc_chunk], 2) + enc_out = self.quant_conv(enc_out) + + # Build posterior and sample + enc_out_cpu = enc_out.to("cpu") + + mask = self._build_spatial_mask( + enc_out_cpu, + is_bound=(h == 0, h_end >= height, w == 0, w_end >= width), + border_width=( + (size_h - stride_h) // encode_spatial_compression_ratio, + (size_w - stride_w) // encode_spatial_compression_ratio, + ), + ).to(dtype=x.dtype, device="cpu") + + target_h = h // encode_spatial_compression_ratio + target_w = w // encode_spatial_compression_ratio + values[ + :, + :, + :, + target_h : target_h + enc_out_cpu.shape[3], + target_w : target_w + enc_out_cpu.shape[4], + ] += enc_out_cpu * mask + weight[ + :, + :, + :, + target_h : target_h + enc_out_cpu.shape[3], + target_w : target_w + enc_out_cpu.shape[4], + ] += mask + + self.clear_cache() + result = values / weight + return result.to(x.device) + + def tiled_decode_with_mask( + self, + z: torch.Tensor, + tile_size: tuple[int, int] = (32, 32), + tile_stride: tuple[int, int] = (24, 24), + ) -> torch.Tensor: + """ + Decode using mask-weighted spatial tiling. + + This approach uses smooth gradient masks at tile boundaries for blending, + which can produce better results than the simple blend approach. + + Args: + z: Input latent tensor [B, C, T, H, W]. + tile_size: (height, width) of each tile in latent space. + tile_stride: (height, width) stride between tiles in latent space. + + Returns: + Decoded video tensor, clamped to [-1, 1]. + """ + _, _, num_frames, height, width = z.shape + size_h, size_w = tile_size + stride_h, stride_w = tile_stride + + upsampling_factor = self.spatial_compression_ratio + + # Build tile tasks + tasks = [] + for h in range(0, height, stride_h): + if h - stride_h >= 0 and h - stride_h + size_h >= height: + continue + for w in range(0, width, stride_w): + if w - stride_w >= 0 and w - stride_w + size_w >= width: + continue + tasks.append((h, h + size_h, w, w + size_w)) + + out_t = num_frames * 4 - 3 + out_channels = self.config.out_channels + out_h = height * upsampling_factor + out_w = width * upsampling_factor + if self.config.patch_size is not None: + out_h = out_h // self.config.patch_size + out_w = out_w // self.config.patch_size + + weight = torch.zeros((1, 1, out_t, out_h, out_w), dtype=z.dtype, device="cpu") + values = torch.zeros((1, out_channels, out_t, out_h, out_w), dtype=z.dtype, device="cpu") + + for h, h_end, w, w_end in tasks: + tile_z = z[:, :, :, h:h_end, w:w_end] + + self.clear_cache() + tile_x = self.post_quant_conv(tile_z) + for k in range(num_frames): + self._conv_idx = [0] + if k == 0: + dec_out = self.decoder( + tile_x[:, :, k : k + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx, + first_chunk=True, + ) + else: + dec_chunk = self.decoder( + tile_x[:, :, k : k + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx, + ) + dec_out = torch.cat([dec_out, dec_chunk], 2) + + if self.config.patch_size is not None: + dec_out = unpatchify(dec_out, patch_size=self.config.patch_size) + + dec_out_cpu = dec_out.to("cpu") + + effective_upsampling = upsampling_factor + if self.config.patch_size is not None: + effective_upsampling = upsampling_factor // self.config.patch_size + + mask = self._build_spatial_mask( + dec_out_cpu, + is_bound=(h == 0, h_end >= height, w == 0, w_end >= width), + border_width=( + (size_h - stride_h) * effective_upsampling, + (size_w - stride_w) * effective_upsampling, + ), + ).to(dtype=z.dtype, device="cpu") + + target_h = h * effective_upsampling + target_w = w * effective_upsampling + values[ + :, + :, + :, + target_h : target_h + dec_out_cpu.shape[3], + target_w : target_w + dec_out_cpu.shape[4], + ] += dec_out_cpu * mask + weight[ + :, + :, + :, + target_h : target_h + dec_out_cpu.shape[3], + target_w : target_w + dec_out_cpu.shape[4], + ] += mask + + self.clear_cache() + result = values / weight + result = result.float().clamp(-1, 1) + return result.to(z.device) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: torch.Generator | None = None, + ) -> DecoderOutput | torch.Tensor: + """ + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make sampling + deterministic. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If `return_dict` is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + x = sample + posterior = self.encode(x).latent_dist + + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, return_dict=return_dict) + return dec diff --git a/diffsynth_engine/models/wan/transformer_wan.py b/diffsynth_engine/models/wan/transformer_wan.py new file mode 100644 index 0000000..cc2d631 --- /dev/null +++ b/diffsynth_engine/models/wan/transformer_wan.py @@ -0,0 +1,602 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_wan.py + +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +import torch.nn as nn +from diffusers.configuration_utils import register_to_config +from diffusers.models.attention import FeedForward +from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.normalization import FP32LayerNorm + +from diffsynth_engine.distributed.utils import sequence_parallel_shard, sequence_parallel_unshard +from diffsynth_engine.forward_context import get_forward_context +from diffsynth_engine.layers.attention import USPAttention +from diffsynth_engine.models.base import DiffusionModel +from diffsynth_engine.utils import logging + +logger = logging.get_logger(__name__) + + +def apply_wan_rotary_emb( + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, +) -> torch.Tensor: + """ + Apply rotary embeddings to input tensors using the given cosine and sine frequency tensors. This function applies + rotary embeddings to the given query or key `hidden_states` tensor using the provided frequency tensors + `freqs_cos` and `freqs_sin`. The input tensor is reshaped into real and imaginary components, and the frequency + tensors are indexed for broadcasting compatibility. The resulting tensor contains rotary embeddings and is returned + as a real tensor. + + Args: + hidden_states (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, S, H, D] + freqs_cos (`torch.Tensor`): Precomputed cosine frequency tensor. + freqs_sin (`torch.Tensor`): Precomputed sine frequency tensor. + + Returns: + `torch.Tensor`: Modified query or key tensor with rotary embeddings. + """ + x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + cos = freqs_cos[..., 0::2] + sin = freqs_sin[..., 1::2] + out = torch.empty_like(hidden_states) + out[..., 0::2] = x1 * cos - x2 * sin + out[..., 1::2] = x1 * sin + x2 * cos + return out.type_as(hidden_states) + + +class WanAttention(nn.Module): + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + eps: float = 1e-5, + dropout: float = 0.0, + added_kv_proj_dim: int | None = None, + cross_attention_dim_head: int | None = None, + ): + super().__init__() + + self.inner_dim = dim_head * heads + self.heads = heads + self.added_kv_proj_dim = added_kv_proj_dim + self.cross_attention_dim_head = cross_attention_dim_head + self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads + + self.to_q = nn.Linear(dim, self.inner_dim, bias=True) + self.to_k = nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_v = nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_out = nn.ModuleList( + [ + nn.Linear(self.inner_dim, dim, bias=True), + nn.Dropout(dropout), + ] + ) + self.norm_q = nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) + self.norm_k = nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) + + self.add_k_proj = self.add_v_proj = None + if added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) + self.norm_added_k = nn.RMSNorm(dim_head * heads, eps=eps) + + forward_context = get_forward_context() + self.usp_attn = USPAttention( + num_heads=heads, + head_size=dim_head, + attn_type=forward_context.attn_type, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + # I2V: split image and text from encoder_hidden_states + encoder_hidden_states_img = None + if self.add_k_proj is not None and encoder_hidden_states is not None: + # 512 is the context length of the text encoder, hardcoded for now + image_context_length = encoder_hidden_states.shape[1] - 512 + encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] + encoder_hidden_states = encoder_hidden_states[:, image_context_length:] + + # QKV projections + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + query = self.to_q(hidden_states) + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + # QK normalization + query = self.norm_q(query) + key = self.norm_k(key) + + # Reshape: [B, S, H*D] -> [B, S, H, D] + query = query.unflatten(2, (self.heads, -1)) + key = key.unflatten(2, (self.heads, -1)) + value = value.unflatten(2, (self.heads, -1)) + + # Apply rotary embeddings + if rotary_emb is not None: + query = apply_wan_rotary_emb(query, *rotary_emb) + key = apply_wan_rotary_emb(key, *rotary_emb) + + # I2V: cross attention with image encoder hidden states + hidden_states_img = None + if encoder_hidden_states_img is not None: + key_img = self.add_k_proj(encoder_hidden_states_img) + value_img = self.add_v_proj(encoder_hidden_states_img) + key_img = self.norm_added_k(key_img) + + key_img = key_img.unflatten(2, (self.heads, -1)) + value_img = value_img.unflatten(2, (self.heads, -1)) + + hidden_states_img = self.usp_attn(query, key_img, value_img) + hidden_states_img = hidden_states_img.flatten(2, 3) + hidden_states_img = hidden_states_img.type_as(query) + + # Self attention + hidden_states = self.usp_attn(query, key, value, attn_mask=attention_mask) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + # Combine I2V attention output + if hidden_states_img is not None: + hidden_states = hidden_states + hidden_states_img + + # Output projection + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + +class WanImageEmbedding(nn.Module): + def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None): + super().__init__() + + self.norm1 = FP32LayerNorm(in_features) + self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu") + self.norm2 = FP32LayerNorm(out_features) + if pos_embed_seq_len is not None: + self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features)) + else: + self.pos_embed = None + + def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor: + if self.pos_embed is not None: + batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape + encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim) + encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed + + hidden_states = self.norm1(encoder_hidden_states_image) + hidden_states = self.ff(hidden_states) + hidden_states = self.norm2(hidden_states) + return hidden_states + + +class WanTimeTextImageEmbedding(nn.Module): + def __init__( + self, + dim: int, + time_freq_dim: int, + time_proj_dim: int, + text_embed_dim: int, + image_embed_dim: int | None = None, + pos_embed_seq_len: int | None = None, + ): + super().__init__() + + self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) + self.act_fn = nn.SiLU() + self.time_proj = nn.Linear(dim, time_proj_dim) + self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") + + self.image_embedder = None + if image_embed_dim is not None: + self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len) + + def forward( + self, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: torch.Tensor | None = None, + timestep_seq_len: int | None = None, + ): + timestep = self.timesteps_proj(timestep) + if timestep_seq_len is not None: + timestep = timestep.unflatten(0, (-1, timestep_seq_len)) + + # Compute time embedding in fp32 to avoid precision loss + with torch.amp.autocast(device_type=timestep.device.type, dtype=torch.float32): + timestep = timestep.float() + temb = self.time_embedder(timestep) + timestep_proj = self.time_proj(self.act_fn(temb)) + timestep_proj = timestep_proj.type_as(encoder_hidden_states) + temb = temb.type_as(encoder_hidden_states) + + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + if encoder_hidden_states_image is not None: + encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) + + return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image + + +class WanRotaryPosEmbed(nn.Module): + def __init__( + self, + attention_head_dim: int, + patch_size: tuple[int, int, int], + max_seq_len: int, + theta: float = 10000.0, + ): + super().__init__() + + self.attention_head_dim = attention_head_dim + self.patch_size = patch_size + self.max_seq_len = max_seq_len + + h_dim = w_dim = 2 * (attention_head_dim // 6) + t_dim = attention_head_dim - h_dim - w_dim + + self.t_dim = t_dim + self.h_dim = h_dim + self.w_dim = w_dim + + with torch.device("cpu"): + freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + + freqs_cos = [] + freqs_sin = [] + + for dim in [t_dim, h_dim, w_dim]: + freq_cos, freq_sin = get_1d_rotary_pos_embed( + dim, + max_seq_len, + theta, + use_real=True, + repeat_interleave_real=True, + freqs_dtype=freqs_dtype, + ) + freqs_cos.append(freq_cos) + freqs_sin.append(freq_sin) + + self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False) + self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False) + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.patch_size + ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w + + split_sizes = [self.t_dim, self.h_dim, self.w_dim] + + freqs_cos = self.freqs_cos.split(split_sizes, dim=1) + freqs_sin = self.freqs_sin.split(split_sizes, dim=1) + + freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) + freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) + freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) + + freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) + freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) + freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) + + freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1) + freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1) + + return freqs_cos, freqs_sin + + +class WanTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + ffn_dim: int, + num_heads: int, + qk_norm: str = "rms_norm_across_heads", + cross_attn_norm: bool = False, + eps: float = 1e-6, + added_kv_proj_dim: int | None = None, + ): + super().__init__() + + # 1. Self-attention + self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.attn1 = WanAttention( + dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + eps=eps, + cross_attention_dim_head=None, + ) + + # 2. Cross-attention + self.attn2 = WanAttention( + dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + eps=eps, + added_kv_proj_dim=added_kv_proj_dim, + cross_attention_dim_head=dim // num_heads, + ) + self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + + # 3. Feed-forward + self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") + self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) + + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + rotary_emb: torch.Tensor, + ) -> torch.Tensor: + if temb.ndim == 4: + # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table.unsqueeze(0) + temb.float() + ).chunk(6, dim=2) + # batch_size, seq_len, 1, inner_dim + shift_msa = shift_msa.squeeze(2) + scale_msa = scale_msa.squeeze(2) + gate_msa = gate_msa.squeeze(2) + c_shift_msa = c_shift_msa.squeeze(2) + c_scale_msa = c_scale_msa.squeeze(2) + c_gate_msa = c_gate_msa.squeeze(2) + else: + # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table + temb.float() + ).chunk(6, dim=1) + + # 1. Self-attention + norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) + attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb) + hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) + + # 2. Cross-attention + norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None) + hidden_states = hidden_states + attn_output + + # 3. Feed-forward + norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( + hidden_states + ) + ff_output = self.ffn(norm_hidden_states) + hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) + + return hidden_states + + +class WanTransformer3DModel(DiffusionModel): + r""" + A Transformer model for video-like data used in the Wan model. + + Args: + patch_size (`tuple[int]`, defaults to `(1, 2, 2)`): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch). + num_attention_heads (`int`, defaults to `40`): + Fixed length for text embeddings. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + text_dim (`int`, defaults to `4096`): + Input dimension for text embeddings. + freq_dim (`int`, defaults to `256`): + Dimension for sinusoidal time embeddings. + ffn_dim (`int`, defaults to `13824`): + Intermediate dimension in feed-forward network. + num_layers (`int`, defaults to `40`): + The number of layers of transformer blocks to use. + cross_attn_norm (`bool`, defaults to `True`): + Enable cross-attention normalization. + qk_norm (`str`, *optional*, defaults to `"rms_norm_across_heads"`): + Query/key normalization type. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + image_dim (`int`, *optional*, defaults to `None`): + Dimension for image embeddings. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + rope_max_seq_len (`int`, defaults to `1024`): + Maximum sequence length for rotary position embeddings. + pos_embed_seq_len (`int`, *optional*, defaults to `None`): + Positional embedding sequence length. + """ + + _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] + _keys_to_ignore_on_load_unexpected = ["norm_added_q"] + + @register_to_config + def __init__( + self, + patch_size: tuple[int, ...] = (1, 2, 2), + num_attention_heads: int = 40, + attention_head_dim: int = 128, + in_channels: int = 16, + out_channels: int = 16, + text_dim: int = 4096, + freq_dim: int = 256, + ffn_dim: int = 13824, + num_layers: int = 40, + cross_attn_norm: bool = True, + qk_norm: str | None = "rms_norm_across_heads", + eps: float = 1e-6, + image_dim: int | None = None, + added_kv_proj_dim: int | None = None, + rope_max_seq_len: int = 1024, + pos_embed_seq_len: int | None = None, + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + # 1. Patch & position embedding + self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) + self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + + # 2. Condition embeddings + # image_embedding_dim=1280 for I2V model + self.condition_embedder = WanTimeTextImageEmbedding( + dim=inner_dim, + time_freq_dim=freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=text_dim, + image_embed_dim=image_dim, + pos_embed_seq_len=pos_embed_seq_len, + ) + + # 3. Transformer blocks + self.blocks = nn.ModuleList( + [ + WanTransformerBlock( + inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim + ) + for _ in range(num_layers) + ] + ) + + # 4. Output norm & projection + self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: torch.Tensor | None = None, + return_dict: bool = True, + ) -> torch.Tensor | Transformer2DModelOutput: + """ + The [`WanTransformer3DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`): + Input `hidden_states`. + timestep (`torch.LongTensor`): + Used to indicate denoising step. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + encoder_hidden_states_image (`torch.Tensor`, *optional*): + Conditional image embeddings for image-conditioned generation. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + rotary_emb = self.rope(hidden_states) + + hidden_states = self.patch_embedding(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + original_seq_len = hidden_states.shape[1] + + # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v) + if timestep.ndim == 2: + ts_seq_len = timestep.shape[1] + timestep = timestep.flatten() # batch_size * seq_len + else: + ts_seq_len = None + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len + ) + if ts_seq_len is not None: + # batch_size, seq_len, 6, inner_dim + timestep_proj = timestep_proj.unflatten(2, (6, -1)) + else: + # batch_size, 6, inner_dim + timestep_proj = timestep_proj.unflatten(1, (6, -1)) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + + rotary_emb_cos, rotary_emb_sin = rotary_emb + hidden_states, rotary_emb_cos, rotary_emb_sin = sequence_parallel_shard( + [hidden_states, rotary_emb_cos, rotary_emb_sin], + seq_dims=[1, 1, 1], + ) + rotary_emb = (rotary_emb_cos, rotary_emb_sin) + + # 4. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.blocks: + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb + ) + else: + for block in self.blocks: + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + + (hidden_states,) = sequence_parallel_unshard([hidden_states], seq_dims=[1], seq_lens=[original_seq_len]) + + # 5. Output norm, projection & unpatchify + if temb.ndim == 3: + # batch_size, seq_len, inner_dim (wan 2.2 ti2v) + shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2) + shift = shift.squeeze(2) + scale = scale.squeeze(2) + else: + # batch_size, inner_dim + shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1) + + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/diffsynth_engine/models/wan/transformer_wan_animate.py b/diffsynth_engine/models/wan/transformer_wan_animate.py new file mode 100644 index 0000000..9a59cef --- /dev/null +++ b/diffsynth_engine/models/wan/transformer_wan_animate.py @@ -0,0 +1,812 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_wan_animate.py + +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import register_to_config +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.normalization import FP32LayerNorm + +from diffsynth_engine.distributed.utils import sequence_parallel_shard, sequence_parallel_unshard +from diffsynth_engine.forward_context import get_forward_context +from diffsynth_engine.layers.attention import USPAttention +from diffsynth_engine.models.base import DiffusionModel +from diffsynth_engine.models.wan.transformer_wan import ( + WanRotaryPosEmbed, + WanTimeTextImageEmbedding, + WanTransformerBlock, +) +from diffsynth_engine.utils import logging + +logger = logging.get_logger(__name__) + +WAN_ANIMATE_MOTION_ENCODER_CHANNEL_SIZES = { + "4": 512, + "8": 512, + "16": 512, + "32": 512, + "64": 256, + "128": 128, + "256": 64, + "512": 32, + "1024": 16, +} + + +class FusedLeakyReLU(nn.Module): + """ + Fused LeakyRelu with scale factor and channel-wise bias. + """ + + def __init__(self, negative_slope: float = 0.2, scale: float = 2**0.5, bias_channels: int | None = None): + super().__init__() + self.negative_slope = negative_slope + self.scale = scale + self.channels = bias_channels + + if self.channels is not None: + self.bias = nn.Parameter(torch.zeros(self.channels)) + else: + self.bias = None + + def forward(self, x: torch.Tensor, channel_dim: int = 1) -> torch.Tensor: + if self.bias is not None: + # Expand self.bias to have all singleton dims except at self.channel_dim + expanded_shape = [1] * x.ndim + expanded_shape[channel_dim] = self.bias.shape[0] + bias = self.bias.reshape(*expanded_shape) + x = x + bias + return F.leaky_relu(x, self.negative_slope) * self.scale + + +class MotionConv2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + bias: bool = True, + blur_kernel: tuple[int, ...] | None = None, + blur_upsample_factor: int = 1, + use_activation: bool = True, + ): + super().__init__() + self.use_activation = use_activation + self.in_channels = in_channels + + # Handle blurring (applying a FIR filter with the given kernel) if available + self.blur = False + if blur_kernel is not None: + p = (len(blur_kernel) - stride) + (kernel_size - 1) + self.blur_padding = ((p + 1) // 2, p // 2) + + kernel = torch.tensor(blur_kernel) + # Convert kernel to 2D if necessary + if kernel.ndim == 1: + kernel = kernel[None, :] * kernel[:, None] + # Normalize kernel + kernel = kernel / kernel.sum() + if blur_upsample_factor > 1: + kernel = kernel * (blur_upsample_factor**2) + self.register_buffer("blur_kernel", kernel, persistent=False) + self.blur = True + + # Main Conv2d parameters (with scale factor) + self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size)) + self.scale = 1 / math.sqrt(in_channels * kernel_size**2) + + self.stride = stride + self.padding = padding + + # If using an activation function, the bias will be fused into the activation + if bias and not self.use_activation: + self.bias = nn.Parameter(torch.zeros(out_channels)) + else: + self.bias = None + + if self.use_activation: + self.act_fn = FusedLeakyReLU(bias_channels=out_channels) + else: + self.act_fn = None + + def forward(self, x: torch.Tensor, channel_dim: int = 1) -> torch.Tensor: + # Apply blur if using + if self.blur: + # NOTE: the original implementation uses a 2D upfirdn operation with the upsampling and downsampling rates + # set to 1, which should be equivalent to a 2D convolution + expanded_kernel = self.blur_kernel[None, None, :, :].expand(self.in_channels, 1, -1, -1) + x = F.conv2d(x, expanded_kernel.to(x.dtype), padding=self.blur_padding, groups=self.in_channels) + + # Main Conv2D with scaling + x = x.to(self.weight.dtype) + x = F.conv2d(x, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding) + + # Activation with fused bias, if using + if self.use_activation: + x = self.act_fn(x, channel_dim=channel_dim) + return x + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," + f" kernel_size={self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" + ) + + +class MotionLinear(nn.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + bias: bool = True, + use_activation: bool = False, + ): + super().__init__() + self.use_activation = use_activation + + # Linear weight with scale factor + self.weight = nn.Parameter(torch.randn(out_dim, in_dim)) + self.scale = 1 / math.sqrt(in_dim) + + # If an activation is present, the bias will be fused to it + if bias and not self.use_activation: + self.bias = nn.Parameter(torch.zeros(out_dim)) + else: + self.bias = None + + if self.use_activation: + self.act_fn = FusedLeakyReLU(bias_channels=out_dim) + else: + self.act_fn = None + + def forward(self, input: torch.Tensor, channel_dim: int = 1) -> torch.Tensor: + out = F.linear(input, self.weight * self.scale, bias=self.bias) + if self.use_activation: + out = self.act_fn(out, channel_dim=channel_dim) + return out + + def __repr__(self): + return ( + f"{self.__class__.__name__}(in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}," + f" bias={self.bias is not None})" + ) + + +class MotionEncoderResBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + kernel_size_skip: int = 1, + blur_kernel: tuple[int, ...] = (1, 3, 3, 1), + downsample_factor: int = 2, + ): + super().__init__() + self.downsample_factor = downsample_factor + + # 3 x 3 Conv + fused leaky ReLU + self.conv1 = MotionConv2d( + in_channels, + in_channels, + kernel_size, + stride=1, + padding=kernel_size // 2, + use_activation=True, + ) + + # 3 x 3 Conv that downsamples 2x + fused leaky ReLU + self.conv2 = MotionConv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=self.downsample_factor, + padding=0, + blur_kernel=blur_kernel, + use_activation=True, + ) + + # 1 x 1 Conv that downsamples 2x in skip connection + self.conv_skip = MotionConv2d( + in_channels, + out_channels, + kernel_size=kernel_size_skip, + stride=self.downsample_factor, + padding=0, + bias=False, + blur_kernel=blur_kernel, + use_activation=False, + ) + + def forward(self, x: torch.Tensor, channel_dim: int = 1) -> torch.Tensor: + x_out = self.conv1(x, channel_dim) + x_out = self.conv2(x_out, channel_dim) + + x_skip = self.conv_skip(x, channel_dim) + + x_out = (x_out + x_skip) / math.sqrt(2) + return x_out + + +class WanAnimateMotionEncoder(nn.Module): + def __init__( + self, + size: int = 512, + style_dim: int = 512, + motion_dim: int = 20, + out_dim: int = 512, + motion_blocks: int = 5, + channels: dict[str, int] | None = None, + ): + super().__init__() + self.size = size + + # Appearance encoder: conv layers + if channels is None: + channels = WAN_ANIMATE_MOTION_ENCODER_CHANNEL_SIZES + + self.conv_in = MotionConv2d(3, channels[str(size)], 1, use_activation=True) + + self.res_blocks = nn.ModuleList() + in_channels = channels[str(size)] + log_size = int(math.log(size, 2)) + for i in range(log_size, 2, -1): + out_channels = channels[str(2 ** (i - 1))] + self.res_blocks.append(MotionEncoderResBlock(in_channels, out_channels)) + in_channels = out_channels + + self.conv_out = MotionConv2d(in_channels, style_dim, 4, padding=0, bias=False, use_activation=False) + + # Motion encoder: linear layers + # NOTE: there are no activations in between the linear layers here, which is weird but I believe matches the + # original code. + linears = [MotionLinear(style_dim, style_dim) for _ in range(motion_blocks - 1)] + linears.append(MotionLinear(style_dim, motion_dim)) + self.motion_network = nn.ModuleList(linears) + + self.motion_synthesis_weight = nn.Parameter(torch.randn(out_dim, motion_dim)) + + def forward(self, face_image: torch.Tensor, channel_dim: int = 1) -> torch.Tensor: + if (face_image.shape[-2] != self.size) or (face_image.shape[-1] != self.size): + raise ValueError( + f"Face pixel values has resolution ({face_image.shape[-1]}, {face_image.shape[-2]}) but is expected" + f" to have resolution ({self.size}, {self.size})" + ) + + # Appearance encoding through convs + face_image = self.conv_in(face_image, channel_dim) + for block in self.res_blocks: + face_image = block(face_image, channel_dim) + face_image = self.conv_out(face_image, channel_dim) + motion_feat = face_image.squeeze(-1).squeeze(-1) + + # Motion feature extraction + for linear_layer in self.motion_network: + motion_feat = linear_layer(motion_feat, channel_dim=channel_dim) + + # Motion synthesis via Linear Motion Decomposition + weight = self.motion_synthesis_weight + 1e-8 + # Upcast the QR orthogonalization operation to FP32 + original_motion_dtype = motion_feat.dtype + motion_feat = motion_feat.to(torch.float32) + weight = weight.to(torch.float32) + + Q = torch.linalg.qr(weight)[0].to(device=motion_feat.device) + + motion_feat_diag = torch.diag_embed(motion_feat) # Alpha, diagonal matrix + motion_decomposition = torch.matmul(motion_feat_diag, Q.T) + motion_vec = torch.sum(motion_decomposition, dim=1) + + motion_vec = motion_vec.to(dtype=original_motion_dtype) + + return motion_vec + + +class WanAnimateFaceEncoder(nn.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + hidden_dim: int = 1024, + num_heads: int = 4, + kernel_size: int = 3, + eps: float = 1e-6, + pad_mode: str = "replicate", + ): + super().__init__() + self.num_heads = num_heads + self.time_causal_padding = (kernel_size - 1, 0) + self.pad_mode = pad_mode + + self.act = nn.SiLU() + + self.conv1_local = nn.Conv1d(in_dim, hidden_dim * num_heads, kernel_size=kernel_size, stride=1) + self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size, stride=2) + self.conv3 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size, stride=2) + + self.norm1 = nn.LayerNorm(hidden_dim, eps, elementwise_affine=False) + self.norm2 = nn.LayerNorm(hidden_dim, eps, elementwise_affine=False) + self.norm3 = nn.LayerNorm(hidden_dim, eps, elementwise_affine=False) + + self.out_proj = nn.Linear(hidden_dim, out_dim) + + self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, out_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_size = x.shape[0] + + # Reshape to channels-first to apply causal Conv1d over frame dim + x = x.permute(0, 2, 1) + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + x = self.conv1_local(x) # [B, C, T_padded] --> [B, N * C, T] + x = x.unflatten(1, (self.num_heads, -1)).flatten(0, 1) # [B, N * C, T] --> [B * N, C, T] + # Reshape back to channels-last to apply LayerNorm over channel dim + x = x.permute(0, 2, 1) + x = self.norm1(x) + x = self.act(x) + + x = x.permute(0, 2, 1) + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + x = self.conv2(x) + x = x.permute(0, 2, 1) + x = self.norm2(x) + x = self.act(x) + + x = x.permute(0, 2, 1) + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + x = self.conv3(x) + x = x.permute(0, 2, 1) + x = self.norm3(x) + x = self.act(x) + + x = self.out_proj(x) + x = x.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) # [B * N, T, C_out] --> [B, T, N, C_out] + + padding = self.padding_tokens.repeat(batch_size, x.shape[1], 1, 1).to(device=x.device) + x = torch.cat([x, padding], dim=-2) # [B, T, N, C_out] --> [B, T, N + 1, C_out] + + return x + + +class WanAnimateFaceBlockCrossAttention(nn.Module): + """ + Temporally-aligned cross attention with the face motion signal in the Wan Animate Face Blocks. + """ + + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + eps: float = 1e-6, + cross_attention_dim_head: int | None = None, + ): + super().__init__() + self.inner_dim = dim_head * heads + self.heads = heads + self.cross_attention_dim_head = cross_attention_dim_head + self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads + + # 1. Pre-Attention Norms for the hidden_states (video latents) and encoder_hidden_states (motion vector). + # NOTE: this is not used in "vanilla" WanAttention + self.pre_norm_q = nn.LayerNorm(dim, eps, elementwise_affine=False) + self.pre_norm_kv = nn.LayerNorm(dim, eps, elementwise_affine=False) + + # 2. QKV and Output Projections + self.to_q = nn.Linear(dim, self.inner_dim, bias=True) + self.to_k = nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_v = nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_out = nn.Linear(self.inner_dim, dim, bias=True) + + # 3. QK Norm + # NOTE: this is applied after the reshape, so only over dim_head rather than dim_head * heads + self.norm_q = nn.RMSNorm(dim_head, eps=eps, elementwise_affine=True) + self.norm_k = nn.RMSNorm(dim_head, eps=eps, elementwise_affine=True) + + # 4. Attention + forward_context = get_forward_context() + self.usp_attn = USPAttention( + num_heads=heads, + head_size=dim_head, + attn_type=forward_context.attn_type, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + # encoder_hidden_states corresponds to the motion vec + # attention_mask corresponds to the motion mask (if any) + hidden_states = self.pre_norm_q(hidden_states) + encoder_hidden_states = self.pre_norm_kv(encoder_hidden_states) + + # B --> batch_size, T --> reduced inference segment len, N --> face_encoder_num_heads + 1, C --> dim + B, T, N, C = encoder_hidden_states.shape + + query = self.to_q(hidden_states) + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + # [B, S, H * D] --> [B, S, H, D] + query = query.unflatten(2, (self.heads, -1)) + # [B, T, N, H * D_kv] --> [B, T, N, H, D_kv] + key = key.view(B, T, N, self.heads, -1) + value = value.view(B, T, N, self.heads, -1) + + query = self.norm_q(query) + key = self.norm_k(key) + + # [B, S, H, D] --> [B * T, S / T, H, D] + query = query.unflatten(1, (T, -1)).flatten(0, 1) + # [B, T, N, H, D_kv] --> [B * T, N, H, D_kv] + key = key.flatten(0, 1) + value = value.flatten(0, 1) + + hidden_states = self.usp_attn(query, key, value) + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + hidden_states = hidden_states.unflatten(0, (B, T)).flatten(1, 2) + + hidden_states = self.to_out(hidden_states) + + if attention_mask is not None: + attention_mask = attention_mask.flatten(start_dim=1) + hidden_states = hidden_states * attention_mask + + return hidden_states + + +class WanAnimateTransformer3DModel(DiffusionModel): + r""" + A Transformer model for video-like data used in the WanAnimate model. + + Args: + patch_size (`tuple[int]`, defaults to `(1, 2, 2)`): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch). + num_attention_heads (`int`, defaults to `40`): + Fixed length for text embeddings. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + in_channels (`int`, defaults to `36`): + The number of channels in the input. + latent_channels (`int`, defaults to `16`): + The number of latent channels. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + text_dim (`int`, defaults to `4096`): + Input dimension for text embeddings. + freq_dim (`int`, defaults to `256`): + Dimension for sinusoidal time embeddings. + ffn_dim (`int`, defaults to `13824`): + Intermediate dimension in feed-forward network. + num_layers (`int`, defaults to `40`): + The number of layers of transformer blocks to use. + cross_attn_norm (`bool`, defaults to `True`): + Enable cross-attention normalization. + qk_norm (`str`, *optional*, defaults to `"rms_norm_across_heads"`): + Query/key normalization type. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + image_dim (`int`, *optional*, defaults to `1280`): + The number of channels to use for the image embedding. If `None`, no projection is used. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + rope_max_seq_len (`int`, defaults to `1024`): + Maximum sequence length for rotary position embeddings. + pos_embed_seq_len (`int`, *optional*, defaults to `None`): + Positional embedding sequence length. + motion_encoder_channel_sizes (`dict[str, int]`, *optional*): + Channel sizes used by the motion encoder. + motion_encoder_size (`int`, defaults to `512`): + Input resolution used by the motion encoder. + motion_style_dim (`int`, defaults to `512`): + Motion style dimension. + motion_dim (`int`, defaults to `20`): + Motion vector dimension. + motion_encoder_dim (`int`, defaults to `512`): + Output dimension of the motion encoder. + face_encoder_hidden_dim (`int`, defaults to `1024`): + Hidden dimension of the face encoder. + face_encoder_num_heads (`int`, defaults to `4`): + Number of attention heads in the face encoder. + inject_face_latents_blocks (`int`, defaults to `5`): + Interval of transformer blocks at which face latents are injected. + motion_encoder_batch_size (`int`, defaults to `8`): + Default batch size for motion encoder inference. + """ + + _keep_in_fp32_modules = [ + "time_embedder", + "scale_shift_table", + "norm1", + "norm2", + "norm3", + "motion_synthesis_weight", + ] + _keys_to_ignore_on_load_unexpected = ["norm_added_q"] + + @register_to_config + def __init__( + self, + patch_size: tuple[int] = (1, 2, 2), + num_attention_heads: int = 40, + attention_head_dim: int = 128, + in_channels: int | None = 36, + latent_channels: int | None = 16, + out_channels: int | None = 16, + text_dim: int = 4096, + freq_dim: int = 256, + ffn_dim: int = 13824, + num_layers: int = 40, + cross_attn_norm: bool = True, + qk_norm: str | None = "rms_norm_across_heads", + eps: float = 1e-6, + image_dim: int | None = 1280, + added_kv_proj_dim: int | None = None, + rope_max_seq_len: int = 1024, + pos_embed_seq_len: int | None = None, + motion_encoder_channel_sizes: dict[str, int] | None = None, + motion_encoder_size: int = 512, + motion_style_dim: int = 512, + motion_dim: int = 20, + motion_encoder_dim: int = 512, + face_encoder_hidden_dim: int = 1024, + face_encoder_num_heads: int = 4, + inject_face_latents_blocks: int = 5, + motion_encoder_batch_size: int = 8, + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + # Allow either only in_channels or only latent_channels to be set for convenience + if in_channels is None and latent_channels is not None: + in_channels = 2 * latent_channels + 4 + elif in_channels is not None and latent_channels is None: + latent_channels = (in_channels - 4) // 2 + elif in_channels is not None and latent_channels is not None: + assert in_channels == 2 * latent_channels + 4, "in_channels should be 2 * latent_channels + 4" + else: + raise ValueError("At least one of `in_channels` and `latent_channels` must be supplied.") + out_channels = out_channels or latent_channels + + # 1. Patch & position embedding + self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) + self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + self.pose_patch_embedding = nn.Conv3d(latent_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + + # 2. Condition embeddings + self.condition_embedder = WanTimeTextImageEmbedding( + dim=inner_dim, + time_freq_dim=freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=text_dim, + image_embed_dim=image_dim, + pos_embed_seq_len=pos_embed_seq_len, + ) + + # 3. Motion encoder + self.motion_encoder = WanAnimateMotionEncoder( + size=motion_encoder_size, + style_dim=motion_style_dim, + motion_dim=motion_dim, + out_dim=motion_encoder_dim, + channels=motion_encoder_channel_sizes, + ) + + # 4. Face encoder + self.face_encoder = WanAnimateFaceEncoder( + in_dim=motion_encoder_dim, + out_dim=inner_dim, + hidden_dim=face_encoder_hidden_dim, + num_heads=face_encoder_num_heads, + ) + + # 5. Transformer blocks + self.blocks = nn.ModuleList( + [ + WanTransformerBlock( + dim=inner_dim, + ffn_dim=ffn_dim, + num_heads=num_attention_heads, + qk_norm=qk_norm, + cross_attn_norm=cross_attn_norm, + eps=eps, + added_kv_proj_dim=added_kv_proj_dim, + ) + for _ in range(num_layers) + ] + ) + + # 6. Face adapter + self.face_adapter = nn.ModuleList( + [ + WanAnimateFaceBlockCrossAttention( + dim=inner_dim, + heads=num_attention_heads, + dim_head=inner_dim // num_attention_heads, + eps=eps, + cross_attention_dim_head=inner_dim // num_attention_heads, + ) + for _ in range(num_layers // inject_face_latents_blocks) + ] + ) + + # 7. Output norm & projection + self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: torch.Tensor | None = None, + pose_hidden_states: torch.Tensor | None = None, + face_pixel_values: torch.Tensor | None = None, + motion_encode_batch_size: int | None = None, + return_dict: bool = True, + ) -> torch.Tensor | Transformer2DModelOutput: + """ + Forward pass of Wan2.2-Animate transformer model. + + Args: + hidden_states (`torch.Tensor` of shape `(B, 2C + 4, T + 1, H, W)`): + Input noisy video latents of shape `(B, 2C + 4, T + 1, H, W)`, where B is the batch size, C is the + number of latent channels (16 for Wan VAE), T is the number of latent frames in an inference segment, H + is the latent height, and W is the latent width. + timestep (`torch.LongTensor`): + The current timestep in the denoising loop. + encoder_hidden_states (`torch.Tensor`): + Text embeddings from the text encoder (umT5 for Wan Animate). + encoder_hidden_states_image (`torch.Tensor`): + CLIP visual features of the reference (character) image. + pose_hidden_states (`torch.Tensor` of shape `(B, C, T, H, W)`): + Pose video latents. + face_pixel_values (`torch.Tensor` of shape `(B, C', S, H', W')`): + Face video in pixel space (not latent space). Typically C' = 3 and H' and W' are the height/width of + the face video in pixels. Here S is the inference segment length, usually set to 77. + motion_encode_batch_size (`int`, *optional*): + The batch size for batched encoding of the face video via the motion encoder. Will default to + `self.config.motion_encoder_batch_size` if not set. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return the output as a dict or tuple. + Returns: + [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: + If `return_dict` is True, a [`~models.transformer_2d.Transformer2DModelOutput`] whose `sample` is the + denoised video latent is returned, otherwise a plain `tuple` whose first element is that tensor is + returned. + """ + + # Check that shapes match up + if pose_hidden_states is not None and pose_hidden_states.shape[2] + 1 != hidden_states.shape[2]: + raise ValueError( + f"pose_hidden_states frame dim (dim 2) is {pose_hidden_states.shape[2]} but must be one less than the" + f" hidden_states's corresponding frame dim: {hidden_states.shape[2]}" + ) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + # 1. Rotary position embedding + rotary_emb = self.rope(hidden_states) + + # 2. Patch embedding + hidden_states = self.patch_embedding(hidden_states) + pose_hidden_states = self.pose_patch_embedding(pose_hidden_states) + # Add pose embeddings to hidden states + hidden_states[:, :, 1:] = hidden_states[:, :, 1:] + pose_hidden_states + # Calling contiguous() here is important so that we don't recompile when performing regional compilation + hidden_states = hidden_states.flatten(2).transpose(1, 2).contiguous() + + original_seq_len = hidden_states.shape[1] + + # 3. Condition embeddings (time, text, image) + # Wan Animate is based on Wan 2.1 and thus uses Wan 2.1's timestep logic + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=None + ) + + # batch_size, 6, inner_dim + timestep_proj = timestep_proj.unflatten(1, (6, -1)) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + + # 4. Get motion features from the face video + # Motion vector computation from face pixel values + batch_size, channels, num_face_frames, height, width = face_pixel_values.shape + # Rearrange from (B, C, T, H, W) to (B*T, C, H, W) + face_pixel_values = face_pixel_values.permute(0, 2, 1, 3, 4).reshape(-1, channels, height, width) + + # Extract motion features using motion encoder + # Perform batched motion encoder inference to allow trading off inference speed for memory usage + motion_encode_batch_size = motion_encode_batch_size or self.config.motion_encoder_batch_size + face_batches = torch.split(face_pixel_values, motion_encode_batch_size) + motion_vec_batches = [] + for face_batch in face_batches: + motion_vec_batch = self.motion_encoder(face_batch) + motion_vec_batches.append(motion_vec_batch) + motion_vec = torch.cat(motion_vec_batches) + motion_vec = motion_vec.view(batch_size, num_face_frames, -1) + + # Now get face features from the motion vector + motion_vec = self.face_encoder(motion_vec) + + # Add padding at the beginning (prepend zeros) + pad_face = torch.zeros_like(motion_vec[:, :1]) + motion_vec = torch.cat([pad_face, motion_vec], dim=1) + + rotary_emb_cos, rotary_emb_sin = rotary_emb + hidden_states, rotary_emb_cos, rotary_emb_sin = sequence_parallel_shard( + [hidden_states, rotary_emb_cos, rotary_emb_sin], + seq_dims=[1, 1, 1], + ) + rotary_emb = (rotary_emb_cos, rotary_emb_sin) + + # 5. Transformer blocks with face adapter integration + for block_idx, block in enumerate(self.blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb + ) + else: + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + + # Face adapter integration: apply after every 5th block (0, 5, 10, 15, ...) + if block_idx % self.config.inject_face_latents_blocks == 0: + face_adapter_block_idx = block_idx // self.config.inject_face_latents_blocks + face_adapter_output = self.face_adapter[face_adapter_block_idx](hidden_states, motion_vec) + face_adapter_output = face_adapter_output.to(device=hidden_states.device) + hidden_states = face_adapter_output + hidden_states + + (hidden_states,) = sequence_parallel_unshard([hidden_states], seq_dims=[1], seq_lens=[original_seq_len]) + + # 6. Output norm, projection & unpatchify + # batch_size, inner_dim + shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1) + + hidden_states_original_dtype = hidden_states.dtype + hidden_states = self.norm_out(hidden_states.float()) + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + hidden_states = (hidden_states * (1 + scale) + shift).to(dtype=hidden_states_original_dtype) + + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/diffsynth_engine/models/wan/transformer_wan_vace.py b/diffsynth_engine/models/wan/transformer_wan_vace.py new file mode 100644 index 0000000..8b6fc5e --- /dev/null +++ b/diffsynth_engine/models/wan/transformer_wan_vace.py @@ -0,0 +1,384 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_wan_vace.py + +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +import torch.nn as nn +from diffusers.configuration_utils import register_to_config +from diffusers.models.attention import FeedForward +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.normalization import FP32LayerNorm + +from diffsynth_engine.distributed.utils import sequence_parallel_shard, sequence_parallel_unshard +from diffsynth_engine.models.base import DiffusionModel +from diffsynth_engine.models.wan.transformer_wan import ( + WanAttention, + WanRotaryPosEmbed, + WanTimeTextImageEmbedding, + WanTransformerBlock, +) +from diffsynth_engine.utils import logging + +logger = logging.get_logger(__name__) + + +class WanVACETransformerBlock(nn.Module): + def __init__( + self, + dim: int, + ffn_dim: int, + num_heads: int, + qk_norm: str = "rms_norm_across_heads", + cross_attn_norm: bool = False, + eps: float = 1e-6, + added_kv_proj_dim: int | None = None, + apply_input_projection: bool = False, + apply_output_projection: bool = False, + ): + super().__init__() + + # 1. Input projection + self.proj_in = None + if apply_input_projection: + self.proj_in = nn.Linear(dim, dim) + + # 2. Self-attention + self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.attn1 = WanAttention( + dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + eps=eps, + cross_attention_dim_head=None, + ) + + # 3. Cross-attention + self.attn2 = WanAttention( + dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + eps=eps, + added_kv_proj_dim=added_kv_proj_dim, + cross_attention_dim_head=dim // num_heads, + ) + self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + + # 4. Feed-forward + self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") + self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) + + # 5. Output projection + self.proj_out = None + if apply_output_projection: + self.proj_out = nn.Linear(dim, dim) + + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + control_hidden_states: torch.Tensor, + temb: torch.Tensor, + rotary_emb: torch.Tensor, + ) -> torch.Tensor: + if self.proj_in is not None: + control_hidden_states = self.proj_in(control_hidden_states) + control_hidden_states = control_hidden_states + hidden_states + + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table.to(temb.device) + temb.float() + ).chunk(6, dim=1) + + # 1. Self-attention + norm_hidden_states = (self.norm1(control_hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as( + control_hidden_states + ) + attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb) + control_hidden_states = (control_hidden_states.float() + attn_output * gate_msa).type_as(control_hidden_states) + + # 2. Cross-attention + norm_hidden_states = self.norm2(control_hidden_states.float()).type_as(control_hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None) + control_hidden_states = control_hidden_states + attn_output + + # 3. Feed-forward + norm_hidden_states = (self.norm3(control_hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( + control_hidden_states + ) + ff_output = self.ffn(norm_hidden_states) + control_hidden_states = (control_hidden_states.float() + ff_output.float() * c_gate_msa).type_as( + control_hidden_states + ) + + conditioning_states = None + if self.proj_out is not None: + conditioning_states = self.proj_out(control_hidden_states) + + return conditioning_states, control_hidden_states + + +class WanVACETransformer3DModel(DiffusionModel): + r""" + A Transformer model for video-like data used in the Wan VACE model. + + This model extends the base Wan Transformer with a VACE control branch that injects + conditioning signals at specified layers for controllable video generation. + + Args: + patch_size (`tuple[int]`, defaults to `(1, 2, 2)`): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch). + num_attention_heads (`int`, defaults to `40`): + The number of attention heads. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + text_dim (`int`, defaults to `4096`): + Input dimension for text embeddings. + freq_dim (`int`, defaults to `256`): + Dimension for sinusoidal time embeddings. + ffn_dim (`int`, defaults to `13824`): + Intermediate dimension in feed-forward network. + num_layers (`int`, defaults to `40`): + The number of layers of transformer blocks to use. + cross_attn_norm (`bool`, defaults to `True`): + Enable cross-attention normalization. + qk_norm (`str`, *optional*, defaults to `"rms_norm_across_heads"`): + Query/key normalization type. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + image_dim (`int`, *optional*, defaults to `None`): + Dimension for image embeddings (used in I2V models). + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels for added key and value projections. + rope_max_seq_len (`int`, defaults to `1024`): + Maximum sequence length for rotary position embeddings. + pos_embed_seq_len (`int`, *optional*, defaults to `None`): + Positional embedding sequence length. + vace_layers (`list[int]`, defaults to `[0, 5, 10, 15, 20, 25, 30, 35]`): + Layer indices where VACE control signals are injected. + vace_in_channels (`int`, defaults to `96`): + Number of input channels for the VACE patch embedding. + """ + + _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] + _keys_to_ignore_on_load_unexpected = ["norm_added_q"] + + @register_to_config + def __init__( + self, + patch_size: tuple[int, ...] = (1, 2, 2), + num_attention_heads: int = 40, + attention_head_dim: int = 128, + in_channels: int = 16, + out_channels: int = 16, + text_dim: int = 4096, + freq_dim: int = 256, + ffn_dim: int = 13824, + num_layers: int = 40, + cross_attn_norm: bool = True, + qk_norm: str | None = "rms_norm_across_heads", + eps: float = 1e-6, + image_dim: int | None = None, + added_kv_proj_dim: int | None = None, + rope_max_seq_len: int = 1024, + pos_embed_seq_len: int | None = None, + vace_layers: list[int] = [0, 5, 10, 15, 20, 25, 30, 35], + vace_in_channels: int = 96, + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + if max(vace_layers) >= num_layers: + raise ValueError(f"VACE layers {vace_layers} exceed the number of transformer layers {num_layers}.") + if 0 not in vace_layers: + raise ValueError("VACE layers must include layer 0.") + + # 1. Patch & position embedding + self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) + self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + self.vace_patch_embedding = nn.Conv3d(vace_in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + + # 2. Condition embeddings + # image_embedding_dim=1280 for I2V model + self.condition_embedder = WanTimeTextImageEmbedding( + dim=inner_dim, + time_freq_dim=freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=text_dim, + image_embed_dim=image_dim, + pos_embed_seq_len=pos_embed_seq_len, + ) + + # 3. Transformer blocks + self.blocks = nn.ModuleList( + [ + WanTransformerBlock( + inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim + ) + for _ in range(num_layers) + ] + ) + + # 4. VACE control blocks + self.vace_blocks = nn.ModuleList( + [ + WanVACETransformerBlock( + inner_dim, + ffn_dim, + num_attention_heads, + qk_norm, + cross_attn_norm, + eps, + added_kv_proj_dim, + apply_input_projection=i == 0, + apply_output_projection=True, + ) + for i in range(len(vace_layers)) + ] + ) + + # 5. Output norm & projection + self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: torch.Tensor | None = None, + control_hidden_states: torch.Tensor = None, + control_hidden_states_scale: torch.Tensor = None, + return_dict: bool = True, + ) -> torch.Tensor | Transformer2DModelOutput: + """ + The [`WanVACETransformer3DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`): + Input `hidden_states`. + timestep (`torch.LongTensor`): + Used to indicate denoising step. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + encoder_hidden_states_image (`torch.Tensor`, *optional*): + Conditional image embeddings for image-conditioned generation. + control_hidden_states (`torch.Tensor`, *optional*): + Control latents used by the VACE control branch. + control_hidden_states_scale (`torch.Tensor`, *optional*): + Per-VACE-layer scale applied to the control hidden states. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + if control_hidden_states_scale is None: + control_hidden_states_scale = control_hidden_states.new_ones(len(self.config.vace_layers)) + control_hidden_states_scale = torch.unbind(control_hidden_states_scale) + if len(control_hidden_states_scale) != len(self.config.vace_layers): + raise ValueError( + f"Length of `control_hidden_states_scale` {len(control_hidden_states_scale)} should be " + f"equal to {len(self.config.vace_layers)}." + ) + + # 1. Rotary position embedding + rotary_emb = self.rope(hidden_states) + + # 2. Patch embedding + hidden_states = self.patch_embedding(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + control_hidden_states = self.vace_patch_embedding(control_hidden_states) + control_hidden_states = control_hidden_states.flatten(2).transpose(1, 2) + control_hidden_states_padding = control_hidden_states.new_zeros( + batch_size, hidden_states.size(1) - control_hidden_states.size(1), control_hidden_states.size(2) + ) + control_hidden_states = torch.cat([control_hidden_states, control_hidden_states_padding], dim=1) + + original_seq_len = hidden_states.shape[1] + + # 3. Time embedding + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, encoder_hidden_states, encoder_hidden_states_image + ) + timestep_proj = timestep_proj.unflatten(1, (6, -1)) + + # 4. Image embedding + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + + rotary_emb_cos, rotary_emb_sin = rotary_emb + hidden_states, control_hidden_states, rotary_emb_cos, rotary_emb_sin = sequence_parallel_shard( + [hidden_states, control_hidden_states, rotary_emb_cos, rotary_emb_sin], + seq_dims=[1, 1, 1, 1], + ) + rotary_emb = (rotary_emb_cos, rotary_emb_sin) + + # 5. VACE control blocks + control_hidden_states_list = [] + for i, block in enumerate(self.vace_blocks): + conditioning_states, control_hidden_states = block( + hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb + ) + control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i])) + control_hidden_states_list = control_hidden_states_list[::-1] + + # 6. Transformer blocks + for i, block in enumerate(self.blocks): + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + if i in self.config.vace_layers: + control_hint, scale = control_hidden_states_list.pop() + hidden_states = hidden_states + control_hint * scale + + (hidden_states,) = sequence_parallel_unshard([hidden_states], seq_dims=[1], seq_lens=[original_seq_len]) + + # 7. Output norm, projection & unpatchify + shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1) + + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/diffsynth_engine/pipelines/base.py b/diffsynth_engine/pipelines/base.py index 55d48f2..10f026e 100644 --- a/diffsynth_engine/pipelines/base.py +++ b/diffsynth_engine/pipelines/base.py @@ -29,14 +29,19 @@ def __call__(self, *args, **kwargs): raise NotImplementedError() @staticmethod - def init_transformer(model_cls: Type[nn.Module], pipeline_config: PipelineConfig, empty_weights: bool = False): + def init_transformer( + model_cls: Type[nn.Module], + pipeline_config: PipelineConfig, + empty_weights: bool = False, + subfolder: str = "transformer", + ): use_fsdp = pipeline_config.use_fsdp and is_world_group_initialized() with set_forward_context(attn_type=pipeline_config.attn_type): with init_empty_weights(): config = model_cls.load_config( pipeline_config.model_path, - subfolder="transformer", + subfolder=subfolder, local_files_only=True, ) model = model_cls.from_config(config) @@ -51,7 +56,7 @@ def init_transformer(model_cls: Type[nn.Module], pipeline_config: PipelineConfig state_dict = load_model_weights( pipeline_config.model_path, - subfolder="transformer", + subfolder=subfolder, device="cpu" if use_fsdp else pipeline_config.device, dtype=pipeline_config.model_dtype, broadcast_from_rank0=not use_fsdp, @@ -76,6 +81,7 @@ def init_text_encoder( pipeline_config: PipelineConfig, key_mapping: dict = None, empty_weights: bool = False, + strict: bool = True, ): use_fsdp = pipeline_config.use_fsdp and is_world_group_initialized() @@ -91,7 +97,13 @@ def init_text_encoder( return model if use_fsdp: - for layer in model.model.language_model.layers: + if hasattr(model, "model") and hasattr(model.model, "language_model"): + layers = model.model.language_model.layers + elif hasattr(model, "encoder") and hasattr(model.encoder, "block"): + layers = model.encoder.block + else: + raise ValueError(f"Unsupported text encoder model for FSDP: {type(model).__name__}") + for layer in layers: fully_shard(layer) fully_shard(model) @@ -110,10 +122,14 @@ def init_text_encoder( set_model_state_dict( model, state_dict, - options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True), + options=StateDictOptions( + full_state_dict=True, + broadcast_from_rank0=True, + strict=strict, + ), ) else: - model.load_state_dict(state_dict, strict=True, assign=True) + model.load_state_dict(state_dict, strict=strict, assign=True) model.to(device=pipeline_config.device) del state_dict @@ -167,5 +183,3 @@ def progress_bar(self, iterable=None, total=None): def set_progress_bar_config(self, **kwargs): self._progress_bar_config = kwargs - - # TODO: preprocess & postprocess & LoRA diff --git a/diffsynth_engine/pipelines/wan/__init__.py b/diffsynth_engine/pipelines/wan/__init__.py new file mode 100644 index 0000000..8186b95 --- /dev/null +++ b/diffsynth_engine/pipelines/wan/__init__.py @@ -0,0 +1,11 @@ +from .pipeline_wan_animate import WanAnimatePipeline +from .pipeline_wan_i2v import WanImageToVideoPipeline +from .pipeline_wan_t2v import WanTextToVideoPipeline +from .pipeline_wan_vace import WanVACEPipeline + +__all__ = [ + "WanTextToVideoPipeline", + "WanImageToVideoPipeline", + "WanAnimatePipeline", + "WanVACEPipeline", +] diff --git a/diffsynth_engine/pipelines/wan/pipeline_wan_animate.py b/diffsynth_engine/pipelines/wan/pipeline_wan_animate.py new file mode 100644 index 0000000..29e4687 --- /dev/null +++ b/diffsynth_engine/pipelines/wan/pipeline_wan_animate.py @@ -0,0 +1,1274 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan_animate.py + +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import html +import os +from copy import deepcopy +from typing import TYPE_CHECKING, Any, Callable + +import PIL +import regex as re +import torch +import torch.nn.functional as F +from accelerate import init_empty_weights +from diffusers.pipelines.wan.image_processor import WanAnimateImageProcessor +from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput +from diffusers.schedulers import UniPCMultistepScheduler +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel + +from diffsynth_engine.configs.wan import WanPipelineConfig +from diffsynth_engine.distributed.parallel_state import ( + get_cfg_group, + is_cfg_group_initialized, +) +from diffsynth_engine.forward_context import set_forward_context +from diffsynth_engine.models.wan import AutoencoderKLWan, WanAnimateTransformer3DModel +from diffsynth_engine.pipelines.base import Pipeline +from diffsynth_engine.registry import get_attn_backend +from diffsynth_engine.utils import logging + +logger = logging.get_logger(__name__) + +if TYPE_CHECKING: + from diffusers.image_processor import PipelineImageInput + + +def basic_clean(text): + try: + import ftfy + + text = ftfy.fix_text(text) + except ImportError: + pass + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class WanAnimatePipeline(Pipeline): + r""" + Pipeline for unified character animation and replacement using Wan-Animate. + + WanAnimatePipeline takes a character image, pose video, and face video as input, and generates a video in two + modes: + + 1. **Animation mode**: The model generates a video of the character image that mimics the human motion in the input + pose and face videos. The character is animated based on the provided motion controls, creating a new animated + video of the character. + + 2. **Replacement mode**: The model replaces a character in a background video with the provided character image, + using the pose and face videos for motion control. This mode requires additional `background_video` and + `mask_video` inputs. The mask video should have black regions where the original content should be preserved and + white regions where the new character should be generated. + + Args: + pipeline_config (`WanPipelineConfig`): + Configuration for the pipeline. + tokenizer (`AutoTokenizer`): + Tokenizer from T5, specifically the google/umt5-xxl variant. + text_encoder (`UMT5EncoderModel`): + T5 text encoder, specifically the google/umt5-xxl variant. + image_encoder (`CLIPVisionModel`): + CLIP vision model for encoding input images. + image_processor (`CLIPImageProcessor`): + CLIP image processor for preprocessing input images. + vae (`AutoencoderKLWan`): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + scheduler (`UniPCMultistepScheduler`): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + transformer (`WanAnimateTransformer3DModel`): + Conditional Transformer to denoise the input latents. + """ + + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + pipeline_config: WanPipelineConfig, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKLWan, + scheduler: UniPCMultistepScheduler, + image_processor: CLIPImageProcessor, + image_encoder: CLIPVisionModel, + transformer: WanAnimateTransformer3DModel, + ): + super().__init__(pipeline_config) + + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.vae = vae + self.image_encoder = image_encoder + self.transformer = transformer + self.scheduler = scheduler + self.image_processor = image_processor + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if self.vae is not None else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if self.vae is not None else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.video_processor_for_mask = VideoProcessor( + vae_scale_factor=self.vae_scale_factor_spatial, do_normalize=False, do_convert_grayscale=True + ) + + spatial_patch_size = self.transformer.config.patch_size[-2:] if self.transformer is not None else (2, 2) + self.vae_image_processor = WanAnimateImageProcessor( + vae_scale_factor=self.vae_scale_factor_spatial, + spatial_patch_size=spatial_patch_size, + resample="bilinear", + fill_color=0, + ) + + head_dim = self.transformer.config.attention_head_dim + self.attn_backend = get_attn_backend(pipeline_config.attn_type) + if not self.attn_backend.supports_head_size(head_dim): + raise ValueError(f"Attention backend {pipeline_config.attn_type!r} does not support head size {head_dim}.") + + @classmethod + def from_pretrained(cls, model_path_or_config: str | WanPipelineConfig): + """ + Load a WanAnimatePipeline from a pretrained model path or config. + + Args: + model_path_or_config: Either a string path to the model directory or a WanPipelineConfig instance. + + Returns: + WanAnimatePipeline: The loaded pipeline. + """ + if isinstance(model_path_or_config, str): + pipeline_config = WanPipelineConfig(model_path=model_path_or_config) + else: + pipeline_config = model_path_or_config + + if not os.path.exists(pipeline_config.model_path): + raise FileNotFoundError(f"Model path not found: {pipeline_config.model_path}") + + # Load transformer + transformer = cls.init_transformer(WanAnimateTransformer3DModel, pipeline_config) + + # Load scheduler + scheduler = UniPCMultistepScheduler.from_pretrained( + pipeline_config.model_path, + subfolder="scheduler", + ) + + # Load VAE + vae = cls.init_vae(AutoencoderKLWan, pipeline_config) + + # Load text encoder + text_encoder = cls.init_text_encoder(UMT5EncoderModel, pipeline_config, strict=False) + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained( + pipeline_config.model_path, + subfolder="tokenizer", + ) + + # Load image encoder + image_encoder = cls.init_image_encoder(pipeline_config) + + # Load image processor + image_processor = None + image_processor_path = os.path.join(pipeline_config.model_path, "image_processor") + if os.path.isdir(image_processor_path): + image_processor = CLIPImageProcessor.from_pretrained( + pipeline_config.model_path, + subfolder="image_processor", + ) + logger.info("Loaded image_processor from `image_processor` subfolder.") + + return cls( + pipeline_config=pipeline_config, + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + image_encoder=image_encoder, + image_processor=image_processor, + transformer=transformer, + scheduler=scheduler, + ) + + @staticmethod + def init_image_encoder(pipeline_config: WanPipelineConfig, empty_weights: bool = False): + logger.info("Initializing image encoder...") + image_encoder_path = os.path.join(pipeline_config.model_path, "image_encoder") + if not os.path.isdir(image_encoder_path): + logger.warning(f"image_encoder not found in {pipeline_config.model_path}.") + return None + + if empty_weights: + with init_empty_weights(): + model = CLIPVisionModel.from_pretrained( + pipeline_config.model_path, + subfolder="image_encoder", + local_files_only=True, + ) + return model + + model = CLIPVisionModel.from_pretrained( + pipeline_config.model_path, + subfolder="image_encoder", + dtype=torch.float32, + ) + model.to(device=pipeline_config.device) + return model + + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self.device + dtype = dtype or self.pipeline_config.text_encoder_dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # Duplicate text embeddings for each generation per prompt + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_image(self, image: PipelineImageInput, device: torch.device | None = None): + device = device or self.device + image = self.image_processor(images=image, return_tensors="pt").to(device) + image_embeds = self.image_encoder(**image, output_hidden_states=True) + return image_embeds.hidden_states[-2] + + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. + max_sequence_length (`int`, *optional*, defaults to 226): + Maximum sequence length for the text encoder. + device (`torch.device`, *optional*): + torch device + dtype (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self.device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + image, + pose_video, + face_video, + background_video, + mask_video, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + image_embeds=None, + callback_on_step_end_tensor_inputs=None, + mode=None, + prev_segment_conditioning_frames=None, + ): + if image is not None and image_embeds is not None: + raise ValueError( + f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to" + " only forward one of the two." + ) + if image is None and image_embeds is None: + raise ValueError( + "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined." + ) + if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): + raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") + if pose_video is None: + raise ValueError("Provide `pose_video`. Cannot leave `pose_video` undefined.") + if face_video is None: + raise ValueError("Provide `face_video`. Cannot leave `face_video` undefined.") + if not isinstance(pose_video, list) or not isinstance(face_video, list): + raise ValueError("`pose_video` and `face_video` must be lists of PIL images.") + if len(pose_video) == 0 or len(face_video) == 0: + raise ValueError("`pose_video` and `face_video` must contain at least one frame.") + if mode == "replace" and (background_video is None or mask_video is None): + raise ValueError( + "Provide `background_video` and `mask_video`. Cannot leave both `background_video` and `mask_video`" + " undefined when mode is `replace`." + ) + if mode == "replace" and (not isinstance(background_video, list) or not isinstance(mask_video, list)): + raise ValueError("`background_video` and `mask_video` must be lists of PIL images when mode is `replace`.") + + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found" + f" {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if mode is not None and (not isinstance(mode, str) or mode not in ("animate", "replace")): + raise ValueError( + f"`mode` has to be of type `str` and in ('animate', 'replace') but its type is {type(mode)} and value is {mode}" + ) + + if prev_segment_conditioning_frames is not None and ( + not isinstance(prev_segment_conditioning_frames, int) or prev_segment_conditioning_frames not in (1, 5) + ): + raise ValueError( + f"`prev_segment_conditioning_frames` has to be of type `int` and 1 or 5 but its type is" + f" {type(prev_segment_conditioning_frames)} and value is {prev_segment_conditioning_frames}" + ) + + def get_i2v_mask( + self, + batch_size: int, + latent_t: int, + latent_h: int, + latent_w: int, + mask_len: int = 1, + mask_pixel_values: torch.Tensor | None = None, + dtype: torch.dtype | None = None, + device: str | torch.device = "cuda", + ) -> torch.Tensor: + # mask_pixel_values shape (if supplied): [B, C = 1, T, latent_h, latent_w] + if mask_pixel_values is None: + mask_lat_size = torch.zeros( + batch_size, 1, (latent_t - 1) * 4 + 1, latent_h, latent_w, dtype=dtype, device=device + ) + else: + mask_lat_size = mask_pixel_values.clone().to(device=device, dtype=dtype) + mask_lat_size[:, :, :mask_len] = 1 + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:]], dim=2) + mask_lat_size = mask_lat_size.view( + batch_size, -1, self.vae_scale_factor_temporal, latent_h, latent_w + ).transpose(1, 2) + + return mask_lat_size + + def prepare_reference_image_latents( + self, + image: torch.Tensor, + batch_size: int = 1, + sample_mode: str = "argmax", + generator: torch.Generator | list[torch.Generator] | None = None, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + ) -> torch.Tensor: + # image shape: (B, C, H, W) or (B, C, T, H, W) + dtype = dtype or self.pipeline_config.vae_dtype + if image.ndim == 4: + image = image.unsqueeze(2) + + _, _, _, height, width = image.shape + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + image = image.to(device=device, dtype=dtype) + if isinstance(generator, list): + ref_image_latents = [ + retrieve_latents(self.vae.encode(image), generator=g, sample_mode=sample_mode) for g in generator + ] + ref_image_latents = torch.cat(ref_image_latents) + else: + ref_image_latents = retrieve_latents(self.vae.encode(image), generator, sample_mode) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(ref_image_latents.device, ref_image_latents.dtype) + ) + latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + ref_image_latents.device, ref_image_latents.dtype + ) + ref_image_latents = (ref_image_latents - latents_mean) * latents_recip_std + + if ref_image_latents.shape[0] == 1 and batch_size > 1: + ref_image_latents = ref_image_latents.expand(batch_size, -1, -1, -1, -1) + + reference_image_mask = self.get_i2v_mask(batch_size, 1, latent_height, latent_width, 1, None, dtype, device) + reference_image_latents = torch.cat([reference_image_mask, ref_image_latents], dim=1) + + return reference_image_latents + + def prepare_prev_segment_cond_latents( + self, + prev_segment_cond_video: torch.Tensor | None = None, + background_video: torch.Tensor | None = None, + mask_video: torch.Tensor | None = None, + batch_size: int = 1, + segment_frame_length: int = 77, + start_frame: int = 0, + height: int = 720, + width: int = 1280, + prev_segment_cond_frames: int = 1, + task: str = "animate", + interpolation_mode: str = "bicubic", + sample_mode: str = "argmax", + generator: torch.Generator | list[torch.Generator] | None = None, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + ) -> torch.Tensor: + # prev_segment_cond_video shape: (B, C, T, H, W) in pixel space if supplied + # background_video shape: (B, C, T, H, W) (same as prev_segment_cond_video shape) + # mask_video shape: (B, 1, T, H, W) (same as prev_segment_cond_video, but with only 1 channel) + dtype = dtype or self.pipeline_config.vae_dtype + if prev_segment_cond_video is None: + if task == "replace": + prev_segment_cond_video = background_video[:, :, :prev_segment_cond_frames].to(dtype) + else: + cond_frames_shape = (batch_size, 3, prev_segment_cond_frames, height, width) + prev_segment_cond_video = torch.zeros(cond_frames_shape, dtype=dtype, device=device) + + data_batch_size, channels, _, segment_height, segment_width = prev_segment_cond_video.shape + num_latent_frames = (segment_frame_length - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + if segment_height != height or segment_width != width: + logger.info( + f"Interpolating prev segment cond video from ({segment_width}, {segment_height}) to ({width}, {height})" + ) + prev_segment_cond_video = prev_segment_cond_video.transpose(1, 2).flatten(0, 1) + prev_segment_cond_video = F.interpolate( + prev_segment_cond_video, size=(height, width), mode=interpolation_mode + ) + prev_segment_cond_video = prev_segment_cond_video.unflatten(0, (batch_size, -1)).transpose(1, 2) + + if task == "replace": + remaining_segment = background_video[:, :, prev_segment_cond_frames:].to(dtype) + else: + remaining_segment_frames = segment_frame_length - prev_segment_cond_frames + remaining_segment = torch.zeros( + batch_size, channels, remaining_segment_frames, height, width, dtype=dtype, device=device + ) + + prev_segment_cond_video = prev_segment_cond_video.to(dtype=dtype) + full_segment_cond_video = torch.cat([prev_segment_cond_video, remaining_segment], dim=2) + + if isinstance(generator, list): + if data_batch_size == len(generator): + prev_segment_cond_latents = [ + retrieve_latents(self.vae.encode(full_segment_cond_video[i].unsqueeze(0)), g, sample_mode) + for i, g in enumerate(generator) + ] + elif data_batch_size == 1: + prev_segment_cond_latents = [ + retrieve_latents(self.vae.encode(full_segment_cond_video), g, sample_mode) for g in generator + ] + else: + raise ValueError( + f"The batch size of the prev segment video should be either {len(generator)} or 1 but is" + f" {data_batch_size}" + ) + prev_segment_cond_latents = torch.cat(prev_segment_cond_latents) + else: + prev_segment_cond_latents = retrieve_latents( + self.vae.encode(full_segment_cond_video), generator, sample_mode + ) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(prev_segment_cond_latents.device, prev_segment_cond_latents.dtype) + ) + latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + prev_segment_cond_latents.device, prev_segment_cond_latents.dtype + ) + prev_segment_cond_latents = (prev_segment_cond_latents - latents_mean) * latents_recip_std + + if task == "replace": + mask_video = 1 - mask_video + mask_video = mask_video.permute(0, 2, 1, 3, 4) + mask_video = mask_video.flatten(0, 1) + mask_video = F.interpolate(mask_video, size=(latent_height, latent_width), mode="nearest") + mask_pixel_values = mask_video.unflatten(0, (batch_size, -1)) + mask_pixel_values = mask_pixel_values.permute(0, 2, 1, 3, 4) # output shape: [B, C = 1, T, H_lat, W_lat] + else: + mask_pixel_values = None + prev_segment_cond_mask = self.get_i2v_mask( + batch_size, + num_latent_frames, + latent_height, + latent_width, + mask_len=prev_segment_cond_frames if start_frame > 0 else 0, + mask_pixel_values=mask_pixel_values, + dtype=dtype, + device=device, + ) + + prev_segment_cond_latents = torch.cat([prev_segment_cond_mask, prev_segment_cond_latents], dim=1) + return prev_segment_cond_latents + + def prepare_pose_latents( + self, + pose_video: torch.Tensor, + batch_size: int = 1, + sample_mode: str = "argmax", + generator: torch.Generator | list[torch.Generator] | None = None, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + ) -> torch.Tensor: + # pose_video shape: (B, C, T, H, W) + dtype = dtype if dtype is not None else self.pipeline_config.vae_dtype + pose_video = pose_video.to(device=device, dtype=dtype) + if isinstance(generator, list): + pose_latents = [ + retrieve_latents(self.vae.encode(pose_video), generator=g, sample_mode=sample_mode) for g in generator + ] + pose_latents = torch.cat(pose_latents) + else: + pose_latents = retrieve_latents(self.vae.encode(pose_video), generator, sample_mode) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(pose_latents.device, pose_latents.dtype) + ) + latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + pose_latents.device, pose_latents.dtype + ) + pose_latents = (pose_latents - latents_mean) * latents_recip_std + if pose_latents.shape[0] == 1 and batch_size > 1: + pose_latents = pose_latents.expand(batch_size, -1, -1, -1, -1) + return pose_latents + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 720, + width: int = 1280, + num_frames: int = 77, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + # +1 for the conditioning frame + shape = (batch_size, num_channels_latents, num_latent_frames + 1, latent_height, latent_width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents + + def pad_video_frames(self, frames: list[Any], num_target_frames: int) -> list[Any]: + """ + Pads an array-like video `frames` to `num_target_frames` using a "reflect"-like strategy. The frame dimension + is assumed to be the first dimension. In the 1D case, we can visualize this strategy as follows: + + pad_video_frames([1, 2, 3, 4, 5], 10) -> [1, 2, 3, 4, 5, 4, 3, 2, 1, 2] + """ + idx = 0 + flip = False + target_frames = [] + while len(target_frames) < num_target_frames: + target_frames.append(deepcopy(frames[idx])) + if flip: + idx -= 1 + else: + idx += 1 + if idx == 0 or idx == len(frames) - 1: + flip = not flip + + return target_frames + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + def _build_attn_metadata(self, attn_params): + if attn_params is None: + return None + + builder_cls = self.attn_backend.get_builder_cls() + builder = builder_cls() + attn_params_dict = attn_params.to_dict() + attn_metadata = builder.build(**attn_params_dict) + return attn_metadata + + def _predict_noise_with_cfg( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, + prompt_embeds: torch.Tensor, + negative_prompt_embeds: torch.Tensor, + image_embeds: torch.Tensor | None, + pose_latents: torch.Tensor, + face_video_segment: torch.Tensor, + motion_encode_batch_size: int | None, + attn_metadata, + apply_cfg: bool, + guidance_scale: float, + use_cfg_parallel: bool, + ): + """ + Predict noise with classifier-free guidance, supporting parallel CFG inference. + + For Wan Animate, the unconditional pass blanks out the face video (sets all pixels to -1) + to remove face conditioning. + + Args: + latent_model_input: The model input (latents concatenated with reference/conditioning latents). + timestep: Current timestep tensor. + prompt_embeds: Positive prompt embeddings tensor. + negative_prompt_embeds: Negative prompt embeddings tensor. + image_embeds: Image embeddings tensor for cross-attention. + pose_latents: Pose video latents. + face_video_segment: Face video segment in pixel space. + motion_encode_batch_size: Batch size for batched motion encoding. + attn_metadata: Attention metadata for set_forward_context. + apply_cfg: Whether to apply classifier-free guidance this step. + guidance_scale: The CFG scale factor. + use_cfg_parallel: Whether to use CFG parallelism across devices. + + Returns: + noise_pred: The predicted noise tensor. + """ + if not apply_cfg: + with set_forward_context(attn_metadata=attn_metadata): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_image=image_embeds, + pose_hidden_states=pose_latents, + face_pixel_values=face_video_segment, + motion_encode_batch_size=motion_encode_batch_size, + return_dict=False, + )[0] + return noise_pred.float() + + # CFG mode + cfg_group, cfg_rank = None, None + if use_cfg_parallel: + if not is_cfg_group_initialized(): + raise RuntimeError("CFG group must be initialized when use_cfg_parallel=True") + cfg_group = get_cfg_group() + cfg_rank = cfg_group.rank_in_group + + noise_pred_pos = torch.zeros_like(latent_model_input, dtype=torch.float32) + noise_pred_neg = torch.zeros_like(latent_model_input, dtype=torch.float32) + + # Positive prompt forward pass (conditional) + if not (use_cfg_parallel and cfg_rank != 0): + with set_forward_context(attn_metadata=attn_metadata): + noise_pred_pos = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_image=image_embeds, + pose_hidden_states=pose_latents, + face_pixel_values=face_video_segment, + motion_encode_batch_size=motion_encode_batch_size, + return_dict=False, + )[0].float() + + # Negative prompt forward pass (unconditional) - blank out face + face_pixel_values_uncond = face_video_segment * 0 - 1 + if not use_cfg_parallel or cfg_rank != 0: + with set_forward_context(attn_metadata=attn_metadata): + noise_pred_neg = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_image=image_embeds, + pose_hidden_states=pose_latents, + face_pixel_values=face_pixel_values_uncond, + motion_encode_batch_size=motion_encode_batch_size, + return_dict=False, + )[0].float() + + # All-reduce for CFG parallel + if use_cfg_parallel: + noise_pred_pos = cfg_group.all_reduce(noise_pred_pos) + noise_pred_neg = cfg_group.all_reduce(noise_pred_neg) + + # Apply CFG + noise_pred = noise_pred_neg + guidance_scale * (noise_pred_pos - noise_pred_neg) + return noise_pred + + @torch.no_grad() + def __call__( + self, + image: PipelineImageInput, + pose_video: list[PIL.Image.Image], + face_video: list[PIL.Image.Image], + background_video: list[PIL.Image.Image] | None = None, + mask_video: list[PIL.Image.Image] | None = None, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + height: int = 720, + width: int = 1280, + segment_frame_length: int = 77, + num_inference_steps: int = 20, + mode: str = "animate", + prev_segment_conditioning_frames: int = 1, + motion_encode_batch_size: int | None = None, + guidance_scale: float = 1.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + image_embeds: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int, dict], dict] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input character image to condition the generation on. Must be an image, a list of images or a + `torch.Tensor`. + pose_video (`list[PIL.Image.Image]`): + The input pose video to condition the generation on. Must be a list of PIL images. + face_video (`list[PIL.Image.Image]`): + The input face video to condition the generation on. Must be a list of PIL images. + background_video (`list[PIL.Image.Image]`, *optional*): + When mode is `"replace"`, the input background video to condition the generation on. Must be a list of + PIL images. + mask_video (`list[PIL.Image.Image]`, *optional*): + When mode is `"replace"`, the input mask video to condition the generation on. Must be a list of PIL + images. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, pass `prompt_embeds` instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to avoid during video generation. If not defined, pass `negative_prompt_embeds` + instead. Ignored when not using guidance (`guidance_scale` < `1`). + height (`int`, defaults to `720`): + The height in pixels of the generated video. + width (`int`, defaults to `1280`): + The width in pixels of the generated video. + segment_frame_length (`int`, defaults to `77`): + The number of frames in each generated video segment. The total frames of video generated will be equal + to the number of frames in `pose_video`; we will generate the video in segments until we have hit this + length. In general, should be 4N + 1, where N is a non-negative integer. + num_inference_steps (`int`, defaults to `20`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + mode (`str`, defaults to `"animate"`): + The mode of the generation. Choose between `"animate"` and `"replace"`. + prev_segment_conditioning_frames (`int`, defaults to `1`): + The number of frames from the previous video segment to be used for temporal guidance. + motion_encode_batch_size (`int`, *optional*): + The batch size for batched encoding of the face video via the motion encoder. This allows trading off + inference speed for lower memory usage by setting a smaller batch size. + guidance_scale (`float`, defaults to `1.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. By default, CFG is not used in Wan + Animate inference. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. If not provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated video. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a `WanPipelineOutput` instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + Attention kwargs dictionary. + callback_on_step_end (`Callable`, *optional*): + A function that is called at the end of each denoising step during the inference with the following + arguments: `callback_on_step_end(step: int, timestep: int, callback_kwargs: dict)`. `callback_kwargs` + will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. + + Returns: + `WanPipelineOutput` or `tuple`: + If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + # 1. Check inputs + self.check_inputs( + prompt, + negative_prompt, + image, + pose_video, + face_video, + background_video, + mask_video, + height, + width, + prompt_embeds, + negative_prompt_embeds, + image_embeds, + callback_on_step_end_tensor_inputs, + mode, + prev_segment_conditioning_frames, + ) + + if segment_frame_length % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`segment_frame_length - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the" + f" nearest number." + ) + segment_frame_length = ( + segment_frame_length // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + ) + segment_frame_length = max(segment_frame_length, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self.device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Compute segment layout + cond_video_frames = len(pose_video) + effective_segment_length = segment_frame_length - prev_segment_conditioning_frames + last_segment_frames = (cond_video_frames - prev_segment_conditioning_frames) % effective_segment_length + if last_segment_frames == 0: + num_padding_frames = 0 + else: + num_padding_frames = effective_segment_length - last_segment_frames + num_target_frames = cond_video_frames + num_padding_frames + num_segments = num_target_frames // effective_segment_length + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = self.pipeline_config.model_dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Preprocess and encode the reference (character) image + image_height, image_width = self.vae_image_processor.get_default_height_width(image) + if image_height != height or image_width != width: + logger.warning(f"Reshaping reference image from ({image_width}, {image_height}) to ({width}, {height})") + image_pixels = self.vae_image_processor.preprocess(image, height=height, width=width, resize_mode="fill").to( + device, dtype=torch.float32 + ) + + # Get CLIP features from the reference image + if image_embeds is None: + image_embeds = self.encode_image(image, device) + image_embeds = image_embeds.repeat(batch_size * num_videos_per_prompt, 1, 1) + image_embeds = image_embeds.to(transformer_dtype) + + # 5. Encode conditioning videos (pose, face) + pose_video = self.pad_video_frames(pose_video, num_target_frames) + face_video = self.pad_video_frames(face_video, num_target_frames) + + pose_video_width, pose_video_height = pose_video[0].size + if pose_video_height != height or pose_video_width != width: + logger.warning( + f"Reshaping pose video from ({pose_video_width}, {pose_video_height}) to ({width}, {height})" + ) + pose_video = self.video_processor.preprocess_video(pose_video, height=height, width=width).to( + device, dtype=torch.float32 + ) + + face_video_width, face_video_height = face_video[0].size + expected_face_size = self.transformer.config.motion_encoder_size + if face_video_width != expected_face_size or face_video_height != expected_face_size: + logger.warning( + f"Reshaping face video from ({face_video_width}, {face_video_height}) to ({expected_face_size}," + f" {expected_face_size})" + ) + face_video = self.video_processor.preprocess_video( + face_video, height=expected_face_size, width=expected_face_size + ).to(device, dtype=torch.float32) + + if mode == "replace": + background_video = self.pad_video_frames(background_video, num_target_frames) + mask_video = self.pad_video_frames(mask_video, num_target_frames) + + background_video = self.video_processor.preprocess_video(background_video, height=height, width=width).to( + device, dtype=torch.float32 + ) + mask_video = self.video_processor_for_mask.preprocess_video(mask_video, height=height, width=width).to( + device, dtype=torch.float32 + ) + + # 6. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 7. Prepare latent variables which stay constant for all inference segments + num_channels_latents = self.vae.config.z_dim + + # Get VAE-encoded latents of the reference (character) image + reference_image_latents = self.prepare_reference_image_latents( + image_pixels, batch_size * num_videos_per_prompt, generator=generator, device=device + ) + + # 8. Loop over video inference segments + start = 0 + end = segment_frame_length + all_out_frames = [] + out_frames = None + actual_batch_size = batch_size * num_videos_per_prompt + + for _ in range(num_segments): + assert start + prev_segment_conditioning_frames < cond_video_frames + + # Sample noisy latents for the current inference segment + latents = self.prepare_latents( + actual_batch_size, + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames=segment_frame_length, + dtype=torch.float32, + device=device, + generator=generator, + latents=latents if start == 0 else None, + ) + + pose_video_segment = pose_video[:, :, start:end] + face_video_segment = face_video[:, :, start:end] + + face_video_segment = face_video_segment.expand(actual_batch_size, -1, -1, -1, -1) + face_video_segment = face_video_segment.to(dtype=transformer_dtype) + + if start > 0: + prev_segment_cond_video = out_frames[:, :, -prev_segment_conditioning_frames:].clone().detach() + else: + prev_segment_cond_video = None + + if mode == "replace": + background_video_segment = background_video[:, :, start:end] + mask_video_segment = mask_video[:, :, start:end] + + background_video_segment = background_video_segment.expand(actual_batch_size, -1, -1, -1, -1) + mask_video_segment = mask_video_segment.expand(actual_batch_size, -1, -1, -1, -1) + else: + background_video_segment = None + mask_video_segment = None + + pose_latents = self.prepare_pose_latents( + pose_video_segment, actual_batch_size, generator=generator, device=device + ) + pose_latents = pose_latents.to(dtype=transformer_dtype) + + prev_segment_cond_latents = self.prepare_prev_segment_cond_latents( + prev_segment_cond_video, + background_video=background_video_segment, + mask_video=mask_video_segment, + batch_size=actual_batch_size, + segment_frame_length=segment_frame_length, + start_frame=start, + height=height, + width=width, + prev_segment_cond_frames=prev_segment_conditioning_frames, + task=mode, + generator=generator, + device=device, + ) + + # Concatenate the reference latents in the frame dimension + reference_latents = torch.cat([reference_image_latents, prev_segment_cond_latents], dim=2) + + # 8.1 Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + # Concatenate the reference image + prev segment conditioning in the channel dim + latent_model_input = torch.cat([latents, reference_latents], dim=1).to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + attn_metadata = self._build_attn_metadata(self.pipeline_config.attn_params) + + noise_pred = self._predict_noise_with_cfg( + latent_model_input=latent_model_input, + timestep=timestep, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + image_embeds=image_embeds, + pose_latents=pose_latents, + face_video_segment=face_video_segment, + motion_encode_batch_size=motion_encode_batch_size, + attn_metadata=attn_metadata, + apply_cfg=self.do_classifier_free_guidance, + guidance_scale=guidance_scale, + use_cfg_parallel=self.pipeline_config.use_cfg_parallel, + ) + + # Compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + latents = latents.to(self.pipeline_config.vae_dtype) + # Destandardize latents in preparation for Wan VAE decoding + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view( + 1, self.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_recip_std + latents_mean + # Skip the first latent frame (used for conditioning) + out_frames = self.vae.decode(latents[:, :, 1:], return_dict=False)[0] + + if start > 0: + out_frames = out_frames[:, :, prev_segment_conditioning_frames:] + all_out_frames.append(out_frames) + + start += effective_segment_length + end += effective_segment_length + + # Reset scheduler timesteps / state for next denoising loop + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + self._current_timestep = None + assert start + prev_segment_conditioning_frames >= cond_video_frames + + if not output_type == "latent": + video = torch.cat(all_out_frames, dim=2)[:, :, :cond_video_frames] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + if not return_dict: + return (video,) + + return WanPipelineOutput(frames=video) diff --git a/diffsynth_engine/pipelines/wan/pipeline_wan_i2v.py b/diffsynth_engine/pipelines/wan/pipeline_wan_i2v.py new file mode 100644 index 0000000..70a74b5 --- /dev/null +++ b/diffsynth_engine/pipelines/wan/pipeline_wan_i2v.py @@ -0,0 +1,978 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan_i2v.py + +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import html +import json +import os +from typing import TYPE_CHECKING, Any, Callable + +import PIL +import regex as re +import torch +from accelerate import init_empty_weights +from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput +from diffusers.schedulers import UniPCMultistepScheduler +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel + +from diffsynth_engine.configs.wan import WanPipelineConfig +from diffsynth_engine.distributed.parallel_state import ( + get_cfg_group, + is_cfg_group_initialized, +) +from diffsynth_engine.forward_context import set_forward_context +from diffsynth_engine.models.wan import AutoencoderKLWan, WanTransformer3DModel +from diffsynth_engine.pipelines.base import Pipeline +from diffsynth_engine.registry import get_attn_backend +from diffsynth_engine.utils import logging + +logger = logging.get_logger(__name__) + +if TYPE_CHECKING: + from diffusers.image_processor import PipelineImageInput + + +def basic_clean(text): + try: + import ftfy + + text = ftfy.fix_text(text) + except ImportError: + pass + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class WanImageToVideoPipeline(Pipeline): + r""" + Pipeline for image-to-video generation using Wan. + + Args: + pipeline_config (`WanPipelineConfig`): + Configuration for the pipeline. + tokenizer (`AutoTokenizer`): + Tokenizer from T5, specifically the google/umt5-xxl variant. + text_encoder (`UMT5EncoderModel`): + T5 text encoder, specifically the google/umt5-xxl variant. + image_encoder (`CLIPVisionModel`, *optional*): + CLIP vision model for encoding input images. + image_processor (`CLIPImageProcessor`, *optional*): + CLIP image processor for preprocessing input images. + vae (`AutoencoderKLWan`): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + scheduler (`UniPCMultistepScheduler`): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + transformer (`WanTransformer3DModel`, *optional*): + Conditional Transformer to denoise the input latents. + transformer_2 (`WanTransformer3DModel`, *optional*): + Conditional Transformer to denoise the input latents during the low-noise stage. If provided, enables + two-stage denoising where `transformer` handles high-noise stages and `transformer_2` handles low-noise + stages. If not provided, only `transformer` is used. + boundary_ratio (`float`, *optional*, defaults to `None`): + Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising. + The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided, + `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < + boundary_timestep. If `None`, only `transformer` is used for the entire denoising process. + expand_timesteps (`bool`, defaults to `False`): + Whether to expand timesteps for Wan2.2 ti2v models. + """ + + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + pipeline_config: WanPipelineConfig, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKLWan, + scheduler: UniPCMultistepScheduler, + image_processor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModel = None, + transformer: WanTransformer3DModel = None, + transformer_2: WanTransformer3DModel = None, + boundary_ratio: float | None = None, + expand_timesteps: bool = False, + ): + super().__init__(pipeline_config) + + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.vae = vae + self.image_encoder = image_encoder + self.image_processor = image_processor + self.transformer = transformer + self.transformer_2 = transformer_2 + self.scheduler = scheduler + self.boundary_ratio = boundary_ratio + self.expand_timesteps = expand_timesteps + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if self.vae is not None else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if self.vae is not None else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + active_transformer = transformer if transformer is not None else transformer_2 + head_dim = active_transformer.config.attention_head_dim + self.attn_backend = get_attn_backend(pipeline_config.attn_type) + if not self.attn_backend.supports_head_size(head_dim): + raise ValueError(f"Attention backend {pipeline_config.attn_type!r} does not support head size {head_dim}.") + + @classmethod + def from_pretrained(cls, model_path_or_config: str | WanPipelineConfig): + """ + Load a WanImageToVideoPipeline from a pretrained model path or config. + + Args: + model_path_or_config: Either a string path to the model directory or a WanPipelineConfig instance. + + Returns: + WanImageToVideoPipeline: The loaded pipeline. + """ + if isinstance(model_path_or_config, str): + pipeline_config = WanPipelineConfig(model_path=model_path_or_config) + else: + pipeline_config = model_path_or_config + + if not os.path.exists(pipeline_config.model_path): + raise FileNotFoundError(f"Model path not found: {pipeline_config.model_path}") + + model_index_path = os.path.join(pipeline_config.model_path, "model_index.json") + model_index = {} + boundary_ratio = None + expand_timesteps = False + if os.path.exists(model_index_path): + with open(model_index_path, "r") as f: + model_index = json.load(f) + boundary_ratio = model_index.get("boundary_ratio", None) + expand_timesteps = model_index.get("expand_timesteps", False) + if boundary_ratio is not None: + logger.info(f"Loaded boundary_ratio={boundary_ratio} from model_index.json") + if expand_timesteps: + logger.info(f"Loaded expand_timesteps={expand_timesteps} from model_index.json") + + # Load transformer + transformer = cls.init_transformer(WanTransformer3DModel, pipeline_config) + + # Load transformer_2 + transformer_2 = None + if "transformer_2" in model_index and model_index["transformer_2"] is not None: + transformer_2_subfolder = "transformer_2" + if os.path.isdir(os.path.join(pipeline_config.model_path, transformer_2_subfolder)): + transformer_2 = cls.init_transformer( + WanTransformer3DModel, pipeline_config, subfolder=transformer_2_subfolder + ) + logger.info( + f"Loaded transformer_2 from `{transformer_2_subfolder}` subfolder of {pipeline_config.model_path}." + ) + else: + logger.warning( + f"transformer_2 declared in model_index.json but subfolder " + f"'{transformer_2_subfolder}' not found in {pipeline_config.model_path}. Skipping." + ) + + # Load scheduler + scheduler = UniPCMultistepScheduler.from_pretrained( + pipeline_config.model_path, + subfolder="scheduler", + ) + + # Load VAE + vae = cls.init_vae(AutoencoderKLWan, pipeline_config) + + # Load text encoder + text_encoder = cls.init_text_encoder(UMT5EncoderModel, pipeline_config, strict=False) + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained( + pipeline_config.model_path, + subfolder="tokenizer", + ) + + # Load image encoder + image_encoder = cls.init_image_encoder(pipeline_config) + + # Load image processor + image_processor = None + image_processor_path = os.path.join(pipeline_config.model_path, "image_processor") + if os.path.isdir(image_processor_path): + image_processor = CLIPImageProcessor.from_pretrained( + pipeline_config.model_path, + subfolder="image_processor", + ) + logger.info("Loaded image_processor from `image_processor` subfolder.") + + return cls( + pipeline_config=pipeline_config, + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + image_encoder=image_encoder, + image_processor=image_processor, + transformer=transformer, + transformer_2=transformer_2, + scheduler=scheduler, + boundary_ratio=boundary_ratio, + expand_timesteps=expand_timesteps, + ) + + @staticmethod + def init_image_encoder(pipeline_config: WanPipelineConfig, empty_weights: bool = False): + logger.info("Initializing image encoder...") + image_encoder_path = os.path.join(pipeline_config.model_path, "image_encoder") + if not os.path.isdir(image_encoder_path): + logger.warning(f"image_encoder not found in {pipeline_config.model_path}.") + return None + + if empty_weights: + with init_empty_weights(): + model = CLIPVisionModel.from_pretrained( + pipeline_config.model_path, + subfolder="image_encoder", + local_files_only=True, + ) + return model + + model = CLIPVisionModel.from_pretrained( + pipeline_config.model_path, + subfolder="image_encoder", + dtype=torch.float32, + ) + model.to(device=pipeline_config.device) + return model + + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self.device + dtype = dtype or self.pipeline_config.text_encoder_dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_image(self, image: PipelineImageInput, device: torch.device | None = None): + device = device or self.device + image = self.image_processor(images=image, return_tensors="pt").to(device) + image_embeds = self.image_encoder(**image, output_hidden_states=True) + return image_embeds.hidden_states[-2] + + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. + max_sequence_length (`int`, *optional*, defaults to 226): + Maximum sequence length for the text encoder. + device (`torch.device`, *optional*): + torch device + dtype (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self.device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + image: PipelineImageInput, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + image_embeds=None, + callback_on_step_end_tensor_inputs=None, + guidance_scale_2=None, + ): + if image is not None and image_embeds is not None: + raise ValueError( + f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to" + " only forward one of the two." + ) + if image is None and image_embeds is None: + raise ValueError( + "Provide either `image` or `image_embeds`. Cannot leave both `image` and `image_embeds` undefined." + ) + if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): + raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found " + f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if self.boundary_ratio is None and guidance_scale_2 is not None: + raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.") + + if self.boundary_ratio is not None and image_embeds is not None: + raise ValueError("Cannot forward `image_embeds` when the pipeline's `boundary_ratio` is not configured.") + + def prepare_latents( + self, + image: PipelineImageInput, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + last_image: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + image = image.unsqueeze(2) # [batch_size, channels, 1, height, width] + + if self.expand_timesteps: + video_condition = image + elif last_image is None: + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 + ) + else: + last_image = last_image.unsqueeze(2) + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image], + dim=2, + ) + video_condition = video_condition.to(device=device, dtype=self.pipeline_config.vae_dtype) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + + if isinstance(generator, list): + latent_condition = [ + retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator + ] + latent_condition = torch.cat(latent_condition) + else: + latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") + latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + + latent_condition = latent_condition.to(dtype) + latent_condition = (latent_condition - latents_mean) * latents_std + + if self.expand_timesteps: + first_frame_mask = torch.ones( + 1, 1, num_latent_frames, latent_height, latent_width, dtype=dtype, device=device + ) + first_frame_mask[:, :, 0] = 0 + return latents, latent_condition, first_frame_mask + + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) + + if last_image is None: + mask_lat_size[:, :, list(range(1, num_frames))] = 0 + else: + mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0 + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(latent_condition.device) + + return latents, torch.concat([mask_lat_size, latent_condition], dim=1) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + def _build_attn_metadata(self, attn_params): + if attn_params is None: + return None + + builder_cls = self.attn_backend.get_builder_cls() + builder = builder_cls() + attn_params_dict = attn_params.to_dict() + attn_metadata = builder.build(**attn_params_dict) + return attn_metadata + + def _predict_noise_with_cfg( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, + prompt_embeds: torch.Tensor, + negative_prompt_embeds: torch.Tensor, + image_embeds: torch.Tensor | None, + attn_metadata, + apply_cfg: bool, + guidance_scale: float, + use_cfg_parallel: bool, + model: WanTransformer3DModel | None = None, + ): + """ + Predict noise with classifier-free guidance, supporting parallel CFG inference. + + Args: + latent_model_input: The model input (latents or latents + condition). + timestep: Current timestep tensor. + prompt_embeds: Positive prompt embeddings tensor. + negative_prompt_embeds: Negative prompt embeddings tensor. + image_embeds: Image embeddings tensor for I2V cross-attention. + attn_metadata: Attention metadata for set_forward_context. + apply_cfg: Whether to apply classifier-free guidance this step. + guidance_scale: The CFG scale factor. + use_cfg_parallel: Whether to use CFG parallelism across devices. + model: The transformer model to use. If None, defaults to self.transformer. + + Returns: + noise_pred: The predicted noise tensor. + """ + if model is None: + model = self.transformer + + if not apply_cfg: + with set_forward_context(attn_metadata=attn_metadata): + noise_pred = model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_image=image_embeds, + return_dict=False, + )[0] + return noise_pred.float() + + # CFG mode + cfg_group, cfg_rank = None, None + if use_cfg_parallel: + if not is_cfg_group_initialized(): + raise RuntimeError("CFG group must be initialized when use_cfg_parallel=True") + cfg_group = get_cfg_group() + cfg_rank = cfg_group.rank_in_group + + noise_pred_pos = torch.zeros_like(latent_model_input, dtype=torch.float32) + noise_pred_neg = torch.zeros_like(latent_model_input, dtype=torch.float32) + + # Positive prompt forward pass + if not (use_cfg_parallel and cfg_rank != 0): + with set_forward_context(attn_metadata=attn_metadata): + noise_pred_pos = model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_image=image_embeds, + return_dict=False, + )[0].float() + + # Negative prompt forward pass + if not use_cfg_parallel or cfg_rank != 0: + with set_forward_context(attn_metadata=attn_metadata): + noise_pred_neg = model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_image=image_embeds, + return_dict=False, + )[0].float() + + # All-reduce for CFG parallel + if use_cfg_parallel: + noise_pred_pos = cfg_group.all_reduce(noise_pred_pos) + noise_pred_neg = cfg_group.all_reduce(noise_pred_neg) + + # Apply CFG + noise_pred = noise_pred_neg + guidance_scale * (noise_pred_pos - noise_pred_neg) + return noise_pred + + @torch.no_grad() + def __call__( + self, + image: PipelineImageInput, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + guidance_scale_2: float | None = None, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + image_embeds: torch.Tensor | None = None, + last_image: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int, dict], dict] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, pass `prompt_embeds` instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to avoid during video generation. If not defined, pass `negative_prompt_embeds` + instead. Ignored when not using guidance (`guidance_scale` < `1`). + height (`int`, defaults to `480`): + The height in pixels of the generated video. + width (`int`, defaults to `832`): + The width in pixels of the generated video. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + guidance_scale_2 (`float`, *optional*, defaults to `None`): + Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's + `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2` + and the pipeline's `boundary_ratio` are not None. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. If not provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings. + last_image (`torch.Tensor`, *optional*): + Optional last frame image for video generation with start and end frames. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated video. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a `WanPipelineOutput` instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + Attention kwargs dictionary. + callback_on_step_end (`Callable`, *optional*): + A function that is called at the end of each denoising step during the inference with the following + arguments: `callback_on_step_end(step: int, timestep: int, callback_kwargs: dict)`. `callback_kwargs` + will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. + + Returns: + `WanPipelineOutput` or `tuple`: + If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + # 1. Check inputs + self.check_inputs( + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds, + negative_prompt_embeds, + image_embeds, + callback_on_step_end_tensor_inputs, + guidance_scale_2, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. " + "Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + patch_size = ( + self.transformer.config.patch_size if self.transformer is not None else self.transformer_2.config.patch_size + ) + h_multiple_of = self.vae_scale_factor_spatial * patch_size[1] + w_multiple_of = self.vae_scale_factor_spatial * patch_size[2] + calc_height = height // h_multiple_of * h_multiple_of + calc_width = width // w_multiple_of * w_multiple_of + if height != calc_height or width != calc_width: + logger.warning( + f"`height` and `width` must be multiples of ({h_multiple_of}, {w_multiple_of}) for proper " + f"patchification. Adjusting ({height}, {width}) -> ({calc_height}, {calc_width})." + ) + height, width = calc_height, calc_width + + if self.boundary_ratio is not None and guidance_scale_2 is None: + guidance_scale_2 = guidance_scale + + self._guidance_scale = guidance_scale + self._guidance_scale_2 = guidance_scale_2 + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self.device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + # Encode image embedding + transformer_dtype = self.pipeline_config.model_dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # only wan 2.1 i2v transformer accepts image_embeds + if self.transformer is not None and self.transformer.config.image_dim is not None: + if image_embeds is None: + if last_image is None: + image_embeds = self.encode_image(image, device) + else: + image_embeds = self.encode_image([image, last_image], device) + image_embeds = image_embeds.repeat(batch_size, 1, 1) + image_embeds = image_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.vae.config.z_dim + image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) + if last_image is not None: + last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( + device, dtype=torch.float32 + ) + + latents_outputs = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + last_image, + ) + if self.expand_timesteps: + latents, condition, first_frame_mask = latents_outputs + else: + latents, condition = latents_outputs + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + if self.boundary_ratio is not None: + boundary_timestep = self.boundary_ratio * self.scheduler.config.num_train_timesteps + else: + boundary_timestep = None + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + if boundary_timestep is None or t >= boundary_timestep: + current_model = self.transformer + current_guidance_scale = guidance_scale + else: + current_model = self.transformer_2 + current_guidance_scale = guidance_scale_2 + + if self.expand_timesteps: + latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents + latent_model_input = latent_model_input.to(transformer_dtype) + + temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten() + timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) + else: + latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + attn_metadata = self._build_attn_metadata(self.pipeline_config.attn_params) + + noise_pred = self._predict_noise_with_cfg( + latent_model_input=latent_model_input, + timestep=timestep, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + image_embeds=image_embeds, + attn_metadata=attn_metadata, + apply_cfg=self.do_classifier_free_guidance, + guidance_scale=current_guidance_scale, + use_cfg_parallel=self.pipeline_config.use_cfg_parallel, + model=current_model, + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + self._current_timestep = None + + if self.expand_timesteps: + latents = (1 - first_frame_mask) * condition + first_frame_mask * latents + + if not output_type == "latent": + latents = latents.to(self.pipeline_config.vae_dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + if not return_dict: + return (video,) + + return WanPipelineOutput(frames=video) diff --git a/diffsynth_engine/pipelines/wan/pipeline_wan_t2v.py b/diffsynth_engine/pipelines/wan/pipeline_wan_t2v.py new file mode 100644 index 0000000..12100ab --- /dev/null +++ b/diffsynth_engine/pipelines/wan/pipeline_wan_t2v.py @@ -0,0 +1,806 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py + +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import html +import json +import os +from typing import Any, Callable + +import regex as re +import torch +from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput +from diffusers.schedulers import UniPCMultistepScheduler +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from transformers import AutoTokenizer, UMT5EncoderModel + +from diffsynth_engine.configs.wan import WanPipelineConfig +from diffsynth_engine.distributed.parallel_state import ( + get_cfg_group, + is_cfg_group_initialized, +) +from diffsynth_engine.forward_context import set_forward_context +from diffsynth_engine.models.wan import AutoencoderKLWan, WanTransformer3DModel +from diffsynth_engine.pipelines.base import Pipeline +from diffsynth_engine.registry import get_attn_backend +from diffsynth_engine.utils import logging + +logger = logging.get_logger(__name__) + + +def basic_clean(text): + try: + import ftfy + + text = ftfy.fix_text(text) + except ImportError: + pass + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +class WanTextToVideoPipeline(Pipeline): + r""" + Pipeline for text-to-video generation using Wan. + + Args: + pipeline_config (`WanPipelineConfig`): + Configuration for the pipeline. + tokenizer (`AutoTokenizer`): + Tokenizer from T5, specifically the google/umt5-xxl variant. + text_encoder (`UMT5EncoderModel`): + T5 text encoder, specifically the google/umt5-xxl variant. + vae (`AutoencoderKLWan`): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + scheduler (`UniPCMultistepScheduler`): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + transformer (`WanTransformer3DModel`, *optional*): + Conditional Transformer to denoise the input latents. + transformer_2 (`WanTransformer3DModel`, *optional*): + Conditional Transformer to denoise the input latents during the low-noise stage. If provided, enables + two-stage denoising where `transformer` handles high-noise stages and `transformer_2` handles low-noise + stages. If not provided, only `transformer` is used. + boundary_ratio (`float`, *optional*, defaults to `None`): + Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising. + The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided, + `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < + boundary_timestep. If `None`, only `transformer` is used for the entire denoising process. + expand_timesteps (`bool`, defaults to `False`): + Whether to expand timesteps for Wan2.2 ti2v models. + """ + + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + pipeline_config: WanPipelineConfig, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKLWan, + scheduler: UniPCMultistepScheduler, + transformer: WanTransformer3DModel | None = None, + transformer_2: WanTransformer3DModel | None = None, + boundary_ratio: float | None = None, + expand_timesteps: bool = False, + ): + super().__init__(pipeline_config) + + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.vae = vae + self.transformer = transformer + self.transformer_2 = transformer_2 + self.scheduler = scheduler + self.boundary_ratio = boundary_ratio + self.expand_timesteps = expand_timesteps + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if self.vae is not None else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if self.vae is not None else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + active_transformer = transformer if transformer is not None else transformer_2 + head_dim = active_transformer.config.attention_head_dim + self.attn_backend = get_attn_backend(pipeline_config.attn_type) + if not self.attn_backend.supports_head_size(head_dim): + raise ValueError(f"Attention backend {pipeline_config.attn_type!r} does not support head size {head_dim}.") + + @classmethod + def from_pretrained(cls, model_path_or_config: str | WanPipelineConfig): + """ + Load a WanTextToVideoPipeline from a pretrained model path or config. + + Args: + model_path_or_config: Either a string path to the model directory or a WanPipelineConfig instance. + + Returns: + WanTextToVideoPipeline: The loaded pipeline. + """ + if isinstance(model_path_or_config, str): + pipeline_config = WanPipelineConfig(model_path=model_path_or_config) + else: + pipeline_config = model_path_or_config + + if not os.path.exists(pipeline_config.model_path): + raise FileNotFoundError(f"Model path not found: {pipeline_config.model_path}") + + model_index_path = os.path.join(pipeline_config.model_path, "model_index.json") + model_index = {} + boundary_ratio = None + expand_timesteps = False + if os.path.exists(model_index_path): + with open(model_index_path, "r") as f: + model_index = json.load(f) + boundary_ratio = model_index.get("boundary_ratio", None) + expand_timesteps = model_index.get("expand_timesteps", False) + if boundary_ratio is not None: + logger.info(f"Loaded boundary_ratio={boundary_ratio} from model_index.json") + if expand_timesteps: + logger.info(f"Loaded expand_timesteps={expand_timesteps} from model_index.json") + + # Load transformer + transformer = cls.init_transformer(WanTransformer3DModel, pipeline_config) + + # Load transformer_2 + transformer_2 = None + if "transformer_2" in model_index and model_index["transformer_2"] is not None: + transformer_2_subfolder = "transformer_2" + if os.path.isdir(os.path.join(pipeline_config.model_path, transformer_2_subfolder)): + transformer_2 = cls.init_transformer( + WanTransformer3DModel, pipeline_config, subfolder=transformer_2_subfolder + ) + logger.info( + f"Loaded transformer_2 from `{transformer_2_subfolder}` subfolder of {pipeline_config.model_path}." + ) + else: + logger.warning( + f"transformer_2 declared in model_index.json but subfolder " + f"'{transformer_2_subfolder}' not found in {pipeline_config.model_path}. Skipping." + ) + + # Load scheduler + scheduler = UniPCMultistepScheduler.from_pretrained( + pipeline_config.model_path, + subfolder="scheduler", + ) + + # Load VAE + vae = cls.init_vae(AutoencoderKLWan, pipeline_config) + + # Load text encoder + text_encoder = cls.init_text_encoder(UMT5EncoderModel, pipeline_config, strict=False) + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained( + pipeline_config.model_path, + subfolder="tokenizer", + ) + + return cls( + pipeline_config=pipeline_config, + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + transformer=transformer, + transformer_2=transformer_2, + scheduler=scheduler, + boundary_ratio=boundary_ratio, + expand_timesteps=expand_timesteps, + ) + + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self.device + dtype = dtype or self.pipeline_config.text_encoder_dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. + max_sequence_length (`int`, *optional*, defaults to 226): + Maximum sequence length for the text encoder. + device (`torch.device`, *optional*): + torch device + dtype (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self.device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + guidance_scale_2=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found " + f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if self.boundary_ratio is None and guidance_scale_2 is not None: + raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.") + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + def _build_attn_metadata(self, attn_params): + if attn_params is None: + return None + + builder_cls = self.attn_backend.get_builder_cls() + builder = builder_cls() + attn_params_dict = attn_params.to_dict() + attn_metadata = builder.build(**attn_params_dict) + return attn_metadata + + def _predict_noise_with_cfg( + self, + latents: torch.Tensor, + timestep: torch.Tensor, + prompt_embeds: torch.Tensor, + negative_prompt_embeds: torch.Tensor, + attn_metadata, + apply_cfg: bool, + guidance_scale: float, + use_cfg_parallel: bool, + model: WanTransformer3DModel | None = None, + ): + """ + Predict noise with classifier-free guidance, supporting parallel CFG inference. + + Args: + latents: Current noisy latents. + timestep: Current timestep tensor. + prompt_embeds: Positive prompt embeddings tensor. + negative_prompt_embeds: Negative prompt embeddings tensor. + attn_metadata: Attention metadata for set_forward_context. + apply_cfg: Whether to apply classifier-free guidance this step. + guidance_scale: The CFG scale factor. + use_cfg_parallel: Whether to use CFG parallelism across devices. + model: The transformer model to use. If None, defaults to self.transformer. + + Returns: + noise_pred: The predicted noise tensor. + """ + if model is None: + model = self.transformer + + transformer_dtype = self.pipeline_config.model_dtype + + if not apply_cfg: + latent_model_input = latents.to(transformer_dtype) + with set_forward_context(attn_metadata=attn_metadata): + noise_pred = model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + return_dict=False, + )[0] + return noise_pred.float() + + # CFG mode + cfg_group, cfg_rank = None, None + if use_cfg_parallel: + if not is_cfg_group_initialized(): + raise RuntimeError("CFG group must be initialized when use_cfg_parallel=True") + cfg_group = get_cfg_group() + cfg_rank = cfg_group.rank_in_group + + latent_model_input = latents.to(transformer_dtype) + + noise_pred_pos = torch.zeros_like(latents, dtype=torch.float32) + noise_pred_neg = torch.zeros_like(latents, dtype=torch.float32) + + # Positive prompt forward pass + if not (use_cfg_parallel and cfg_rank != 0): + with set_forward_context(attn_metadata=attn_metadata): + noise_pred_pos = model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + return_dict=False, + )[0].float() + + # Negative prompt forward pass + if not use_cfg_parallel or cfg_rank != 0: + with set_forward_context(attn_metadata=attn_metadata): + noise_pred_neg = model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + return_dict=False, + )[0].float() + + # All-reduce for CFG parallel + if use_cfg_parallel: + noise_pred_pos = cfg_group.all_reduce(noise_pred_pos) + noise_pred_neg = cfg_group.all_reduce(noise_pred_neg) + + # Apply CFG + noise_pred = noise_pred_neg + guidance_scale * (noise_pred_pos - noise_pred_neg) + return noise_pred + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + guidance_scale_2: float | None = None, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int, dict], dict] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, pass `prompt_embeds` instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to avoid during video generation. If not defined, pass `negative_prompt_embeds` + instead. Ignored when not using guidance (`guidance_scale` < `1`). + height (`int`, defaults to `480`): + The height in pixels of the generated video. + width (`int`, defaults to `832`): + The width in pixels of the generated video. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + guidance_scale_2 (`float`, *optional*, defaults to `None`): + Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's + `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2` + and the pipeline's `boundary_ratio` are not None. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. If not provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated video. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a `WanPipelineOutput` instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + Attention kwargs dictionary. + callback_on_step_end (`Callable`, *optional*): + A function that is called at the end of each denoising step during the inference with the following + arguments: `callback_on_step_end(step: int, timestep: int, callback_kwargs: dict)`. `callback_kwargs` + will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. + + Returns: + `WanPipelineOutput` or `tuple`: + If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + # 1. Check inputs + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + guidance_scale_2, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. " + "Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + patch_size = ( + self.transformer.config.patch_size if self.transformer is not None else self.transformer_2.config.patch_size + ) + h_multiple_of = self.vae_scale_factor_spatial * patch_size[1] + w_multiple_of = self.vae_scale_factor_spatial * patch_size[2] + calc_height = height // h_multiple_of * h_multiple_of + calc_width = width // w_multiple_of * w_multiple_of + if height != calc_height or width != calc_width: + logger.warning( + f"`height` and `width` must be multiples of ({h_multiple_of}, {w_multiple_of}) for proper " + f"patchification. Adjusting ({height}, {width}) -> ({calc_height}, {calc_width})." + ) + height, width = calc_height, calc_width + + if self.boundary_ratio is not None and guidance_scale_2 is None: + guidance_scale_2 = guidance_scale + + self._guidance_scale = guidance_scale + self._guidance_scale_2 = guidance_scale_2 + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self.device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = self.pipeline_config.model_dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = ( + self.transformer.config.in_channels + if self.transformer is not None + else self.transformer_2.config.in_channels + ) + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + mask = torch.ones(latents.shape, dtype=torch.float32, device=device) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + + if self.boundary_ratio is not None: + boundary_timestep = self.boundary_ratio * self.scheduler.config.num_train_timesteps + else: + boundary_timestep = None + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + if boundary_timestep is None or t >= boundary_timestep: + # wan2.1 or high-noise stage in wan2.2 + current_model = self.transformer + current_guidance_scale = guidance_scale + else: + # low-noise stage in wan2.2 + current_model = self.transformer_2 + current_guidance_scale = guidance_scale_2 + + if self.expand_timesteps: + # seq_len: num_latent_frames * latent_height//2 * latent_width//2 + temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten() + # batch_size, seq_len + timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) + else: + timestep = t.expand(latents.shape[0]) + + attn_metadata = self._build_attn_metadata(self.pipeline_config.attn_params) + + noise_pred = self._predict_noise_with_cfg( + latents=latents, + timestep=timestep, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + attn_metadata=attn_metadata, + apply_cfg=self.do_classifier_free_guidance, + guidance_scale=current_guidance_scale, + use_cfg_parallel=self.pipeline_config.use_cfg_parallel, + model=current_model, + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.pipeline_config.vae_dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + if not return_dict: + return (video,) + + return WanPipelineOutput(frames=video) diff --git a/diffsynth_engine/pipelines/wan/pipeline_wan_vace.py b/diffsynth_engine/pipelines/wan/pipeline_wan_vace.py new file mode 100644 index 0000000..89e7180 --- /dev/null +++ b/diffsynth_engine/pipelines/wan/pipeline_wan_vace.py @@ -0,0 +1,1113 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan_vace.py + +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import html +import json +import os +from typing import TYPE_CHECKING, Any, Callable + +import PIL.Image +import regex as re +import torch +from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput +from diffusers.schedulers import UniPCMultistepScheduler +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from transformers import AutoTokenizer, UMT5EncoderModel + +from diffsynth_engine.configs.wan import WanPipelineConfig +from diffsynth_engine.distributed.parallel_state import ( + get_cfg_group, + is_cfg_group_initialized, +) +from diffsynth_engine.forward_context import set_forward_context +from diffsynth_engine.models.wan import AutoencoderKLWan, WanVACETransformer3DModel +from diffsynth_engine.pipelines.base import Pipeline +from diffsynth_engine.registry import get_attn_backend +from diffsynth_engine.utils import logging + +logger = logging.get_logger(__name__) + +if TYPE_CHECKING: + from diffusers.image_processor import PipelineImageInput + + +def basic_clean(text): + try: + import ftfy + + text = ftfy.fix_text(text) + except ImportError: + pass + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class WanVACEPipeline(Pipeline): + r""" + Pipeline for controllable generation using Wan. + + Args: + pipeline_config (`WanPipelineConfig`): + Configuration for the pipeline. + tokenizer (`AutoTokenizer`): + Tokenizer from T5, specifically the google/umt5-xxl variant. + text_encoder (`UMT5EncoderModel`): + T5 text encoder, specifically the google/umt5-xxl variant. + vae (`AutoencoderKLWan`): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + scheduler (`UniPCMultistepScheduler`): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + transformer (`WanVACETransformer3DModel`, *optional*): + Conditional Transformer to denoise the input latents. + transformer_2 (`WanVACETransformer3DModel`, *optional*): + Conditional Transformer to denoise the input latents during the low-noise stage. If provided, enables + two-stage denoising where `transformer` handles high-noise stages and `transformer_2` handles low-noise + stages. If not provided, only `transformer` is used. + boundary_ratio (`float`, *optional*, defaults to `None`): + Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising. + The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided, + `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < + boundary_timestep. If `None`, only `transformer` is used for the entire denoising process. + """ + + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + pipeline_config: WanPipelineConfig, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKLWan, + scheduler: UniPCMultistepScheduler, + transformer: WanVACETransformer3DModel = None, + transformer_2: WanVACETransformer3DModel = None, + boundary_ratio: float | None = None, + ): + super().__init__(pipeline_config) + + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.vae = vae + self.transformer = transformer + self.transformer_2 = transformer_2 + self.scheduler = scheduler + self.boundary_ratio = boundary_ratio + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if self.vae is not None else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if self.vae is not None else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + active_transformer = transformer if transformer is not None else transformer_2 + head_dim = active_transformer.config.attention_head_dim + self.attn_backend = get_attn_backend(pipeline_config.attn_type) + if not self.attn_backend.supports_head_size(head_dim): + raise ValueError(f"Attention backend {pipeline_config.attn_type!r} does not support head size {head_dim}.") + + @classmethod + def from_pretrained(cls, model_path_or_config: str | WanPipelineConfig): + """ + Load a WanVACEPipeline from a pretrained model path or config. + + Args: + model_path_or_config: Either a string path to the model directory or a WanPipelineConfig instance. + + Returns: + WanVACEPipeline: The loaded pipeline. + """ + if isinstance(model_path_or_config, str): + pipeline_config = WanPipelineConfig(model_path=model_path_or_config) + else: + pipeline_config = model_path_or_config + + if not os.path.exists(pipeline_config.model_path): + raise FileNotFoundError(f"Model path not found: {pipeline_config.model_path}") + + model_index_path = os.path.join(pipeline_config.model_path, "model_index.json") + model_index = {} + boundary_ratio = None + if os.path.exists(model_index_path): + with open(model_index_path, "r") as f: + model_index = json.load(f) + boundary_ratio = model_index.get("boundary_ratio", None) + if boundary_ratio is not None: + logger.info(f"Loaded boundary_ratio={boundary_ratio} from model_index.json") + + # Load transformer + transformer = cls.init_transformer(WanVACETransformer3DModel, pipeline_config) + + # Load transformer_2 + transformer_2 = None + if "transformer_2" in model_index and model_index["transformer_2"] is not None: + transformer_2_subfolder = "transformer_2" + if os.path.isdir(os.path.join(pipeline_config.model_path, transformer_2_subfolder)): + transformer_2 = cls.init_transformer( + WanVACETransformer3DModel, pipeline_config, subfolder=transformer_2_subfolder + ) + logger.info( + f"Loaded transformer_2 from `{transformer_2_subfolder}` subfolder of {pipeline_config.model_path}." + ) + else: + logger.warning( + f"transformer_2 declared in model_index.json but subfolder " + f"'{transformer_2_subfolder}' not found in {pipeline_config.model_path}. Skipping." + ) + + # Load scheduler + scheduler = UniPCMultistepScheduler.from_pretrained(pipeline_config.model_path, subfolder="scheduler") + + # Load VAE + vae = cls.init_vae(AutoencoderKLWan, pipeline_config) + + # Load text encoder + text_encoder = cls.init_text_encoder(UMT5EncoderModel, pipeline_config, strict=False) + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(pipeline_config.model_path, subfolder="tokenizer") + + return cls( + pipeline_config=pipeline_config, + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + transformer=transformer, + transformer_2=transformer_2, + scheduler=scheduler, + boundary_ratio=boundary_ratio, + ) + + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self.device + dtype = dtype or self.pipeline_config.text_encoder_dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. + max_sequence_length (`int`, *optional*, defaults to 226): + Maximum sequence length for the text encoder. + device (`torch.device`, *optional*): + torch device + dtype (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self.device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + video=None, + mask=None, + reference_images=None, + guidance_scale_2=None, + ): + if self.transformer is not None: + base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1] + elif self.transformer_2 is not None: + base = self.vae_scale_factor_spatial * self.transformer_2.config.patch_size[1] + else: + raise ValueError( + "`transformer` or `transformer_2` component must be set in order to run inference with this pipeline" + ) + + if height % base != 0 or width % base != 0: + raise ValueError(f"`height` and `width` have to be divisible by {base} but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found " + f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if self.boundary_ratio is None and guidance_scale_2 is not None: + raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: " + f"{negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if video is not None: + if mask is not None: + if len(video) != len(mask): + raise ValueError( + f"Length of `video` {len(video)} and `mask` {len(mask)} do not match. Please make sure that" + " they have the same length." + ) + if reference_images is not None: + is_pil_image = isinstance(reference_images, PIL.Image.Image) + is_list_of_pil_images = isinstance(reference_images, list) and all( + isinstance(ref_img, PIL.Image.Image) for ref_img in reference_images + ) + is_list_of_list_of_pil_images = isinstance(reference_images, list) and all( + isinstance(ref_img, list) and all(isinstance(r, PIL.Image.Image) for r in ref_img) + for ref_img in reference_images + ) + if not (is_pil_image or is_list_of_pil_images or is_list_of_list_of_pil_images): + raise ValueError( + "`reference_images` has to be of type `PIL.Image.Image` or `list` of `PIL.Image.Image`, or " + f"`list` of `list` of `PIL.Image.Image`, but is {type(reference_images)}" + ) + if is_list_of_list_of_pil_images and len(reference_images) != 1: + raise ValueError( + "The pipeline only supports generating one video at a time at the moment. When passing a list " + "of list of reference images, where the outer list corresponds to the batch size and the inner " + "list corresponds to list of conditioning images per video, please make sure to only pass " + "one inner list of reference images (i.e., `[[, , ...]]`" + ) + elif mask is not None: + raise ValueError("`mask` can only be passed if `video` is passed as well.") + + def preprocess_conditions( + self, + video: list[PipelineImageInput] | None = None, + mask: list[PipelineImageInput] | None = None, + reference_images: PIL.Image.Image | list[PIL.Image.Image] | list[list[PIL.Image.Image]] | None = None, + batch_size: int = 1, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + ): + if video is not None: + base = self.vae_scale_factor_spatial * ( + self.transformer.config.patch_size[1] + if self.transformer is not None + else self.transformer_2.config.patch_size[1] + ) + video_height, video_width = self.video_processor.get_default_height_width(video[0]) + + if video_height * video_width > height * width: + scale = min(width / video_width, height / video_height) + video_height, video_width = int(video_height * scale), int(video_width * scale) + + if video_height % base != 0 or video_width % base != 0: + logger.warning( + f"Video height and width should be divisible by {base}, but got {video_height} and {video_width}." + ) + video_height = (video_height // base) * base + video_width = (video_width // base) * base + + assert video_height * video_width <= height * width + + video = self.video_processor.preprocess_video(video, video_height, video_width) + image_size = (video_height, video_width) + else: + video = torch.zeros(batch_size, 3, num_frames, height, width, dtype=dtype, device=device) + image_size = (height, width) + + if mask is not None: + mask = self.video_processor.preprocess_video(mask, image_size[0], image_size[1]) + mask = torch.clamp((mask + 1) / 2, min=0, max=1) + else: + mask = torch.ones_like(video) + + video = video.to(dtype=dtype, device=device) + mask = mask.to(dtype=dtype, device=device) + + # Make a list of list of images where the outer list corresponds to video batch size and the inner list + # corresponds to list of conditioning images per video + if reference_images is None or isinstance(reference_images, PIL.Image.Image): + reference_images = [[reference_images] for _ in range(video.shape[0])] + elif isinstance(reference_images, (list, tuple)) and isinstance(next(iter(reference_images)), PIL.Image.Image): + reference_images = [reference_images] + elif ( + isinstance(reference_images, (list, tuple)) + and isinstance(next(iter(reference_images)), list) + and isinstance(next(iter(reference_images[0])), PIL.Image.Image) + ): + reference_images = reference_images + else: + raise ValueError( + "`reference_images` has to be of type `PIL.Image.Image` or `list` of `PIL.Image.Image`, or " + f"`list` of `list` of `PIL.Image.Image`, but is {type(reference_images)}" + ) + + if video.shape[0] != len(reference_images): + raise ValueError( + f"Batch size of `video` {video.shape[0]} and length of `reference_images` " + f"{len(reference_images)} does not match." + ) + + ref_images_lengths = [len(batch) for batch in reference_images] + if any(length != ref_images_lengths[0] for length in ref_images_lengths): + raise ValueError( + f"All batches of `reference_images` should have the same length, but got {ref_images_lengths}." + ) + + reference_images_preprocessed = [] + for reference_images_batch in reference_images: + preprocessed_images = [] + for image in reference_images_batch: + if image is None: + continue + image = self.video_processor.preprocess(image, None, None) + img_height, img_width = image.shape[-2:] + scale = min(image_size[0] / img_height, image_size[1] / img_width) + new_height, new_width = int(img_height * scale), int(img_width * scale) + resized_image = torch.nn.functional.interpolate( + image, size=(new_height, new_width), mode="bilinear", align_corners=False + ).squeeze(0) # [C, H, W] + top = (image_size[0] - new_height) // 2 + left = (image_size[1] - new_width) // 2 + canvas = torch.ones(3, *image_size, device=device, dtype=dtype) + canvas[:, top : top + new_height, left : left + new_width] = resized_image + preprocessed_images.append(canvas) + reference_images_preprocessed.append(preprocessed_images) + + return video, mask, reference_images_preprocessed + + def prepare_video_latents( + self, + video: torch.Tensor, + mask: torch.Tensor, + reference_images: list[list[torch.Tensor]] | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + device: torch.device | None = None, + ) -> torch.Tensor: + device = device or self.device + + if isinstance(generator, list): + raise ValueError("Passing a list of generators is not yet supported.") + + if reference_images is None: + reference_images = [[None] for _ in range(video.shape[0])] + else: + if video.shape[0] != len(reference_images): + raise ValueError( + f"Batch size of `video` {video.shape[0]} and length of `reference_images` " + f"{len(reference_images)} does not match." + ) + + if video.shape[0] != 1: + raise ValueError("Generating with more than one video is not yet supported.") + + vae_dtype = self.pipeline_config.vae_dtype + video = video.to(dtype=vae_dtype) + + latents_mean = torch.tensor(self.vae.config.latents_mean, device=device, dtype=torch.float32).view( + 1, self.vae.config.z_dim, 1, 1, 1 + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=device, dtype=torch.float32).view( + 1, self.vae.config.z_dim, 1, 1, 1 + ) + + if mask is None: + latents = retrieve_latents(self.vae.encode(video), generator, sample_mode="argmax").unbind(0) + latents = ((latents.float() - latents_mean) * latents_std).to(vae_dtype) + else: + mask = torch.where(mask > 0.5, 1.0, 0.0).to(dtype=vae_dtype) + inactive = video * (1 - mask) + reactive = video * mask + inactive = retrieve_latents(self.vae.encode(inactive), generator, sample_mode="argmax") + reactive = retrieve_latents(self.vae.encode(reactive), generator, sample_mode="argmax") + inactive = ((inactive.float() - latents_mean) * latents_std).to(vae_dtype) + reactive = ((reactive.float() - latents_mean) * latents_std).to(vae_dtype) + latents = torch.cat([inactive, reactive], dim=1) + + latent_list = [] + for latent, reference_images_batch in zip(latents, reference_images): + for reference_image in reference_images_batch: + assert reference_image.ndim == 3 + reference_image = reference_image.to(dtype=vae_dtype) + reference_image = reference_image[None, :, None, :, :] # [1, C, 1, H, W] + reference_latent = retrieve_latents(self.vae.encode(reference_image), generator, sample_mode="argmax") + reference_latent = ((reference_latent.float() - latents_mean) * latents_std).to(vae_dtype) + reference_latent = reference_latent.squeeze(0) # [C, 1, H, W] + reference_latent = torch.cat([reference_latent, torch.zeros_like(reference_latent)], dim=0) + latent = torch.cat([reference_latent.squeeze(0), latent], dim=1) + latent_list.append(latent) + return torch.stack(latent_list) + + def prepare_masks( + self, + mask: torch.Tensor, + reference_images: list[list[torch.Tensor]] | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + ) -> torch.Tensor: + if isinstance(generator, list): + raise ValueError("Passing a list of generators is not yet supported.") + + if reference_images is None: + reference_images = [[None] for _ in range(mask.shape[0])] + else: + if mask.shape[0] != len(reference_images): + raise ValueError( + f"Batch size of `mask` {mask.shape[0]} and length of `reference_images` " + f"{len(reference_images)} does not match." + ) + + if mask.shape[0] != 1: + raise ValueError("Generating with more than one video is not yet supported.") + + transformer_patch_size = ( + self.transformer.config.patch_size[1] + if self.transformer is not None + else self.transformer_2.config.patch_size[1] + ) + + mask_list = [] + for mask_, reference_images_batch in zip(mask, reference_images): + num_channels, num_frames, height, width = mask_.shape + new_num_frames = (num_frames + self.vae_scale_factor_temporal - 1) // self.vae_scale_factor_temporal + new_height = height // (self.vae_scale_factor_spatial * transformer_patch_size) * transformer_patch_size + new_width = width // (self.vae_scale_factor_spatial * transformer_patch_size) * transformer_patch_size + mask_ = mask_[0, :, :, :] + mask_ = mask_.view( + num_frames, new_height, self.vae_scale_factor_spatial, new_width, self.vae_scale_factor_spatial + ) + mask_ = mask_.permute(2, 4, 0, 1, 3).flatten(0, 1) # [8x8, num_frames, new_height, new_width] + mask_ = torch.nn.functional.interpolate( + mask_.unsqueeze(0), size=(new_num_frames, new_height, new_width), mode="nearest-exact" + ).squeeze(0) + num_ref_images = len(reference_images_batch) + if num_ref_images > 0: + mask_padding = torch.zeros_like(mask_[:, :num_ref_images, :, :]) + mask_ = torch.cat([mask_padding, mask_], dim=1) + mask_list.append(mask_) + return torch.stack(mask_list) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + def _build_attn_metadata(self, attn_params): + if attn_params is None: + return None + + builder_cls = self.attn_backend.get_builder_cls() + builder = builder_cls() + attn_params_dict = attn_params.to_dict() + attn_metadata = builder.build(**attn_params_dict) + return attn_metadata + + def _predict_noise_with_cfg( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, + prompt_embeds: torch.Tensor, + negative_prompt_embeds: torch.Tensor, + control_hidden_states: torch.Tensor, + control_hidden_states_scale: torch.Tensor, + attn_metadata, + apply_cfg: bool, + guidance_scale: float, + use_cfg_parallel: bool, + model: WanVACETransformer3DModel | None = None, + ): + """ + Predict noise with classifier-free guidance, supporting parallel CFG inference. + + Args: + latent_model_input: The model input latents. + timestep: Current timestep tensor. + prompt_embeds: Positive prompt embeddings tensor. + negative_prompt_embeds: Negative prompt embeddings tensor. + control_hidden_states: VACE conditioning latents. + control_hidden_states_scale: Per-layer scale for VACE conditioning. + attn_metadata: Attention metadata for set_forward_context. + apply_cfg: Whether to apply classifier-free guidance this step. + guidance_scale: The CFG scale factor. + use_cfg_parallel: Whether to use CFG parallelism across devices. + model: The transformer model to use. If None, defaults to self.transformer. + + Returns: + noise_pred: The predicted noise tensor. + """ + if model is None: + model = self.transformer + + if not apply_cfg: + with set_forward_context(attn_metadata=attn_metadata): + noise_pred = model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + control_hidden_states=control_hidden_states, + control_hidden_states_scale=control_hidden_states_scale, + return_dict=False, + )[0] + return noise_pred.float() + + # CFG mode + cfg_group, cfg_rank = None, None + if use_cfg_parallel: + if not is_cfg_group_initialized(): + raise RuntimeError("CFG group must be initialized when use_cfg_parallel=True") + cfg_group = get_cfg_group() + cfg_rank = cfg_group.rank_in_group + + noise_pred_pos = torch.zeros_like(latent_model_input, dtype=torch.float32) + noise_pred_neg = torch.zeros_like(latent_model_input, dtype=torch.float32) + + # Positive prompt forward pass + if not (use_cfg_parallel and cfg_rank != 0): + with set_forward_context(attn_metadata=attn_metadata): + noise_pred_pos = model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + control_hidden_states=control_hidden_states, + control_hidden_states_scale=control_hidden_states_scale, + return_dict=False, + )[0].float() + + # Negative prompt forward pass + if not use_cfg_parallel or cfg_rank != 0: + with set_forward_context(attn_metadata=attn_metadata): + noise_pred_neg = model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + control_hidden_states=control_hidden_states, + control_hidden_states_scale=control_hidden_states_scale, + return_dict=False, + )[0].float() + + # All-reduce for CFG parallel + if use_cfg_parallel: + noise_pred_pos = cfg_group.all_reduce(noise_pred_pos) + noise_pred_neg = cfg_group.all_reduce(noise_pred_neg) + + # Apply CFG + noise_pred = noise_pred_neg + guidance_scale * (noise_pred_pos - noise_pred_neg) + return noise_pred + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + video: list[PipelineImageInput] | None = None, + mask: list[PipelineImageInput] | None = None, + reference_images: list[PipelineImageInput] | None = None, + conditioning_scale: float | list[float] | torch.Tensor = 1.0, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + guidance_scale_2: float | None = None, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int, dict], dict] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, pass `prompt_embeds` instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to avoid during video generation. If not defined, pass `negative_prompt_embeds` + instead. Ignored when not using guidance (`guidance_scale` < `1`). + video (`list[PIL.Image.Image]`, *optional*): + The input video or videos to be used as a starting point for the generation. The video should be a list + of PIL images, a numpy array, or a torch tensor. Currently, the pipeline only supports generating one + video at a time. + mask (`list[PIL.Image.Image]`, *optional*): + The input mask defines which video regions to condition on and which to generate. Black areas in the + mask indicate conditioning regions, while white areas indicate regions for generation. The mask should + be a list of PIL images, a numpy array, or a torch tensor. Currently supports generating a single video + at a time. + reference_images (`list[PIL.Image.Image]`, *optional*): + A list of one or more reference images as extra conditioning for the generation. For example, if you + are trying to inpaint a video to change the character, you can pass reference images of the new + character here. Refer to the Diffusers [examples](https://github.com/huggingface/diffusers/pull/11582) + and original [user + guide](https://github.com/ali-vilab/VACE/blob/0897c6d055d7d9ea9e191dce763006664d9780f8/UserGuide.md) + for a full list of supported tasks and use cases. + conditioning_scale (`float`, `list[float]`, `torch.Tensor`, defaults to `1.0`): + The conditioning scale to be applied when adding the control conditioning latent stream to the + denoising latent stream in each control layer of the model. If a float is provided, it will be applied + uniformly to all layers. If a list or tensor is provided, it should have the same length as the number + of control layers in the model (`len(transformer.config.vace_layers)`). + height (`int`, defaults to `480`): + The height in pixels of the generated video. + width (`int`, defaults to `832`): + The width in pixels of the generated video. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + guidance_scale_2 (`float`, *optional*, defaults to `None`): + Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's + `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2` + and the pipeline's `boundary_ratio` are not None. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. If not provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated video. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a `WanPipelineOutput` instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + Attention kwargs dictionary. + callback_on_step_end (`Callable`, *optional*): + A function that is called at the end of each denoising step during the inference with the following + arguments: `callback_on_step_end(step: int, timestep: int, callback_kwargs: dict)`. `callback_kwargs` + will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. + + Returns: + `WanPipelineOutput` or `tuple`: + If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + # 1. Check inputs + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + video, + mask, + reference_images, + guidance_scale_2, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. " + "Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + if self.boundary_ratio is not None and guidance_scale_2 is None: + guidance_scale_2 = guidance_scale + + self._guidance_scale = guidance_scale + self._guidance_scale_2 = guidance_scale_2 + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self.device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + transformer_dtype = self.pipeline_config.model_dtype + + vace_layers = ( + self.transformer.config.vace_layers + if self.transformer is not None + else self.transformer_2.config.vace_layers + ) + if isinstance(conditioning_scale, (int, float)): + conditioning_scale = [conditioning_scale] * len(vace_layers) + if isinstance(conditioning_scale, list): + if len(conditioning_scale) != len(vace_layers): + raise ValueError( + f"Length of `conditioning_scale` {len(conditioning_scale)} does not match " + f"number of layers {len(vace_layers)}." + ) + conditioning_scale = torch.tensor(conditioning_scale) + if isinstance(conditioning_scale, torch.Tensor): + if conditioning_scale.size(0) != len(vace_layers): + raise ValueError( + f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match " + f"number of layers {len(vace_layers)}." + ) + conditioning_scale = conditioning_scale.to(device=device, dtype=transformer_dtype) + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + video, mask, reference_images = self.preprocess_conditions( + video, + mask, + reference_images, + batch_size, + height, + width, + num_frames, + torch.float32, + device, + ) + num_reference_images = len(reference_images[0]) + + conditioning_latents = self.prepare_video_latents(video, mask, reference_images, generator, device) + mask = self.prepare_masks(mask, reference_images, generator) + conditioning_latents = torch.cat([conditioning_latents, mask], dim=1) + conditioning_latents = conditioning_latents.to(transformer_dtype) + + num_channels_latents = ( + self.transformer.config.in_channels + if self.transformer is not None + else self.transformer_2.config.in_channels + ) + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames + num_reference_images * self.vae_scale_factor_temporal, + torch.float32, + device, + generator, + latents, + ) + + if conditioning_latents.shape[2] != latents.shape[2]: + logger.warning( + "The number of frames in the conditioning latents does not match the number of frames " + "to be generated. Generation quality may be affected." + ) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + if self.boundary_ratio is not None: + boundary_timestep = self.boundary_ratio * self.scheduler.config.num_train_timesteps + else: + boundary_timestep = None + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + if boundary_timestep is None or t >= boundary_timestep: + current_model = self.transformer + current_guidance_scale = guidance_scale + else: + current_model = self.transformer_2 + current_guidance_scale = guidance_scale_2 + + latent_model_input = latents.to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + attn_metadata = self._build_attn_metadata(self.pipeline_config.attn_params) + + noise_pred = self._predict_noise_with_cfg( + latent_model_input=latent_model_input, + timestep=timestep, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + control_hidden_states=conditioning_latents, + control_hidden_states_scale=conditioning_scale, + attn_metadata=attn_metadata, + apply_cfg=self.do_classifier_free_guidance, + guidance_scale=current_guidance_scale, + use_cfg_parallel=self.pipeline_config.use_cfg_parallel, + model=current_model, + ) + + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents[:, :, num_reference_images:] + latents = latents.to(self.pipeline_config.vae_dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + if not return_dict: + return (video,) + + return WanPipelineOutput(frames=video) diff --git a/diffsynth_engine/registry.py b/diffsynth_engine/registry.py index 9cd9b4a..8a9709c 100644 --- a/diffsynth_engine/registry.py +++ b/diffsynth_engine/registry.py @@ -19,6 +19,10 @@ "QwenImageEditPipeline": "diffsynth_engine.pipelines.qwen_image.pipeline_qwenimage_edit:QwenImageEditPipeline", "QwenImageEditPlusPipeline": "diffsynth_engine.pipelines.qwen_image.pipeline_qwenimage_edit_plus:QwenImageEditPlusPipeline", "QwenImageLayeredPipeline": "diffsynth_engine.pipelines.qwen_image.pipeline_qwenimage_layered:QwenImageLayeredPipeline", + "WanAnimatePipeline": "diffsynth_engine.pipelines.wan.pipeline_wan_animate:WanAnimatePipeline", + "WanImageToVideoPipeline": "diffsynth_engine.pipelines.wan.pipeline_wan_i2v:WanImageToVideoPipeline", + "WanTextToVideoPipeline": "diffsynth_engine.pipelines.wan.pipeline_wan_t2v:WanTextToVideoPipeline", + "WanVACEPipeline": "diffsynth_engine.pipelines.wan.pipeline_wan_vace:WanVACEPipeline", } _ATTENTION_BACKENDS: dict[str, str] = { diff --git a/diffsynth_engine/utils/video.py b/diffsynth_engine/utils/video.py index ce00a36..5eceebf 100644 --- a/diffsynth_engine/utils/video.py +++ b/diffsynth_engine/utils/video.py @@ -35,11 +35,15 @@ def save_video(frames, save_path, fps=15): elif save_path.endswith(".mp4"): codec = "libx264" - frames = [np.array(img) for img in frames] + converted_frames = [] + for img in frames: + arr = np.array(img) + if arr.dtype != np.uint8: + arr = np.clip(arr * 255.0, 0, 255).astype(np.uint8) + converted_frames.append(arr) - # 使用 imageio 写入 .webm 文件 with iio.imopen(save_path, "w", plugin="FFMPEG") as writer: - writer.write(frames, fps=fps, codec=codec) + writer.write(converted_frames, fps=fps, codec=codec) def read_n_frames( diff --git a/examples/input/wan_22_animate_face.mp4 b/examples/input/wan_22_animate_face.mp4 new file mode 100644 index 0000000..622d6df Binary files /dev/null and b/examples/input/wan_22_animate_face.mp4 differ diff --git a/examples/input/wan_22_animate_input.png b/examples/input/wan_22_animate_input.png new file mode 100644 index 0000000..a1b85f7 Binary files /dev/null and b/examples/input/wan_22_animate_input.png differ diff --git a/examples/input/wan_22_animate_pose.mp4 b/examples/input/wan_22_animate_pose.mp4 new file mode 100644 index 0000000..af36f5e Binary files /dev/null and b/examples/input/wan_22_animate_pose.mp4 differ diff --git a/examples/input/wan_22_i2v_input.png b/examples/input/wan_22_i2v_input.png new file mode 100644 index 0000000..6a558eb Binary files /dev/null and b/examples/input/wan_22_i2v_input.png differ diff --git a/examples/input/wan_vace_first_frame.png b/examples/input/wan_vace_first_frame.png new file mode 100644 index 0000000..032cd5c Binary files /dev/null and b/examples/input/wan_vace_first_frame.png differ diff --git a/examples/input/wan_vace_last_frame.png b/examples/input/wan_vace_last_frame.png new file mode 100644 index 0000000..83ac8c5 Binary files /dev/null and b/examples/input/wan_vace_last_frame.png differ diff --git a/examples/wan/wan_22_animate.py b/examples/wan/wan_22_animate.py new file mode 100644 index 0000000..c41fc5d --- /dev/null +++ b/examples/wan/wan_22_animate.py @@ -0,0 +1,60 @@ +import torch +from diffusers.utils import export_to_video, load_video +from PIL import Image + +from diffsynth_engine.pipelines.wan import WanAnimatePipeline +from diffsynth_engine.utils.download import fetch_model + +if __name__ == "__main__": + model_path = fetch_model("Wan-AI/Wan2.2-Animate-14B-Diffusers") + pipe = WanAnimatePipeline.from_pretrained(model_path) + + # Load the reference character image + image = Image.open("examples/input/wan_22_animate_input.png") + + # Load pose and face conditioning videos (preprocessed from a reference video) + pose_video = load_video("examples/input/wan_22_animate_pose.mp4") + face_video = load_video("examples/input/wan_22_animate_face.mp4") + + prompt = "People in the video are doing actions." + + # ---- Animate mode ---- + video = pipe( + image=image, + pose_video=pose_video, + face_video=face_video, + prompt=prompt, + mode="animate", + segment_frame_length=77, + prev_segment_conditioning_frames=1, + guidance_scale=1.0, + num_inference_steps=20, + generator=torch.Generator(device="cpu").manual_seed(42), + ) + + export_to_video(video.frames[0], "animated_output.mp4", fps=30) + + # ---- Replace mode (optional) ---- + # In replace mode, an additional background_video and mask_video are required. + # background_video: the original video whose character will be replaced. + # mask_video: grayscale masks indicating the region to replace (white = replace). + # + # background_video = load_video("examples/input/wan_22_animate_background.mp4") + # mask_video = load_video("examples/input/wan_22_animate_mask.mp4") + # + # video_replace = pipe( + # image=image, + # pose_video=pose_video, + # face_video=face_video, + # background_video=background_video, + # mask_video=mask_video, + # prompt=prompt, + # mode="replace", + # segment_frame_length=77, + # prev_segment_conditioning_frames=1, + # guidance_scale=1.0, + # num_inference_steps=20, + # generator=torch.Generator(device="cpu").manual_seed(42), + # ) + # + # export_to_video(video_replace.frames[0], "animated_output_replace.mp4", fps=30) diff --git a/examples/wan/wan_22_image_to_video.py b/examples/wan/wan_22_image_to_video.py new file mode 100644 index 0000000..ac50229 --- /dev/null +++ b/examples/wan/wan_22_image_to_video.py @@ -0,0 +1,36 @@ +import numpy as np +import torch +from diffusers.utils import export_to_video +from PIL import Image + +from diffsynth_engine.pipelines.wan import WanImageToVideoPipeline +from diffsynth_engine.utils.download import fetch_model + +if __name__ == "__main__": + model_path = fetch_model("Wan-AI/Wan2.2-I2V-A14B-Diffusers") + pipe = WanImageToVideoPipeline.from_pretrained(model_path) + + image = Image.open("examples/input/wan_22_i2v_input.png") + max_area = 480 * 832 + aspect_ratio = image.height / image.width + mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + image = image.resize((width, height)) + + prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." + negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" + + video = pipe( + image=image, + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=81, + guidance_scale=3.5, + num_inference_steps=40, + generator=torch.Generator(device="cpu").manual_seed(42), + ) + + export_to_video(video.frames[0], "wan_22_i2v.mp4", fps=16) diff --git a/examples/wan/wan_22_text_to_video.py b/examples/wan/wan_22_text_to_video.py new file mode 100644 index 0000000..9e62928 --- /dev/null +++ b/examples/wan/wan_22_text_to_video.py @@ -0,0 +1,23 @@ +import torch +from diffusers.utils import export_to_video + +from diffsynth_engine.pipelines.wan import WanTextToVideoPipeline +from diffsynth_engine.utils.download import fetch_model + +if __name__ == "__main__": + model_path = fetch_model("Wan-AI/Wan2.2-T2V-A14B-Diffusers") + pipe = WanTextToVideoPipeline.from_pretrained(model_path) + + video = pipe( + prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + num_frames=81, + width=1280, + height=720, + guidance_scale=4.0, + guidance_scale_2=3.0, + num_inference_steps=40, + generator=torch.Generator(device="cpu").manual_seed(42), + ) + + export_to_video(video.frames[0], "wan_22_t2v.mp4", fps=16) diff --git a/examples/wan/wan_vace.py b/examples/wan/wan_vace.py new file mode 100644 index 0000000..eaf4648 --- /dev/null +++ b/examples/wan/wan_vace.py @@ -0,0 +1,73 @@ +import PIL.Image +import torch +from diffusers.schedulers import UniPCMultistepScheduler +from diffusers.utils import export_to_video, load_image + +from diffsynth_engine.pipelines.wan import WanVACEPipeline +from diffsynth_engine.utils.download import fetch_model + + +def prepare_video_and_mask( + first_img: PIL.Image.Image, + last_img: PIL.Image.Image, + height: int, + width: int, + num_frames: int, +): + first_img = first_img.resize((width, height)) + last_img = last_img.resize((width, height)) + frames = [first_img] + frames.extend([PIL.Image.new("RGB", (width, height), (128, 128, 128))] * (num_frames - 2)) + frames.append(last_img) + mask_black = PIL.Image.new("L", (width, height), 0) + mask_white = PIL.Image.new("L", (width, height), 255) + mask = [mask_black, *[mask_white] * (num_frames - 2), mask_black] + return frames, mask + + +if __name__ == "__main__": + model_path = fetch_model("Wan-AI/Wan2.1-VACE-14B-diffusers") + pipe = WanVACEPipeline.from_pretrained(model_path) + + # Set flow_shift to 5.0 for 720P (use 3.0 for 480P) + flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P + pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift) + + # Load the first and last frame images + first_frame = load_image("examples/input/wan_vace_first_frame.png") + last_frame = load_image("examples/input/wan_vace_last_frame.png") + + prompt = ( + "CG animation style, a small blue bird takes off from the ground, flapping its wings. " + "The bird's feathers are delicate, with a unique pattern on its chest. " + "The background shows a blue sky with white clouds under bright sunshine. " + "The camera follows the bird upward, capturing its flight and the vastness of the sky " + "from a close-up, low-angle perspective." + ) + negative_prompt = ( + "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, " + "images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, " + "incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, " + "misshapen limbs, fused fingers, still picture, messy background, three legs, many people " + "in the background, walking backwards" + ) + + height = 512 + width = 512 + num_frames = 81 + video, mask = prepare_video_and_mask(first_frame, last_frame, height, width, num_frames) + + output = pipe( + video=video, + mask=mask, + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=30, + guidance_scale=5.0, + generator=torch.Generator().manual_seed(42), + ) + + export_to_video(output.frames[0], "wan_vace_output.mp4", fps=16) diff --git a/tests/common/test_case.py b/tests/common/test_case.py index aa97998..13528ff 100644 --- a/tests/common/test_case.py +++ b/tests/common/test_case.py @@ -9,7 +9,7 @@ from diffsynth_engine.utils.load_utils import load_file from diffsynth_engine.utils.video import VideoReader, load_video, save_video -from tests.common.utils import compute_normalized_ssim +from tests.common.utils import compute_normalized_ssim, compute_video_ms_ssim TEST_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # test flags @@ -109,3 +109,22 @@ def assertVideoEqualAndSaveFailed( name = expect_video_path.split("/")[-1] self.save_video(input_video, name, fps=fps) raise e + + def assertVideoMsSsimEqual(self, input_video: List[Image.Image], expect_video: List[Image.Image], threshold=0.95): + ms_ssim_score = compute_video_ms_ssim(input_video, expect_video) + self.assertGreaterEqual(ms_ssim_score, threshold) + + def assertVideoMsSsimEqualAndSaveFailed( + self, input_video: List[Image.Image], expect_video_path: str, threshold=0.95, fps: int = 15 + ): + """ + 比较input_video和testdata/expect/{name}的MS-SSIM相似度,如果失败则保存input_video到当前工作目录 + """ + try: + expect_video = self.get_expect_video(expect_video_path) + expect_frames = [expect_video[i] for i in range(len(expect_video))] + self.assertVideoMsSsimEqual(input_video, expect_frames, threshold=threshold) + except Exception as e: + name = expect_video_path.split("/")[-1] + self.save_video(input_video, name, fps=fps) + raise e diff --git a/tests/common/utils.py b/tests/common/utils.py index f4d5ba3..00f2e01 100644 --- a/tests/common/utils.py +++ b/tests/common/utils.py @@ -1,4 +1,7 @@ +from typing import List + import numpy as np +import torch from PIL import Image from skimage.metrics import structural_similarity @@ -14,3 +17,41 @@ def compute_normalized_ssim(image1: Image.Image, image2: Image.Image): ssim_normalized = (ssim + 1) / 2 return ssim_normalized + + +def compute_video_ms_ssim( + input_frames: List[Image.Image], + expect_frames: List[Image.Image], +) -> float: + """Compute the mean MS-SSIM score between two frame sequences. + + Each frame is converted to a ``[1, C, H, W]`` float tensor in ``[0, 1]`` + and scored with ``MultiScaleStructuralSimilarityIndexMeasure``. The + returned value is the average MS-SSIM across all frame pairs. + """ + from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure + + ms_ssim_metric = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0) + + scores: List[float] = [] + for pred_frame, target_frame in zip(input_frames, expect_frames): + pred_array = np.array(pred_frame).astype(np.float32) + target_array = np.array(target_frame).astype(np.float32) + + # Normalize to [0, 1]: only divide by 255 when the data is in uint8 range + if pred_array.max() > 1.0: + pred_array = pred_array / 255.0 + if target_array.max() > 1.0: + target_array = target_array / 255.0 + + pred_tensor = torch.from_numpy(pred_array) + target_tensor = torch.from_numpy(target_array) + + # [H, W, C] -> [1, C, H, W] + pred_tensor = pred_tensor.permute(2, 0, 1).unsqueeze(0) + target_tensor = target_tensor.permute(2, 0, 1).unsqueeze(0) + + score = ms_ssim_metric(pred_tensor, target_tensor) + scores.append(score.item()) + + return float(np.mean(scores)) diff --git a/tests/data/expect/wan/wan_22_animate.mp4 b/tests/data/expect/wan/wan_22_animate.mp4 new file mode 100644 index 0000000..6f5e809 Binary files /dev/null and b/tests/data/expect/wan/wan_22_animate.mp4 differ diff --git a/tests/data/expect/wan/wan_22_i2v.mp4 b/tests/data/expect/wan/wan_22_i2v.mp4 new file mode 100644 index 0000000..4c064a3 Binary files /dev/null and b/tests/data/expect/wan/wan_22_i2v.mp4 differ diff --git a/tests/data/expect/wan/wan_22_t2v.mp4 b/tests/data/expect/wan/wan_22_t2v.mp4 new file mode 100644 index 0000000..976ded2 Binary files /dev/null and b/tests/data/expect/wan/wan_22_t2v.mp4 differ diff --git a/tests/data/expect/wan/wan_vace.mp4 b/tests/data/expect/wan/wan_vace.mp4 new file mode 100644 index 0000000..94779cc Binary files /dev/null and b/tests/data/expect/wan/wan_vace.mp4 differ diff --git a/tests/data/input/wan_22_animate_face.mp4 b/tests/data/input/wan_22_animate_face.mp4 new file mode 100644 index 0000000..622d6df Binary files /dev/null and b/tests/data/input/wan_22_animate_face.mp4 differ diff --git a/tests/data/input/wan_22_animate_input.png b/tests/data/input/wan_22_animate_input.png new file mode 100644 index 0000000..a1b85f7 Binary files /dev/null and b/tests/data/input/wan_22_animate_input.png differ diff --git a/tests/data/input/wan_22_animate_pose.mp4 b/tests/data/input/wan_22_animate_pose.mp4 new file mode 100644 index 0000000..af36f5e Binary files /dev/null and b/tests/data/input/wan_22_animate_pose.mp4 differ diff --git a/tests/data/input/wan_22_i2v_input.png b/tests/data/input/wan_22_i2v_input.png new file mode 100644 index 0000000..6a558eb Binary files /dev/null and b/tests/data/input/wan_22_i2v_input.png differ diff --git a/tests/data/input/wan_vace_first_frame.png b/tests/data/input/wan_vace_first_frame.png new file mode 100644 index 0000000..032cd5c Binary files /dev/null and b/tests/data/input/wan_vace_first_frame.png differ diff --git a/tests/data/input/wan_vace_last_frame.png b/tests/data/input/wan_vace_last_frame.png new file mode 100644 index 0000000..83ac8c5 Binary files /dev/null and b/tests/data/input/wan_vace_last_frame.png differ diff --git a/tests/test_pipelines/test_wan_21_vace.py b/tests/test_pipelines/test_wan_21_vace.py new file mode 100644 index 0000000..c795942 --- /dev/null +++ b/tests/test_pipelines/test_wan_21_vace.py @@ -0,0 +1,84 @@ +import unittest + +import PIL.Image +import torch +from diffusers.schedulers import UniPCMultistepScheduler + +from diffsynth_engine.pipelines.wan import WanVACEPipeline +from diffsynth_engine.utils.download import fetch_model +from tests.common.test_case import VideoTestCase + + +def prepare_video_and_mask( + first_img: PIL.Image.Image, + last_img: PIL.Image.Image, + height: int, + width: int, + num_frames: int, +): + first_img = first_img.resize((width, height)) + last_img = last_img.resize((width, height)) + frames = [first_img] + frames.extend([PIL.Image.new("RGB", (width, height), (128, 128, 128))] * (num_frames - 2)) + frames.append(last_img) + mask_black = PIL.Image.new("L", (width, height), 0) + mask_white = PIL.Image.new("L", (width, height), 255) + mask = [mask_black, *[mask_white] * (num_frames - 2), mask_black] + return frames, mask + + +class TestWanVACEPipeline(VideoTestCase): + @classmethod + def setUpClass(cls): + model_path = fetch_model("Wan-AI/Wan2.1-VACE-14B-diffusers") + cls.pipe = WanVACEPipeline.from_pretrained(model_path) + flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P + cls.pipe.scheduler = UniPCMultistepScheduler.from_config(cls.pipe.scheduler.config, flow_shift=flow_shift) + + @classmethod + def tearDownClass(cls): + del cls.pipe + + def test_vace(self): + first_frame = self.get_input_image("wan_vace_first_frame.png") + last_frame = self.get_input_image("wan_vace_last_frame.png") + + prompt = ( + "CG animation style, a small blue bird takes off from the ground, flapping its wings. " + "The bird's feathers are delicate, with a unique pattern on its chest. " + "The background shows a blue sky with white clouds under bright sunshine. " + "The camera follows the bird upward, capturing its flight and the vastness of the sky " + "from a close-up, low-angle perspective." + ) + negative_prompt = ( + "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, " + "images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, " + "incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, " + "misshapen limbs, fused fingers, still picture, messy background, three legs, many people " + "in the background, walking backwards" + ) + + height = 512 + width = 512 + num_frames = 81 + video, mask = prepare_video_and_mask(first_frame, last_frame, height, width, num_frames) + + result = self.pipe( + video=video, + mask=mask, + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=30, + guidance_scale=5.0, + generator=torch.Generator(device="cpu").manual_seed(42), + ) + + output_frames = result.frames[0] + self.assertVideoMsSsimEqualAndSaveFailed(output_frames, "wan/wan_vace.mp4", threshold=0.93, fps=16) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_pipelines/test_wan_22_animate.py b/tests/test_pipelines/test_wan_22_animate.py new file mode 100644 index 0000000..b5f9f6e --- /dev/null +++ b/tests/test_pipelines/test_wan_22_animate.py @@ -0,0 +1,48 @@ +import unittest + +import torch + +from diffsynth_engine.pipelines.wan import WanAnimatePipeline +from diffsynth_engine.utils.download import fetch_model +from tests.common.test_case import VideoTestCase + + +class TestWan22AnimatePipeline(VideoTestCase): + @classmethod + def setUpClass(cls): + model_path = fetch_model("Wan-AI/Wan2.2-Animate-14B-Diffusers") + cls.pipe = WanAnimatePipeline.from_pretrained(model_path) + + @classmethod + def tearDownClass(cls): + del cls.pipe + + def test_animate(self): + image = self.get_input_image("wan_22_animate_input.png") + + pose_video_reader = self.get_input_video("wan_22_animate_pose.mp4") + face_video_reader = self.get_input_video("wan_22_animate_face.mp4") + pose_video = [pose_video_reader[i] for i in range(len(pose_video_reader))] + face_video = [face_video_reader[i] for i in range(len(face_video_reader))] + + prompt = "People in the video are doing actions." + + video = self.pipe( + image=image, + pose_video=pose_video, + face_video=face_video, + prompt=prompt, + mode="animate", + segment_frame_length=77, + prev_segment_conditioning_frames=1, + guidance_scale=1.0, + num_inference_steps=20, + generator=torch.Generator(device="cpu").manual_seed(42), + ) + + output_frames = video.frames[0] + self.assertVideoMsSsimEqualAndSaveFailed(output_frames, "wan/wan_22_animate.mp4", threshold=0.98, fps=30) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_pipelines/test_wan_22_image_to_video.py b/tests/test_pipelines/test_wan_22_image_to_video.py new file mode 100644 index 0000000..86d0960 --- /dev/null +++ b/tests/test_pipelines/test_wan_22_image_to_video.py @@ -0,0 +1,50 @@ +import unittest + +import numpy as np +import torch + +from diffsynth_engine.pipelines.wan import WanImageToVideoPipeline +from diffsynth_engine.utils.download import fetch_model +from tests.common.test_case import VideoTestCase + + +class TestWan22ImageToVideoPipeline(VideoTestCase): + @classmethod + def setUpClass(cls): + model_path = fetch_model("Wan-AI/Wan2.2-I2V-A14B-Diffusers") + cls.pipe = WanImageToVideoPipeline.from_pretrained(model_path) + + @classmethod + def tearDownClass(cls): + del cls.pipe + + def test_image_to_video(self): + image = self.get_input_image("wan_22_i2v_input.png") + max_area = 480 * 832 + aspect_ratio = image.height / image.width + mod_value = self.pipe.vae_scale_factor_spatial * self.pipe.transformer.config.patch_size[1] + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + image = image.resize((width, height)) + + prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." + negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" + + video = self.pipe( + image=image, + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=81, + guidance_scale=3.5, + num_inference_steps=40, + generator=torch.Generator(device="cpu").manual_seed(42), + ) + + output_frames = video.frames[0] + self.assertVideoMsSsimEqualAndSaveFailed(output_frames, "wan/wan_22_i2v.mp4", threshold=0.98, fps=16) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_pipelines/test_wan_22_text_to_video.py b/tests/test_pipelines/test_wan_22_text_to_video.py new file mode 100644 index 0000000..1b5a962 --- /dev/null +++ b/tests/test_pipelines/test_wan_22_text_to_video.py @@ -0,0 +1,43 @@ +import unittest + +import torch + +from diffsynth_engine.pipelines.wan import WanTextToVideoPipeline +from diffsynth_engine.utils.download import fetch_model +from tests.common.test_case import VideoTestCase + + +class TestWan22TextToVideoPipeline(VideoTestCase): + @classmethod + def setUpClass(cls): + model_path = fetch_model("Wan-AI/Wan2.2-T2V-A14B-Diffusers") + cls.pipe = WanTextToVideoPipeline.from_pretrained(model_path) + + @classmethod + def tearDownClass(cls): + del cls.pipe + + def test_text_to_video(self): + prompt = ( + "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." + ) + negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" + + video = self.pipe( + prompt=prompt, + negative_prompt=negative_prompt, + num_frames=81, + width=1280, + height=720, + guidance_scale=4.0, + guidance_scale_2=3.0, + num_inference_steps=40, + generator=torch.Generator(device="cpu").manual_seed(42), + ) + + output_frames = video.frames[0] + self.assertVideoMsSsimEqualAndSaveFailed(output_frames, "wan/wan_22_t2v.mp4", threshold=0.98, fps=16) + + +if __name__ == "__main__": + unittest.main()