Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions diffsynth_engine/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .base import PipelineConfig
from .qwen_image import QwenImagePipelineConfig
from .wan import WanPipelineConfig

__all__ = [
"PipelineConfig",
"QwenImagePipelineConfig",
"WanPipelineConfig",
]
8 changes: 8 additions & 0 deletions diffsynth_engine/configs/wan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from dataclasses import dataclass

from diffsynth_engine.configs.base import PipelineConfig


@dataclass
class WanPipelineConfig(PipelineConfig):
pass
6 changes: 4 additions & 2 deletions diffsynth_engine/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from __future__ import annotations

from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
Expand All @@ -13,7 +15,7 @@

@dataclass
class ForwardContext:
attn_metadata: Optional["AttentionMetadata"] = None
attn_metadata: Optional[AttentionMetadata] = None
attn_type: Optional[str] = None


Expand Down Expand Up @@ -42,7 +44,7 @@ def override_forward_context(forward_context: Optional[ForwardContext] = None):

@contextmanager
def set_forward_context(
attn_metadata: Optional["AttentionMetadata"] = None,
attn_metadata: Optional[AttentionMetadata] = None,
attn_type: Optional[str] = None,
):
"""A context manager to that stores the current forward context."""
Expand Down
29 changes: 27 additions & 2 deletions diffsynth_engine/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
class DiffusionModel(nn.Module, ConfigMixin):
config_name = CONFIG_NAME

_keep_in_fp32_modules: list[str] | None = None
_keys_to_ignore_on_load_unexpected: list[str] | None = None

@property
def dtype(self) -> torch.dtype:
param = next(self.parameters(), None)
Expand All @@ -37,8 +40,30 @@ def from_pretrained(
with init_empty_weights():
model = cls.from_config(config_dict)

# load model weights
state_dict = load_model_weights(model_path, subfolder, device, dtype)
# avoid precision loss
if dtype is not None and dtype != torch.float32 and cls._keep_in_fp32_modules:
state_dict = load_model_weights(model_path, subfolder, device, dtype=None)
for key in state_dict:
if any(m in key.split(".") for m in cls._keep_in_fp32_modules):
state_dict[key] = state_dict[key].to(device=device, dtype=torch.float32)
else:
state_dict[key] = state_dict[key].to(device=device, dtype=dtype)
else:
state_dict = load_model_weights(model_path, subfolder, device, dtype)

# filter unexpected keys that the model explicitly ignores
if cls._keys_to_ignore_on_load_unexpected:
keys_to_remove = [
key for key in state_dict if any(pattern in key for pattern in cls._keys_to_ignore_on_load_unexpected)
]
for key in keys_to_remove:
del state_dict[key]
if keys_to_remove:
logger.info(
f"Dropped {len(keys_to_remove)} unexpected key(s) matching "
f"{cls._keys_to_ignore_on_load_unexpected} from state_dict."
)

model.load_state_dict(state_dict, strict=True, assign=True)
model.to(device=device)
return model
Expand Down
6 changes: 6 additions & 0 deletions diffsynth_engine/models/wan/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .autoencoder_kl_wan import AutoencoderKLWan
from .transformer_wan import WanTransformer3DModel
from .transformer_wan_animate import WanAnimateTransformer3DModel
from .transformer_wan_vace import WanVACETransformer3DModel

__all__ = ["AutoencoderKLWan", "WanTransformer3DModel", "WanAnimateTransformer3DModel", "WanVACETransformer3DModel"]
Loading