mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-31 21:06:45 +08:00
Compare commits
3 Commits
fix-torcha
...
autoencode
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
119daffdf1 | ||
|
|
c346ad5eb8 | ||
|
|
e1e7d58a4a |
@@ -862,23 +862,23 @@ def _native_attention_backward_op(
|
||||
key.requires_grad_(True)
|
||||
value.requires_grad_(True)
|
||||
|
||||
query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query_t,
|
||||
key=key_t,
|
||||
value=value_t,
|
||||
attn_mask=ctx.attn_mask,
|
||||
dropout_p=ctx.dropout_p,
|
||||
is_causal=ctx.is_causal,
|
||||
scale=ctx.scale,
|
||||
enable_gqa=ctx.enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
with torch.enable_grad():
|
||||
query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query_t,
|
||||
key=key_t,
|
||||
value=value_t,
|
||||
attn_mask=ctx.attn_mask,
|
||||
dropout_p=ctx.dropout_p,
|
||||
is_causal=ctx.is_causal,
|
||||
scale=ctx.scale,
|
||||
enable_gqa=ctx.enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
|
||||
grad_out_t = grad_out.permute(0, 2, 1, 3)
|
||||
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
|
||||
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False
|
||||
)
|
||||
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
|
||||
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out, retain_graph=False
|
||||
)
|
||||
|
||||
grad_query = grad_query_t.permute(0, 2, 1, 3)
|
||||
grad_key = grad_key_t.permute(0, 2, 1, 3)
|
||||
|
||||
@@ -13,24 +13,29 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import AutoencoderDC
|
||||
|
||||
from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, floats_tensor, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, torch_device
|
||||
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderDCTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderDC
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
class AutoencoderDCTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return AutoencoderDC
|
||||
|
||||
def get_autoencoder_dc_config(self):
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
"in_channels": 3,
|
||||
"latent_channels": 4,
|
||||
@@ -56,33 +61,34 @@ class AutoencoderDCTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.Test
|
||||
"scaling_factor": 0.41407,
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
torch.manual_seed(seed)
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
|
||||
image = torch.randn(batch_size, num_channels, *sizes).to(torch_device)
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
# Bridge for AutoencoderTesterMixin which still uses the old interface
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_dc_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
return self.get_init_dict(), self.get_dummy_inputs()
|
||||
|
||||
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
|
||||
def test_layerwise_casting_inference(self):
|
||||
super().test_layerwise_casting_inference()
|
||||
|
||||
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
|
||||
class TestAutoencoderDC(AutoencoderDCTesterConfig, ModelTesterMixin):
|
||||
base_precision = 1e-2
|
||||
|
||||
|
||||
class TestAutoencoderDCTraining(AutoencoderDCTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for AutoencoderDC."""
|
||||
|
||||
|
||||
class TestAutoencoderDCMemory(AutoencoderDCTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for AutoencoderDC."""
|
||||
|
||||
@pytest.mark.skipif(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
|
||||
def test_layerwise_casting_memory(self):
|
||||
super().test_layerwise_casting_memory()
|
||||
|
||||
|
||||
class TestAutoencoderDCSlicingTiling(AutoencoderDCTesterConfig, AutoencoderTesterMixin):
|
||||
"""Slicing and tiling tests for AutoencoderDC."""
|
||||
|
||||
@@ -98,6 +98,64 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def _context_parallel_backward_worker(
|
||||
rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict
|
||||
):
|
||||
"""Worker function for context parallel backward pass testing."""
|
||||
try:
|
||||
# Set up distributed environment
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(master_port)
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
|
||||
# Get device configuration
|
||||
device_config = DEVICE_CONFIG.get(torch_device, DEVICE_CONFIG["cuda"])
|
||||
backend = device_config["backend"]
|
||||
device_module = device_config["module"]
|
||||
|
||||
# Initialize process group
|
||||
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
|
||||
|
||||
# Set device for this process
|
||||
device_module.set_device(rank)
|
||||
device = torch.device(f"{torch_device}:{rank}")
|
||||
|
||||
# Create model in training mode
|
||||
model = model_class(**init_dict)
|
||||
model.to(device)
|
||||
model.train()
|
||||
|
||||
# Move inputs to device
|
||||
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
|
||||
|
||||
# Enable context parallelism
|
||||
cp_config = ContextParallelConfig(**cp_dict)
|
||||
model.enable_parallelism(config=cp_config)
|
||||
|
||||
# Run forward and backward pass
|
||||
output = model(**inputs_on_device, return_dict=False)[0]
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
|
||||
# Check that backward actually produced at least one valid gradient
|
||||
grads = [p.grad for p in model.parameters() if p.requires_grad and p.grad is not None]
|
||||
has_valid_grads = len(grads) > 0 and all(torch.isfinite(g).all() for g in grads)
|
||||
|
||||
# Only rank 0 reports results
|
||||
if rank == 0:
|
||||
return_dict["status"] = "success"
|
||||
return_dict["has_valid_grads"] = bool(has_valid_grads)
|
||||
|
||||
except Exception as e:
|
||||
if rank == 0:
|
||||
return_dict["status"] = "error"
|
||||
return_dict["error"] = str(e)
|
||||
finally:
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def _custom_mesh_worker(
|
||||
rank,
|
||||
world_size,
|
||||
@@ -204,6 +262,51 @@ class ContextParallelTesterMixin:
|
||||
def test_context_parallel_batch_inputs(self, cp_type):
|
||||
self.test_context_parallel_inference(cp_type, batch_size=2)
|
||||
|
||||
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
|
||||
def test_context_parallel_backward(self, cp_type, batch_size: int = 1):
|
||||
if not torch.distributed.is_available():
|
||||
pytest.skip("torch.distributed is not available.")
|
||||
|
||||
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
|
||||
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
|
||||
|
||||
if cp_type == "ring_degree":
|
||||
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
if active_backend == AttentionBackendName.NATIVE:
|
||||
pytest.skip("Ring attention is not supported with the native attention backend.")
|
||||
|
||||
world_size = 2
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs(batch_size=batch_size)
|
||||
|
||||
# Move all tensors to CPU for multiprocessing
|
||||
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
|
||||
cp_dict = {cp_type: world_size}
|
||||
|
||||
# Find a free port for distributed communication
|
||||
master_port = _find_free_port()
|
||||
|
||||
# Use multiprocessing manager for cross-process communication
|
||||
manager = mp.Manager()
|
||||
return_dict = manager.dict()
|
||||
|
||||
# Spawn worker processes
|
||||
mp.spawn(
|
||||
_context_parallel_backward_worker,
|
||||
args=(world_size, master_port, self.model_class, init_dict, cp_dict, inputs_dict, return_dict),
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
assert return_dict.get("status") == "success", (
|
||||
f"Context parallel backward pass failed: {return_dict.get('error', 'Unknown error')}"
|
||||
)
|
||||
assert return_dict.get("has_valid_grads"), "Context parallel backward pass did not produce valid gradients."
|
||||
|
||||
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
|
||||
def test_context_parallel_backward_batch_inputs(self, cp_type):
|
||||
self.test_context_parallel_backward(cp_type, batch_size=2)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cp_type,mesh_shape,mesh_dim_names",
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user