mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-30 20:37:55 +08:00
Compare commits
5 Commits
fix-torcha
...
autoencode
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
771174ac68 | ||
|
|
d37402866c | ||
|
|
e1e7d58a4a | ||
|
|
d10190b1b7 | ||
|
|
a1c3e6ccbb |
@@ -22,7 +22,7 @@ from typing import Set
|
||||
import safetensors.torch
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger, is_accelerate_available, is_torchao_available
|
||||
from ..utils import get_logger, is_accelerate_available
|
||||
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
@@ -35,54 +35,6 @@ if is_accelerate_available():
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def _is_torchao_tensor(tensor: torch.Tensor) -> bool:
|
||||
if not is_torchao_available():
|
||||
return False
|
||||
from torchao.utils import TorchAOBaseTensor
|
||||
|
||||
return isinstance(tensor, TorchAOBaseTensor)
|
||||
|
||||
|
||||
def _get_torchao_inner_tensor_names(tensor: torch.Tensor) -> list[str]:
|
||||
"""Get names of all internal tensor data attributes from a TorchAO tensor."""
|
||||
cls = type(tensor)
|
||||
names = list(getattr(cls, "tensor_data_names", []))
|
||||
for attr_name in getattr(cls, "optional_tensor_data_names", []):
|
||||
if getattr(tensor, attr_name, None) is not None:
|
||||
names.append(attr_name)
|
||||
return names
|
||||
|
||||
|
||||
def _swap_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
|
||||
"""Move a TorchAO parameter to the device of `source` via `swap_tensors`.
|
||||
|
||||
`param.data = source` does not work for `_make_wrapper_subclass` tensors because the `.data` setter only replaces
|
||||
the outer wrapper storage while leaving the subclass's internal attributes (e.g. `.qdata`, `.scale`) on the
|
||||
original device. `swap_tensors` swaps the full tensor contents in-place, preserving the parameter's identity so
|
||||
that any dict keyed by `id(param)` remains valid.
|
||||
|
||||
Refer to https://github.com/huggingface/diffusers/pull/13276#discussion_r2944471548 for the full discussion.
|
||||
"""
|
||||
torch.utils.swap_tensors(param, source)
|
||||
|
||||
|
||||
def _restore_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
|
||||
"""Restore internal tensor data of a TorchAO parameter from `source` without mutating `source`.
|
||||
|
||||
Unlike `_swap_torchao_tensor` this copies attribute references one-by-one via `setattr` so that `source` is **not**
|
||||
modified. Use this when `source` is a cached tensor that must remain unchanged (e.g. a pinned CPU copy in
|
||||
`cpu_param_dict`).
|
||||
"""
|
||||
for attr_name in _get_torchao_inner_tensor_names(source):
|
||||
setattr(param, attr_name, getattr(source, attr_name))
|
||||
|
||||
|
||||
def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None:
|
||||
"""Record stream for all internal tensors of a TorchAO parameter."""
|
||||
for attr_name in _get_torchao_inner_tensor_names(param):
|
||||
getattr(param, attr_name).record_stream(stream)
|
||||
|
||||
|
||||
# fmt: off
|
||||
_GROUP_OFFLOADING = "group_offloading"
|
||||
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
|
||||
@@ -172,13 +124,6 @@ class ModuleGroup:
|
||||
else torch.cuda
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _to_cpu(tensor, low_cpu_mem_usage):
|
||||
# For TorchAO tensors, `.data` returns an incomplete wrapper without internal attributes
|
||||
# (e.g. `.qdata`, `.scale`), so we must call `.cpu()` on the tensor directly.
|
||||
t = tensor.cpu() if _is_torchao_tensor(tensor) else tensor.data.cpu()
|
||||
return t if low_cpu_mem_usage else t.pin_memory()
|
||||
|
||||
def _init_cpu_param_dict(self):
|
||||
cpu_param_dict = {}
|
||||
if self.stream is None:
|
||||
@@ -186,15 +131,17 @@ class ModuleGroup:
|
||||
|
||||
for module in self.modules:
|
||||
for param in module.parameters():
|
||||
cpu_param_dict[param] = self._to_cpu(param, self.low_cpu_mem_usage)
|
||||
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
|
||||
for buffer in module.buffers():
|
||||
cpu_param_dict[buffer] = self._to_cpu(buffer, self.low_cpu_mem_usage)
|
||||
cpu_param_dict[buffer] = (
|
||||
buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
|
||||
)
|
||||
|
||||
for param in self.parameters:
|
||||
cpu_param_dict[param] = self._to_cpu(param, self.low_cpu_mem_usage)
|
||||
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
|
||||
|
||||
for buffer in self.buffers:
|
||||
cpu_param_dict[buffer] = self._to_cpu(buffer, self.low_cpu_mem_usage)
|
||||
cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
|
||||
|
||||
return cpu_param_dict
|
||||
|
||||
@@ -210,16 +157,9 @@ class ModuleGroup:
|
||||
pinned_dict = None
|
||||
|
||||
def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
|
||||
moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if _is_torchao_tensor(tensor):
|
||||
_swap_torchao_tensor(tensor, moved)
|
||||
else:
|
||||
tensor.data = moved
|
||||
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.record_stream:
|
||||
if _is_torchao_tensor(tensor):
|
||||
_record_stream_torchao_tensor(tensor, default_stream)
|
||||
else:
|
||||
tensor.data.record_stream(default_stream)
|
||||
tensor.data.record_stream(default_stream)
|
||||
|
||||
def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):
|
||||
for group_module in self.modules:
|
||||
@@ -238,19 +178,7 @@ class ModuleGroup:
|
||||
source = pinned_memory[buffer] if pinned_memory else buffer.data
|
||||
self._transfer_tensor_to_device(buffer, source, default_stream)
|
||||
|
||||
def _check_disk_offload_torchao(self):
|
||||
all_tensors = list(self.tensor_to_key.keys())
|
||||
has_torchao = any(_is_torchao_tensor(t) for t in all_tensors)
|
||||
if has_torchao:
|
||||
raise ValueError(
|
||||
"Disk offloading is not supported for TorchAO quantized tensors because safetensors "
|
||||
"cannot serialize TorchAO subclass tensors. Use memory offloading instead by not "
|
||||
"setting `offload_to_disk_path`."
|
||||
)
|
||||
|
||||
def _onload_from_disk(self):
|
||||
self._check_disk_offload_torchao()
|
||||
|
||||
if self.stream is not None:
|
||||
# Wait for previous Host->Device transfer to complete
|
||||
self.stream.synchronize()
|
||||
@@ -293,8 +221,6 @@ class ModuleGroup:
|
||||
self._process_tensors_from_modules(None)
|
||||
|
||||
def _offload_to_disk(self):
|
||||
self._check_disk_offload_torchao()
|
||||
|
||||
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
|
||||
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
|
||||
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
|
||||
@@ -319,35 +245,18 @@ class ModuleGroup:
|
||||
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
if _is_torchao_tensor(param):
|
||||
_restore_torchao_tensor(param, self.cpu_param_dict[param])
|
||||
else:
|
||||
param.data = self.cpu_param_dict[param]
|
||||
for param in self.parameters:
|
||||
if _is_torchao_tensor(param):
|
||||
_restore_torchao_tensor(param, self.cpu_param_dict[param])
|
||||
else:
|
||||
param.data = self.cpu_param_dict[param]
|
||||
for param in self.parameters:
|
||||
param.data = self.cpu_param_dict[param]
|
||||
for buffer in self.buffers:
|
||||
if _is_torchao_tensor(buffer):
|
||||
_restore_torchao_tensor(buffer, self.cpu_param_dict[buffer])
|
||||
else:
|
||||
buffer.data = self.cpu_param_dict[buffer]
|
||||
buffer.data = self.cpu_param_dict[buffer]
|
||||
else:
|
||||
for group_module in self.modules:
|
||||
group_module.to(self.offload_device, non_blocking=False)
|
||||
for param in self.parameters:
|
||||
if _is_torchao_tensor(param):
|
||||
moved = param.to(self.offload_device, non_blocking=False)
|
||||
_swap_torchao_tensor(param, moved)
|
||||
else:
|
||||
param.data = param.data.to(self.offload_device, non_blocking=False)
|
||||
param.data = param.data.to(self.offload_device, non_blocking=False)
|
||||
for buffer in self.buffers:
|
||||
if _is_torchao_tensor(buffer):
|
||||
moved = buffer.to(self.offload_device, non_blocking=False)
|
||||
_swap_torchao_tensor(buffer, moved)
|
||||
else:
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
|
||||
|
||||
@torch.compiler.disable()
|
||||
def onload_(self):
|
||||
|
||||
@@ -862,23 +862,23 @@ def _native_attention_backward_op(
|
||||
key.requires_grad_(True)
|
||||
value.requires_grad_(True)
|
||||
|
||||
query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query_t,
|
||||
key=key_t,
|
||||
value=value_t,
|
||||
attn_mask=ctx.attn_mask,
|
||||
dropout_p=ctx.dropout_p,
|
||||
is_causal=ctx.is_causal,
|
||||
scale=ctx.scale,
|
||||
enable_gqa=ctx.enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
with torch.enable_grad():
|
||||
query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query_t,
|
||||
key=key_t,
|
||||
value=value_t,
|
||||
attn_mask=ctx.attn_mask,
|
||||
dropout_p=ctx.dropout_p,
|
||||
is_causal=ctx.is_causal,
|
||||
scale=ctx.scale,
|
||||
enable_gqa=ctx.enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
|
||||
grad_out_t = grad_out.permute(0, 2, 1, 3)
|
||||
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
|
||||
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False
|
||||
)
|
||||
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
|
||||
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out, retain_graph=False
|
||||
)
|
||||
|
||||
grad_query = grad_query_t.permute(0, 2, 1, 3)
|
||||
grad_key = grad_key_t.permute(0, 2, 1, 3)
|
||||
|
||||
@@ -14,8 +14,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
@@ -25,7 +25,6 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_hf_numpy,
|
||||
require_torch_accelerator,
|
||||
require_torch_accelerator_with_fp16,
|
||||
@@ -35,22 +34,26 @@ from ...testing_utils import (
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKL
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
class AutoencoderKLTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return AutoencoderKL
|
||||
|
||||
def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None):
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
def get_init_dict(self, block_out_channels=None, norm_num_groups=None):
|
||||
block_out_channels = block_out_channels or [2, 4]
|
||||
norm_num_groups = norm_num_groups or 2
|
||||
init_dict = {
|
||||
return {
|
||||
"block_out_channels": block_out_channels,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
@@ -59,42 +62,32 @@ class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.Test
|
||||
"latent_channels": 4,
|
||||
"norm_num_groups": norm_num_groups,
|
||||
}
|
||||
return init_dict
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
torch.manual_seed(seed)
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
|
||||
image = torch.randn(batch_size, num_channels, *sizes).to(torch_device)
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
# Bridge for AutoencoderTesterMixin which still uses the old interface
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
return self.get_init_dict(), self.get_dummy_inputs()
|
||||
|
||||
|
||||
class TestAutoencoderKL(AutoencoderKLTesterConfig, ModelTesterMixin, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
assert model is not None
|
||||
assert len(loading_info["missing_keys"]) == 0
|
||||
|
||||
model.to(torch_device)
|
||||
image = model(**self.dummy_input)
|
||||
image = model(**self.get_dummy_inputs())
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@@ -168,17 +161,24 @@ class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.Test
|
||||
]
|
||||
)
|
||||
|
||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
|
||||
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2)
|
||||
|
||||
|
||||
class TestAutoencoderKLMemory(AutoencoderKLTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for AutoencoderKL."""
|
||||
|
||||
|
||||
class TestAutoencoderKLSlicingTiling(AutoencoderKLTesterConfig, AutoencoderTesterMixin):
|
||||
"""Slicing and tiling tests for AutoencoderKL."""
|
||||
|
||||
|
||||
@slow
|
||||
class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
class AutoencoderKLIntegrationTests:
|
||||
def get_file_format(self, seed, shape):
|
||||
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
|
||||
|
||||
def tearDown(self):
|
||||
def teardown_method(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@@ -341,10 +341,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
|
||||
@parameterized.expand([(13,), (16,), (27,)])
|
||||
@require_torch_gpu
|
||||
@unittest.skipIf(
|
||||
not is_xformers_available(),
|
||||
reason="xformers is not required when using PyTorch 2.0.",
|
||||
)
|
||||
@pytest.mark.skipif(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
|
||||
def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed):
|
||||
model = self.get_sd_vae_model(fp16=True)
|
||||
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)
|
||||
@@ -362,10 +359,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
|
||||
@parameterized.expand([(13,), (16,), (37,)])
|
||||
@require_torch_gpu
|
||||
@unittest.skipIf(
|
||||
not is_xformers_available(),
|
||||
reason="xformers is not required when using PyTorch 2.0.",
|
||||
)
|
||||
@pytest.mark.skipif(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
|
||||
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
|
||||
model = self.get_sd_vae_model()
|
||||
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
|
||||
|
||||
@@ -98,6 +98,64 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def _context_parallel_backward_worker(
|
||||
rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict
|
||||
):
|
||||
"""Worker function for context parallel backward pass testing."""
|
||||
try:
|
||||
# Set up distributed environment
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(master_port)
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
|
||||
# Get device configuration
|
||||
device_config = DEVICE_CONFIG.get(torch_device, DEVICE_CONFIG["cuda"])
|
||||
backend = device_config["backend"]
|
||||
device_module = device_config["module"]
|
||||
|
||||
# Initialize process group
|
||||
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
|
||||
|
||||
# Set device for this process
|
||||
device_module.set_device(rank)
|
||||
device = torch.device(f"{torch_device}:{rank}")
|
||||
|
||||
# Create model in training mode
|
||||
model = model_class(**init_dict)
|
||||
model.to(device)
|
||||
model.train()
|
||||
|
||||
# Move inputs to device
|
||||
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
|
||||
|
||||
# Enable context parallelism
|
||||
cp_config = ContextParallelConfig(**cp_dict)
|
||||
model.enable_parallelism(config=cp_config)
|
||||
|
||||
# Run forward and backward pass
|
||||
output = model(**inputs_on_device, return_dict=False)[0]
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
|
||||
# Check that backward actually produced at least one valid gradient
|
||||
grads = [p.grad for p in model.parameters() if p.requires_grad and p.grad is not None]
|
||||
has_valid_grads = len(grads) > 0 and all(torch.isfinite(g).all() for g in grads)
|
||||
|
||||
# Only rank 0 reports results
|
||||
if rank == 0:
|
||||
return_dict["status"] = "success"
|
||||
return_dict["has_valid_grads"] = bool(has_valid_grads)
|
||||
|
||||
except Exception as e:
|
||||
if rank == 0:
|
||||
return_dict["status"] = "error"
|
||||
return_dict["error"] = str(e)
|
||||
finally:
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def _custom_mesh_worker(
|
||||
rank,
|
||||
world_size,
|
||||
@@ -204,6 +262,51 @@ class ContextParallelTesterMixin:
|
||||
def test_context_parallel_batch_inputs(self, cp_type):
|
||||
self.test_context_parallel_inference(cp_type, batch_size=2)
|
||||
|
||||
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
|
||||
def test_context_parallel_backward(self, cp_type, batch_size: int = 1):
|
||||
if not torch.distributed.is_available():
|
||||
pytest.skip("torch.distributed is not available.")
|
||||
|
||||
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
|
||||
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
|
||||
|
||||
if cp_type == "ring_degree":
|
||||
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
if active_backend == AttentionBackendName.NATIVE:
|
||||
pytest.skip("Ring attention is not supported with the native attention backend.")
|
||||
|
||||
world_size = 2
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs(batch_size=batch_size)
|
||||
|
||||
# Move all tensors to CPU for multiprocessing
|
||||
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
|
||||
cp_dict = {cp_type: world_size}
|
||||
|
||||
# Find a free port for distributed communication
|
||||
master_port = _find_free_port()
|
||||
|
||||
# Use multiprocessing manager for cross-process communication
|
||||
manager = mp.Manager()
|
||||
return_dict = manager.dict()
|
||||
|
||||
# Spawn worker processes
|
||||
mp.spawn(
|
||||
_context_parallel_backward_worker,
|
||||
args=(world_size, master_port, self.model_class, init_dict, cp_dict, inputs_dict, return_dict),
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
assert return_dict.get("status") == "success", (
|
||||
f"Context parallel backward pass failed: {return_dict.get('error', 'Unknown error')}"
|
||||
)
|
||||
assert return_dict.get("has_valid_grads"), "Context parallel backward pass did not produce valid gradients."
|
||||
|
||||
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
|
||||
def test_context_parallel_backward_batch_inputs(self, cp_type):
|
||||
self.test_context_parallel_backward(cp_type, batch_size=2)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cp_type,mesh_shape,mesh_dim_names",
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user