Compare commits

...

3 Commits

Author SHA1 Message Date
DN6
4efd3de674 update 2026-02-16 23:04:41 +05:30
DN6
685ee01154 update 2026-02-16 21:34:16 +05:30
Sayak Paul
35086ac06a [core] support device type device_maps to work with offloading. (#12811)
* support device type device_maps to work with offloading.

* add tests.

* fix tests

* skip tests where it's not supported.

* empty

* up

* up

* fix allegro.
2026-02-16 16:31:45 +05:30
11 changed files with 87 additions and 21 deletions

View File

@@ -77,7 +77,8 @@ jobs:
uv pip install -e ".[quality]"
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip uninstall transformers huggingface_hub && uv pip install transformers
- name: Environment
run: |
python utils/print_env.py
@@ -135,7 +136,7 @@ jobs:
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip uninstall transformers huggingface_hub && uv pip install transformers
- name: Environment
run: |
@@ -188,7 +189,7 @@ jobs:
run: |
uv pip install -e ".[quality,training]"
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip uninstall transformers huggingface_hub && uv pip install transformers
- name: Environment
run: |
python utils/print_env.py

View File

@@ -112,7 +112,7 @@ LIBRARIES = []
for library in LOADABLE_CLASSES:
LIBRARIES.append(library)
SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device()]
SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device(), "cpu"]
logger = logging.get_logger(__name__)
@@ -468,8 +468,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
pipeline_is_sequentially_offloaded = any(
module_is_sequentially_offloaded(module) for _, module in self.components.items()
)
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
if is_pipeline_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
@@ -1188,7 +1187,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
"""
self._maybe_raise_error_if_group_offload_active(raise_error=True)
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
if is_pipeline_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`."
@@ -1312,7 +1311,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
self.remove_all_hooks()
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
if is_pipeline_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
@@ -2228,6 +2227,21 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
return True
return False
def _is_pipeline_device_mapped(self):
# We support passing `device_map="cuda"`, for example. This is helpful, in case
# users want to pass `device_map="cpu"` when initializing a pipeline. This explicit declaration is desirable
# in limited VRAM environments because quantized models often initialize directly on the accelerator.
device_map = self.hf_device_map
is_device_type_map = False
if isinstance(device_map, str):
try:
torch.device(device_map)
is_device_type_map = True
except RuntimeError:
pass
return not is_device_type_map and isinstance(device_map, dict) and len(device_map) > 1
class StableDiffusionMixin:
r"""

View File

@@ -628,6 +628,21 @@ class BitsAndBytesTesterMixin(BitsAndBytesConfigMixin, QuantizationTesterMixin):
"""Test that quantized models can be used for training with adapters."""
self._test_quantization_training(BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"])
@pytest.mark.parametrize(
"config_name",
list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()),
ids=list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()),
)
def test_cpu_device_map(self, config_name):
config_kwargs = BitsAndBytesConfigMixin.BNB_CONFIGS[config_name]
model_quantized = self._create_quantized_model(config_kwargs, device_map="cpu")
assert hasattr(model_quantized, "hf_device_map"), "Model should have hf_device_map attribute"
assert model_quantized.hf_device_map is not None, "hf_device_map should not be None"
assert model_quantized.device == torch.device("cpu"), (
f"Model should be on CPU, but is on {model_quantized.device}"
)
@is_quantization
@is_quanto

View File

@@ -158,6 +158,10 @@ class AllegroPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTes
def test_save_load_optional_components(self):
pass
@unittest.skip("Decoding without tiling is not yet implemented")
def test_pipeline_with_accelerator_device_map(self):
pass
def test_inference(self):
device = "cpu"

View File

