mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-01 06:10:36 +08:00
Compare commits
3 Commits
save-autom
...
custom-dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
96f97f8214 | ||
|
|
f8b95ff263 | ||
|
|
cfff46069d |
@@ -97,32 +97,5 @@ If the custom model inherits from the [`ModelMixin`] class, it gets access to th
|
||||
> )
|
||||
> ```
|
||||
|
||||
### Saving custom models
|
||||
|
||||
Use [`~ConfigMixin.register_for_auto_class`] to add the `auto_map` entry to `config.json` automatically when saving. This avoids having to manually edit the config file.
|
||||
|
||||
```py
|
||||
# my_model.py
|
||||
from diffusers import ModelMixin, ConfigMixin
|
||||
|
||||
class MyCustomModel(ModelMixin, ConfigMixin):
|
||||
...
|
||||
|
||||
MyCustomModel.register_for_auto_class("AutoModel")
|
||||
|
||||
model = MyCustomModel(...)
|
||||
model.save_pretrained("./my_model")
|
||||
```
|
||||
|
||||
The saved `config.json` will include the `auto_map` field.
|
||||
|
||||
```json
|
||||
{
|
||||
"auto_map": {
|
||||
"AutoModel": "my_model.MyCustomModel"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> Learn more about implementing custom models in the [Community components](../using-diffusers/custom_pipeline_overview#community-components) guide.
|
||||
@@ -107,38 +107,6 @@ class ConfigMixin:
|
||||
has_compatibles = False
|
||||
|
||||
_deprecated_kwargs = []
|
||||
_auto_class = None
|
||||
|
||||
@classmethod
|
||||
def register_for_auto_class(cls, auto_class="AutoModel"):
|
||||
"""
|
||||
Register this class with the given auto class so that it can be loaded with `AutoModel.from_pretrained(...,
|
||||
trust_remote_code=True)`.
|
||||
|
||||
When the config is saved, the resulting `config.json` will include an `auto_map` entry mapping the auto class
|
||||
to this class's module and class name.
|
||||
|
||||
Args:
|
||||
auto_class (`str` or type, *optional*, defaults to `"AutoModel"`):
|
||||
The auto class to register this class with. Can be a string (e.g. `"AutoModel"`) or the class itself.
|
||||
Currently only `"AutoModel"` is supported.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
from diffusers import ModelMixin, ConfigMixin
|
||||
|
||||
|
||||
class MyCustomModel(ModelMixin, ConfigMixin): ...
|
||||
|
||||
|
||||
MyCustomModel.register_for_auto_class("AutoModel")
|
||||
```
|
||||
"""
|
||||
if auto_class != "AutoModel":
|
||||
raise ValueError(f"Only 'AutoModel' is supported, got '{auto_class}'.")
|
||||
|
||||
cls._auto_class = auto_class
|
||||
|
||||
def register_to_config(self, **kwargs):
|
||||
if self.config_name is None:
|
||||
@@ -653,12 +621,6 @@ class ConfigMixin:
|
||||
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
|
||||
_ = config_dict.pop("_pre_quantization_dtype", None)
|
||||
|
||||
if getattr(self, "_auto_class", None) is not None:
|
||||
module = self.__class__.__module__.split(".")[-1]
|
||||
auto_map = config_dict.get("auto_map", {})
|
||||
auto_map[self._auto_class] = f"{module}.{self.__class__.__name__}"
|
||||
config_dict["auto_map"] = auto_map
|
||||
|
||||
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path: str | os.PathLike):
|
||||
|
||||
@@ -60,6 +60,16 @@ class ContextParallelConfig:
|
||||
rotate_method (`str`, *optional*, defaults to `"allgather"`):
|
||||
Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
|
||||
is supported.
|
||||
ulysses_anything (`bool`, *optional*, defaults to `False`):
|
||||
Whether to enable "Ulysses Anything" mode, which supports arbitrary sequence lengths and head counts that
|
||||
are not evenly divisible by `ulysses_degree`. When enabled, `ulysses_degree` must be greater than 1 and
|
||||
`ring_degree` must be 1.
|
||||
mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
|
||||
A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of
|
||||
creating a new one. This is useful when combining context parallelism with other parallelism strategies
|
||||
(e.g., FSDP, tensor parallelism) that share the same device mesh. The mesh must have both "ring" and
|
||||
"ulysses" dimensions. Use size 1 for dimensions not being used (e.g., `mesh_shape=(2, 1, 4)` with
|
||||
`mesh_dim_names=("ring", "ulysses", "fsdp")` for ring attention only with FSDP).
|
||||
|
||||
"""
|
||||
|
||||
@@ -68,6 +78,7 @@ class ContextParallelConfig:
|
||||
convert_to_fp32: bool = True
|
||||
# TODO: support alltoall
|
||||
rotate_method: Literal["allgather", "alltoall"] = "allgather"
|
||||
mesh: torch.distributed.device_mesh.DeviceMesh | None = None
|
||||
# Whether to enable ulysses anything attention to support
|
||||
# any sequence lengths and any head numbers.
|
||||
ulysses_anything: bool = False
|
||||
@@ -124,7 +135,7 @@ class ContextParallelConfig:
|
||||
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
|
||||
)
|
||||
|
||||
self._flattened_mesh = self._mesh._flatten()
|
||||
self._flattened_mesh = self._mesh["ring", "ulysses"]._flatten()
|
||||
self._ring_mesh = self._mesh["ring"]
|
||||
self._ulysses_mesh = self._mesh["ulysses"]
|
||||
self._ring_local_rank = self._ring_mesh.get_local_rank()
|
||||
|
||||
@@ -62,8 +62,6 @@ _REQUIRED_FLEX_VERSION = "2.5.0"
|
||||
_REQUIRED_XLA_VERSION = "2.2"
|
||||
_REQUIRED_XFORMERS_VERSION = "0.0.29"
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
|
||||
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
|
||||
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
|
||||
@@ -75,18 +73,8 @@ _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _
|
||||
|
||||
|
||||
if _CAN_USE_FLASH_ATTN:
|
||||
try:
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
# Handle ABI mismatch or other import failures gracefully.
|
||||
# This can happen when flash_attn was compiled against a different PyTorch version.
|
||||
logger.warning(f"flash_attn is installed but failed to import: {e}. Falling back to native PyTorch attention.")
|
||||
_CAN_USE_FLASH_ATTN = False
|
||||
flash_attn_func = None
|
||||
flash_attn_varlen_func = None
|
||||
_wrapped_flash_attn_backward = None
|
||||
_wrapped_flash_attn_forward = None
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
|
||||
else:
|
||||
flash_attn_func = None
|
||||
flash_attn_varlen_func = None
|
||||
@@ -95,47 +83,26 @@ else:
|
||||
|
||||
|
||||
if _CAN_USE_FLASH_ATTN_3:
|
||||
try:
|
||||
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
||||
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"flash_attn_3 failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_FLASH_ATTN_3 = False
|
||||
flash_attn_3_func = None
|
||||
flash_attn_3_varlen_func = None
|
||||
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
||||
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
||||
else:
|
||||
flash_attn_3_func = None
|
||||
flash_attn_3_varlen_func = None
|
||||
|
||||
if _CAN_USE_AITER_ATTN:
|
||||
try:
|
||||
from aiter import flash_attn_func as aiter_flash_attn_func
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"aiter failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_AITER_ATTN = False
|
||||
aiter_flash_attn_func = None
|
||||
from aiter import flash_attn_func as aiter_flash_attn_func
|
||||
else:
|
||||
aiter_flash_attn_func = None
|
||||
|
||||
if _CAN_USE_SAGE_ATTN:
|
||||
try:
|
||||
from sageattention import (
|
||||
sageattn,
|
||||
sageattn_qk_int8_pv_fp8_cuda,
|
||||
sageattn_qk_int8_pv_fp8_cuda_sm90,
|
||||
sageattn_qk_int8_pv_fp16_cuda,
|
||||
sageattn_qk_int8_pv_fp16_triton,
|
||||
sageattn_varlen,
|
||||
)
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"sageattention failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_SAGE_ATTN = False
|
||||
sageattn = None
|
||||
sageattn_qk_int8_pv_fp8_cuda = None
|
||||
sageattn_qk_int8_pv_fp8_cuda_sm90 = None
|
||||
sageattn_qk_int8_pv_fp16_cuda = None
|
||||
sageattn_qk_int8_pv_fp16_triton = None
|
||||
sageattn_varlen = None
|
||||
from sageattention import (
|
||||
sageattn,
|
||||
sageattn_qk_int8_pv_fp8_cuda,
|
||||
sageattn_qk_int8_pv_fp8_cuda_sm90,
|
||||
sageattn_qk_int8_pv_fp16_cuda,
|
||||
sageattn_qk_int8_pv_fp16_triton,
|
||||
sageattn_varlen,
|
||||
)
|
||||
else:
|
||||
sageattn = None
|
||||
sageattn_qk_int8_pv_fp16_cuda = None
|
||||
@@ -146,48 +113,26 @@ else:
|
||||
|
||||
|
||||
if _CAN_USE_FLEX_ATTN:
|
||||
try:
|
||||
# We cannot import the flex_attention function from the package directly because it is expected (from the
|
||||
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
|
||||
# compiled function.
|
||||
import torch.nn.attention.flex_attention as flex_attention
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"flex_attention failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_FLEX_ATTN = False
|
||||
flex_attention = None
|
||||
else:
|
||||
flex_attention = None
|
||||
# We cannot import the flex_attention function from the package directly because it is expected (from the
|
||||
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
|
||||
# compiled function.
|
||||
import torch.nn.attention.flex_attention as flex_attention
|
||||
|
||||
|
||||
if _CAN_USE_NPU_ATTN:
|
||||
try:
|
||||
from torch_npu import npu_fusion_attention
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"torch_npu failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_NPU_ATTN = False
|
||||
npu_fusion_attention = None
|
||||
from torch_npu import npu_fusion_attention
|
||||
else:
|
||||
npu_fusion_attention = None
|
||||
|
||||
|
||||
if _CAN_USE_XLA_ATTN:
|
||||
try:
|
||||
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"torch_xla failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_XLA_ATTN = False
|
||||
xla_flash_attention = None
|
||||
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
|
||||
else:
|
||||
xla_flash_attention = None
|
||||
|
||||
|
||||
if _CAN_USE_XFORMERS_ATTN:
|
||||
try:
|
||||
import xformers.ops as xops
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"xformers failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_XFORMERS_ATTN = False
|
||||
xops = None
|
||||
import xformers.ops as xops
|
||||
else:
|
||||
xops = None
|
||||
|
||||
@@ -213,6 +158,8 @@ else:
|
||||
_register_fake = register_fake_no_op
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# TODO(aryan): Add support for the following:
|
||||
# - Sage Attention++
|
||||
# - block sparse, radial and other attention methods
|
||||
|
||||
@@ -1567,7 +1567,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
mesh = None
|
||||
if config.context_parallel_config is not None:
|
||||
cp_config = config.context_parallel_config
|
||||
mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||
mesh = cp_config.mesh or torch.distributed.device_mesh.init_device_mesh(
|
||||
device_type=device_type,
|
||||
mesh_shape=cp_config.mesh_shape,
|
||||
mesh_dim_names=cp_config.mesh_dim_names,
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from transformers import CLIPTextModel, LongformerModel
|
||||
|
||||
from diffusers import ConfigMixin
|
||||
from diffusers.models import AutoModel, UNet2DConditionModel
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
class TestAutoModel(unittest.TestCase):
|
||||
@@ -105,51 +100,3 @@ class TestAutoModelFromConfig(unittest.TestCase):
|
||||
def test_from_config_raises_on_none(self):
|
||||
with self.assertRaises(ValueError, msg="Please provide a `pretrained_model_name_or_path_or_dict`"):
|
||||
AutoModel.from_config(None)
|
||||
|
||||
|
||||
class TestRegisterForAutoClass(unittest.TestCase):
|
||||
def test_register_for_auto_class_sets_attribute(self):
|
||||
class DummyModel(ModelMixin, ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
DummyModel.register_for_auto_class("AutoModel")
|
||||
self.assertEqual(DummyModel._auto_class, "AutoModel")
|
||||
|
||||
def test_register_for_auto_class_rejects_unsupported(self):
|
||||
class DummyModel(ModelMixin, ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
with self.assertRaises(ValueError, msg="Only 'AutoModel' is supported"):
|
||||
DummyModel.register_for_auto_class("AutoPipeline")
|
||||
|
||||
def test_auto_map_in_saved_config(self):
|
||||
class DummyModel(ModelMixin, ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
DummyModel.register_for_auto_class("AutoModel")
|
||||
model = DummyModel()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_config(tmpdir)
|
||||
config_path = os.path.join(tmpdir, "config.json")
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
self.assertIn("auto_map", config)
|
||||
self.assertIn("AutoModel", config["auto_map"])
|
||||
module_name = DummyModel.__module__.split(".")[-1]
|
||||
self.assertEqual(config["auto_map"]["AutoModel"], f"{module_name}.DummyModel")
|
||||
|
||||
def test_no_auto_map_without_register(self):
|
||||
class DummyModel(ModelMixin, ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
model = DummyModel()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_config(tmpdir)
|
||||
config_path = os.path.join(tmpdir, "config.json")
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
self.assertNotIn("auto_map", config)
|
||||
|
||||
@@ -60,12 +60,7 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
|
||||
model.eval()
|
||||
|
||||
# Move inputs to device
|
||||
inputs_on_device = {}
|
||||
for key, value in inputs_dict.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
inputs_on_device[key] = value.to(device)
|
||||
else:
|
||||
inputs_on_device[key] = value
|
||||
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
|
||||
|
||||
# Enable context parallelism
|
||||
cp_config = ContextParallelConfig(**cp_dict)
|
||||
@@ -89,6 +84,59 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def _custom_mesh_worker(
|
||||
rank,
|
||||
world_size,
|
||||
master_port,
|
||||
model_class,
|
||||
init_dict,
|
||||
cp_dict,
|
||||
mesh_shape,
|
||||
mesh_dim_names,
|
||||
inputs_dict,
|
||||
return_dict,
|
||||
):
|
||||
"""Worker function for context parallel testing with a user-provided custom DeviceMesh."""
|
||||
try:
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(master_port)
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
|
||||
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
||||
|
||||
torch.cuda.set_device(rank)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
|
||||
model = model_class(**init_dict)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
|
||||
|
||||
# DeviceMesh must be created after init_process_group, inside each worker process.
|
||||
mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
|
||||
)
|
||||
cp_config = ContextParallelConfig(**cp_dict, mesh=mesh)
|
||||
model.enable_parallelism(config=cp_config)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_on_device, return_dict=False)[0]
|
||||
|
||||
if rank == 0:
|
||||
return_dict["status"] = "success"
|
||||
return_dict["output_shape"] = list(output.shape)
|
||||
|
||||
except Exception as e:
|
||||
if rank == 0:
|
||||
return_dict["status"] = "error"
|
||||
return_dict["error"] = str(e)
|
||||
finally:
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
@is_context_parallel
|
||||
@require_torch_multi_accelerator
|
||||
class ContextParallelTesterMixin:
|
||||
@@ -126,3 +174,48 @@ class ContextParallelTesterMixin:
|
||||
assert return_dict.get("status") == "success", (
|
||||
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cp_type,mesh_shape,mesh_dim_names",
|
||||
[
|
||||
("ring_degree", (2, 1, 1), ("ring", "ulysses", "fsdp")),
|
||||
("ulysses_degree", (1, 2, 1), ("ring", "ulysses", "fsdp")),
|
||||
],
|
||||
ids=["ring-3d-fsdp", "ulysses-3d-fsdp"],
|
||||
)
|
||||
def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names):
|
||||
if not torch.distributed.is_available():
|
||||
pytest.skip("torch.distributed is not available.")
|
||||
|
||||
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
|
||||
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
|
||||
|
||||
world_size = 2
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()}
|
||||
cp_dict = {cp_type: world_size}
|
||||
|
||||
master_port = _find_free_port()
|
||||
manager = mp.Manager()
|
||||
return_dict = manager.dict()
|
||||
|
||||
mp.spawn(
|
||||
_custom_mesh_worker,
|
||||
args=(
|
||||
world_size,
|
||||
master_port,
|
||||
self.model_class,
|
||||
init_dict,
|
||||
cp_dict,
|
||||
mesh_shape,
|
||||
mesh_dim_names,
|
||||
inputs_dict,
|
||||
return_dict,
|
||||
),
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
assert return_dict.get("status") == "success", (
|
||||
f"Custom mesh context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user