Compare commits

..

4 Commits

Author SHA1 Message Date
Dhruv Nair
e90eb9de70 update 2026-02-17 11:21:51 +01:00
Sayak Paul
35086ac06a [core] support device type device_maps to work with offloading. (#12811)
* support device type device_maps to work with offloading.

* add tests.

* fix tests

* skip tests where it's not supported.

* empty

* up

* up

* fix allegro.
2026-02-16 16:31:45 +05:30
Sayak Paul
e390646f25 [tests] accept recompile_limit from the user in tests (#13150)
accept recompile_limit from the user in tests
2026-02-16 14:48:21 +05:30
Dhruv Nair
59e7a46928 [Pipelines] Remove k-diffusion (#13152)
* remove k-diffusion

* fix copies
2026-02-16 13:54:24 +05:30
28 changed files with 172 additions and 68 deletions

View File

@@ -112,7 +112,7 @@ LIBRARIES = []
for library in LOADABLE_CLASSES:
LIBRARIES.append(library)
SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device()]
SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device(), "cpu"]
logger = logging.get_logger(__name__)
@@ -468,8 +468,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
pipeline_is_sequentially_offloaded = any(
module_is_sequentially_offloaded(module) for _, module in self.components.items()
)
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
if is_pipeline_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
@@ -1188,7 +1187,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
"""
self._maybe_raise_error_if_group_offload_active(raise_error=True)
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
if is_pipeline_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`."
@@ -1312,7 +1311,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
self.remove_all_hooks()
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
if is_pipeline_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
@@ -2228,6 +2227,21 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
return True
return False
def _is_pipeline_device_mapped(self):
# We support passing `device_map="cuda"`, for example. This is helpful, in case
# users want to pass `device_map="cpu"` when initializing a pipeline. This explicit declaration is desirable
# in limited VRAM environments because quantized models often initialize directly on the accelerator.
device_map = self.hf_device_map
is_device_type_map = False
if isinstance(device_map, str):
try:
torch.device(device_map)
is_device_type_map = True
except RuntimeError:
pass
return not is_device_type_map and isinstance(device_map, dict) and len(device_map) > 1
class StableDiffusionMixin:
r"""

View File

@@ -81,7 +81,7 @@ class TorchCompileTesterMixin:
_ = model(**inputs_dict)
@torch.no_grad()
def test_torch_compile_repeated_blocks(self):
def test_torch_compile_repeated_blocks(self, recompile_limit=1):
if self.model_class._repeated_blocks is None:
pytest.skip("Skipping test as the model class doesn't have `_repeated_blocks` set.")
@@ -92,7 +92,6 @@ class TorchCompileTesterMixin:
model.eval()
model.compile_repeated_blocks(fullgraph=True)
recompile_limit = 1
if self.model_class.__name__ == "UNet2DConditionModel":
recompile_limit = 2

View File

@@ -628,6 +628,21 @@ class BitsAndBytesTesterMixin(BitsAndBytesConfigMixin, QuantizationTesterMixin):
"""Test that quantized models can be used for training with adapters."""
self._test_quantization_training(BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"])
@pytest.mark.parametrize(
"config_name",
list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()),
ids=list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()),
)
def test_cpu_device_map(self, config_name):
config_kwargs = BitsAndBytesConfigMixin.BNB_CONFIGS[config_name]
model_quantized = self._create_quantized_model(config_kwargs, device_map="cpu")
assert hasattr(model_quantized, "hf_device_map"), "Model should have hf_device_map attribute"
assert model_quantized.hf_device_map is not None, "hf_device_map should not be None"
assert model_quantized.device == torch.device("cpu"), (
f"Model should be on CPU, but is on {model_quantized.device}"
)
@is_quantization
@is_quanto

View File