@@ -34,9 +34,7 @@ enable_full_determinism()
class KandinskyPipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = KandinskyCombinedPipeline
params = [
"prompt",
]
params = ["prompt"]
batch_params = ["prompt", "negative_prompt"]
required_optional_params = [
"generator",
@@ -148,6 +146,10 @@ class KandinskyPipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCase)
def test_dict_tuple_outputs_equivalent(self):
super().test_dict_tuple_outputs_equivalent(expected_max_difference=5e-4)
@unittest.skip("Test not supported.")
def test_pipeline_with_accelerator_device_map(self):
pass
class KandinskyPipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = KandinskyImg2ImgCombinedPipeline
@@ -264,6 +266,10 @@ class KandinskyPipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.Te
def test_save_load_optional_components(self):
super().test_save_load_optional_components(expected_max_difference=5e-4)
@unittest.skip("Test not supported.")
def test_pipeline_with_accelerator_device_map(self):
pass
class KandinskyPipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = KandinskyInpaintCombinedPipeline
@@ -384,3 +390,7 @@ class KandinskyPipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.Te
def test_save_load_local(self):
super().test_save_load_local(expected_max_difference=5e-3)
@unittest.skip("Test not supported.")
def test_pipeline_with_accelerator_device_map(self):
pass

View File

@@ -36,9 +36,7 @@ enable_full_determinism()
class KandinskyV22PipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = KandinskyV22CombinedPipeline
params = [
"prompt",
]
params = ["prompt"]
batch_params = ["prompt", "negative_prompt"]
required_optional_params = [
"generator",
@@ -70,12 +68,7 @@ class KandinskyV22PipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCa
def get_dummy_inputs(self, device, seed=0):
prior_dummy = PriorDummies()
inputs = prior_dummy.get_dummy_inputs(device=device, seed=seed)
inputs.update(
{
"height": 64,
"width": 64,
}
)
inputs.update({"height": 64, "width": 64})
return inputs
def test_kandinsky(self):
@@ -155,12 +148,18 @@ class KandinskyV22PipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCa
def test_save_load_optional_components(self):
super().test_save_load_optional_components(expected_max_difference=5e-3)
@unittest.skip("Test not supported.")
def test_callback_inputs(self):
pass
@unittest.skip("Test not supported.")
def test_callback_cfg(self):
pass
@unittest.skip("Test not supported.")
def test_pipeline_with_accelerator_device_map(self):
pass
class KandinskyV22PipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = KandinskyV22Img2ImgCombinedPipeline
@@ -279,12 +278,18 @@ class KandinskyV22PipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest
def save_load_local(self):
super().test_save_load_local(expected_max_difference=5e-3)
@unittest.skip("Test not supported.")
def test_callback_inputs(self):
pass
@unittest.skip("Test not supported.")
def test_callback_cfg(self):
pass
@unittest.skip("Test not supported.")
def test_pipeline_with_accelerator_device_map(self):
pass
class KandinskyV22PipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = KandinskyV22InpaintCombinedPipeline
@@ -411,3 +416,7 @@ class KandinskyV22PipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest
def test_callback_cfg(self):
pass
@unittest.skip("`device_map` is not yet supported for connected pipelines.")
def test_pipeline_with_accelerator_device_map(self):
pass

View File

@@ -296,6 +296,9 @@ class KandinskyV22InpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCas
output = pipe(**inputs)[0]
assert output.abs().sum() == 0
def test_pipeline_with_accelerator_device_map(self):
super().test_pipeline_with_accelerator_device_map(expected_max_difference=5e-3)
@slow
@require_torch_accelerator

View File

@@ -194,6 +194,9 @@ class Kandinsky3Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
def test_save_load_dduf(self):
super().test_save_load_dduf(atol=1e-3, rtol=1e-3)
def test_pipeline_with_accelerator_device_map(self):
super().test_pipeline_with_accelerator_device_map(expected_max_difference=5e-3)
@slow
@require_torch_accelerator

View File

@@ -2355,7 +2355,6 @@ class PipelineTesterMixin:
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
)
@require_torch_accelerator
def test_pipeline_with_accelerator_device_map(self, expected_max_difference=1e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)

View File

@@ -342,3 +342,7 @@ class VisualClozePipelineFastTests(unittest.TestCase, PipelineTesterMixin):
self.assertLess(
max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
)
@unittest.skip("Test not supported.")
def test_pipeline_with_accelerator_device_map(self):
pass

View File

@@ -310,3 +310,7 @@ class VisualClozeGenerationPipelineFastTests(unittest.TestCase, PipelineTesterMi
@unittest.skip("Skipped due to missing layout_prompt. Needs further investigation.")
def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=0.0001, rtol=0.0001):
pass
@unittest.skip("Needs to be revisited later.")
def test_pipeline_with_accelerator_device_map(self, expected_max_difference=0.0001):
pass