@@ -147,22 +147,7 @@ class TestWanVACETransformer3DCompile(WanVACETransformer3DTesterConfig, TorchCom
def test_torch_compile_repeated_blocks(self):
# WanVACE has two block types (WanTransformerBlock and WanVACETransformerBlock),
# so we need recompile_limit=2 instead of the default 1.
import torch._dynamo
import torch._inductor.utils
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model.compile_repeated_blocks(fullgraph=True)
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(recompile_limit=2),
):
_ = model(**inputs_dict)
_ = model(**inputs_dict)
super().test_torch_compile_repeated_blocks(recompile_limit=2)
class TestWanVACETransformer3DBitsAndBytes(WanVACETransformer3DTesterConfig, BitsAndBytesTesterMixin):

View File

@@ -158,6 +158,10 @@ class AllegroPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTes
def test_save_load_optional_components(self):
pass
@unittest.skip("Decoding without tiling is not yet implemented")
def test_pipeline_with_accelerator_device_map(self):
pass
def test_inference(self):
device = "cpu"

View File

@@ -18,7 +18,7 @@ import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, T5EncoderModel
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, DDIMScheduler
@@ -117,7 +117,9 @@ class CogVideoXPipelineFastTests(
torch.manual_seed(0)
scheduler = DDIMScheduler()
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
config.tie_word_embeddings = False
text_encoder = T5EncoderModel(config)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
components = {

View File

@@ -19,7 +19,7 @@ import unittest
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from transformers import AutoConfig, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from diffusers import (
AutoencoderKL,
@@ -97,7 +97,9 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin, Fl
text_encoder = CLIPTextModel(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
config.tie_word_embeddings = False
text_encoder_2 = T5EncoderModel(config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = T5TokenizerFast.from_pretrained("hf-internal-testing/tiny-random-t5")

View File

@@ -18,7 +18,14 @@ import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
from transformers import (
AutoConfig,
AutoTokenizer,
CLIPTextConfig,
CLIPTextModelWithProjection,
CLIPTokenizer,
T5EncoderModel,
)
from diffusers import (
AutoencoderKL,
@@ -117,7 +124,9 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
config.tie_word_embeddings = False
text_encoder_3 = T5EncoderModel(config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

View File

@@ -3,7 +3,7 @@ import unittest
import numpy as np
import torch
from PIL import Image
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from transformers import AutoConfig, AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxTransformer2DModel
@@ -53,7 +53,9 @@ class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
text_encoder = CLIPTextModel(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
config.tie_word_embeddings = False
text_encoder_2 = T5EncoderModel(config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")

View File

@@ -3,7 +3,7 @@ import unittest
import numpy as np
import torch
from PIL import Image
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from transformers import AutoConfig, AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import (
AutoencoderKL,
@@ -57,7 +57,9 @@ class FluxControlImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin
text_encoder = CLIPTextModel(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
config.tie_word_embeddings = False
text_encoder_2 = T5EncoderModel(config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")

View File

@@ -3,7 +3,7 @@ import unittest
import numpy as np
import torch
from PIL import Image
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from transformers import AutoConfig, AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import (
AutoencoderKL,
@@ -58,7 +58,9 @@ class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin
text_encoder = CLIPTextModel(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
config.tie_word_embeddings = False
text_encoder_2 = T5EncoderModel(config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")

View File

@@ -3,7 +3,7 @@ import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from transformers import AutoConfig, AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxFillPipeline, FluxTransformer2DModel
@@ -58,7 +58,9 @@ class FluxFillPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
text_encoder = CLIPTextModel(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
config.tie_word_embeddings = False
text_encoder_2 = T5EncoderModel(config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")

View File

@@ -3,7 +3,7 @@ import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from transformers import AutoConfig, AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxImg2ImgPipeline, FluxTransformer2DModel
@@ -55,7 +55,9 @@ class FluxImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxI
text_encoder = CLIPTextModel(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
config.tie_word_embeddings = False
text_encoder_2 = T5EncoderModel(config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")

View File

@@ -3,7 +3,7 @@ import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from transformers import AutoConfig, AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxInpaintPipeline, FluxTransformer2DModel
@@ -55,7 +55,9 @@ class FluxInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxI
text_encoder = CLIPTextModel(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
config.tie_word_embeddings = False
text_encoder_2 = T5EncoderModel(config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")

View File

@@ -3,7 +3,7 @@ import unittest
import numpy as np
import PIL.Image
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from transformers import AutoConfig, AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import (
AutoencoderKL,
@@ -79,7 +79,9 @@ class FluxKontextPipelineFastTests(
text_encoder = CLIPTextModel(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
config.tie_word_embeddings = False
text_encoder_2 = T5EncoderModel(config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")

View File

@@ -3,7 +3,7 @@ import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from transformers import AutoConfig, AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import (
AutoencoderKL,
@@ -79,7 +79,9 @@ class FluxKontextInpaintPipelineFastTests(
text_encoder = CLIPTextModel(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
config.tie_word_embeddings = False
text_encoder_2 = T5EncoderModel(config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")

View File

@@ -18,6 +18,7 @@ import unittest
import numpy as np
import torch
from transformers import (
AutoConfig,
AutoTokenizer,
CLIPTextConfig,
CLIPTextModelWithProjection,
@@ -94,7 +95,9 @@ class HiDreamImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
config.tie_word_embeddings = False
text_encoder_3 = T5EncoderModel(config)
torch.manual_seed(0)
text_encoder_4 = LlamaForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")

View File

@@ -19,7 +19,7 @@ import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, BertModel, T5EncoderModel
from transformers import AutoConfig, AutoTokenizer, BertModel, T5EncoderModel
from diffusers import AutoencoderKL, DDPMScheduler, HunyuanDiT2DModel, HunyuanDiTPipeline
@@ -74,7 +74,10 @@ class HunyuanDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
scheduler = DDPMScheduler()
text_encoder = BertModel.from_pretrained("hf-internal-testing/tiny-random-BertModel")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BertModel")
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
config.tie_word_embeddings = False
text_encoder_2 = T5EncoderModel(config)
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
components = {

View File

@@ -34,9 +34,7 @@ enable_full_determinism()
class KandinskyPipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = KandinskyCombinedPipeline
params = [
"prompt",
]
params = ["prompt"]
batch_params = ["prompt", "negative_prompt"]
required_optional_params = [
"generator",
@@ -148,6 +146,10 @@ class KandinskyPipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCase)
def test_dict_tuple_outputs_equivalent(self):
super().test_dict_tuple_outputs_equivalent(expected_max_difference=5e-4)
@unittest.skip("Test not supported.")
def test_pipeline_with_accelerator_device_map(self):
pass
class KandinskyPipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = KandinskyImg2ImgCombinedPipeline
@@ -264,6 +266,10 @@ class KandinskyPipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.Te
def test_save_load_optional_components(self):
super().test_save_load_optional_components(expected_max_difference=5e-4)
@unittest.skip("Test not supported.")
def test_pipeline_with_accelerator_device_map(self):
pass
class KandinskyPipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = KandinskyInpaintCombinedPipeline
@@ -384,3 +390,7 @@ class KandinskyPipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.Te
def test_save_load_local(self):
super().test_save_load_local(expected_max_difference=5e-3)
@unittest.skip("Test not supported.")
def test_pipeline_with_accelerator_device_map(self):
pass

View File

@@ -36,9 +36,7 @@ enable_full_determinism()
class KandinskyV22PipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = KandinskyV22CombinedPipeline
params = [
"prompt",
]
params = ["prompt"]
batch_params = ["prompt", "negative_prompt"]
required_optional_params = [
"generator",
@@ -70,12 +68,7 @@ class KandinskyV22PipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCa
def get_dummy_inputs(self, device, seed=0):
prior_dummy = PriorDummies()
inputs = prior_dummy.get_dummy_inputs(device=device, seed=seed)
inputs.update(
{
"height": 64,
"width": 64,
}
)
inputs.update({"height": 64, "width": 64})
return inputs
def test_kandinsky(self):
@@ -155,12 +148,18 @@ class KandinskyV22PipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCa
def test_save_load_optional_components(self):
super().test_save_load_optional_components(expected_max_difference=5e-3)
@unittest.skip("Test not supported.")
def test_callback_inputs(self):
pass
@unittest.skip("Test not supported.")
def test_callback_cfg(self):
pass
@unittest.skip("Test not supported.")
def test_pipeline_with_accelerator_device_map(self):
pass
class KandinskyV22PipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = KandinskyV22Img2ImgCombinedPipeline
@@ -279,12 +278,18 @@ class KandinskyV22PipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest
def save_load_local(self):
super().test_save_load_local(expected_max_difference=5e-3)
@unittest.skip("Test not supported.")
def test_callback_inputs(self):
pass
@unittest.skip("Test not supported.")
def test_callback_cfg(self):
pass
@unittest.skip("Test not supported.")
def test_pipeline_with_accelerator_device_map(self):
pass
class KandinskyV22PipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = KandinskyV22InpaintCombinedPipeline
@@ -411,3 +416,7 @@ class KandinskyV22PipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest
def test_callback_cfg(self):
pass
@unittest.skip("`device_map` is not yet supported for connected pipelines.")
def test_pipeline_with_accelerator_device_map(self):
pass

View File

@@ -296,6 +296,9 @@ class KandinskyV22InpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCas
output = pipe(**inputs)[0]
assert output.abs().sum() == 0
def test_pipeline_with_accelerator_device_map(self):
super().test_pipeline_with_accelerator_device_map(expected_max_difference=5e-3)
@slow
@require_torch_accelerator

View File

@@ -194,6 +194,9 @@ class Kandinsky3Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
def test_save_load_dduf(self):
super().test_save_load_dduf(atol=1e-3, rtol=1e-3)
def test_pipeline_with_accelerator_device_map(self):
super().test_pipeline_with_accelerator_device_map(expected_max_difference=5e-3)
@slow
@require_torch_accelerator

View File

@@ -17,7 +17,7 @@ import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, T5EncoderModel
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel
@@ -88,7 +88,9 @@ class LTXPipelineFastTests(PipelineTesterMixin, FirstBlockCacheTesterMixin, unit
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler()
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
config.tie_word_embeddings = False
text_encoder = T5EncoderModel(config)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
components = {

View File

@@ -4,7 +4,14 @@ import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
from transformers import (
AutoConfig,
AutoTokenizer,
CLIPTextConfig,
CLIPTextModelWithProjection,
CLIPTokenizer,
T5EncoderModel,
)
from diffusers import (
AutoencoderKL,
@@ -73,7 +80,10 @@ class StableDiffusion3Img2ImgPipelineFastTests(PipelineLatentTesterMixin, unitte
torch.manual_seed(0)
text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config)
text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
config.tie_word_embeddings = False
text_encoder_3 = T5EncoderModel(config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

View File

@@ -2355,7 +2355,6 @@ class PipelineTesterMixin:
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
)
@require_torch_accelerator
def test_pipeline_with_accelerator_device_map(self, expected_max_difference=1e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)

View File

@@ -342,3 +342,7 @@ class VisualClozePipelineFastTests(unittest.TestCase, PipelineTesterMixin):
self.assertLess(
max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
)
@unittest.skip("Test not supported.")
def test_pipeline_with_accelerator_device_map(self):
pass

View File

@@ -310,3 +310,7 @@ class VisualClozeGenerationPipelineFastTests(unittest.TestCase, PipelineTesterMi
@unittest.skip("Skipped due to missing layout_prompt. Needs further investigation.")
def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=0.0001, rtol=0.0001):
pass
@unittest.skip("Needs to be revisited later.")
def test_pipeline_with_accelerator_device_map(self, expected_max_difference=0.0001):
pass

View File

@@ -18,7 +18,7 @@ import unittest
import numpy as np
import torch
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLWan, UniPCMultistepScheduler, WanImageToVideoPipeline, WanTransformer3DModel
@@ -64,7 +64,11 @@ class Wan22ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
torch.manual_seed(0)
scheduler = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
config.tie_word_embeddings = False
text_encoder = T5EncoderModel(config)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
@@ -248,7 +252,11 @@ class Wan225BImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCas
torch.manual_seed(0)
scheduler = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
config.tie_word_embeddings = False
text_encoder = T5EncoderModel(config)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)