Compare commits

..

1 Commits

Author SHA1 Message Date
sayakpaul
e4fc2a138d fix torchao typo. 2025-12-23 12:56:51 +05:30
2 changed files with 8 additions and 22 deletions

View File

@@ -33,7 +33,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantzation_config=pipeline_quant_config,
quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
device_map="cuda"
)
@@ -50,7 +50,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantzation_config=pipeline_quant_config,
quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
device_map="cuda"
)
@@ -70,7 +70,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantzation_config=pipeline_quant_config,
quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
device_map="cuda"
)

View File

@@ -109,7 +109,7 @@ LIBRARIES = []
for library in LOADABLE_CLASSES:
LIBRARIES.append(library)
SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device(), "cpu"]
SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device()]
logger = logging.get_logger(__name__)
@@ -462,7 +462,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
pipeline_is_sequentially_offloaded = any(
module_is_sequentially_offloaded(module) for _, module in self.components.items()
)
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
@@ -1163,7 +1164,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
"""
self._maybe_raise_error_if_group_offload_active(raise_error=True)
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`."
@@ -1285,7 +1286,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
self.remove_all_hooks()
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
@@ -2170,21 +2171,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
return True
return False
def _is_pipeline_device_mapped(self):
# We support passing `device_map="cuda"`, for example. This is helpful, in case
# users want to pass `device_map="cpu"` when initializing a pipeline. This explicit declaration is desirable
# in limited VRAM environments because quantized models often initialize directly on the accelerator.
device_map = self.hf_device_map
is_device_type_map = False
if isinstance(device_map, str):
try:
torch.device(device_map)
is_device_type_map = True
except RuntimeError:
pass
return not is_device_type_map and isinstance(device_map, dict) and len(device_map) > 1
class StableDiffusionMixin:
r"""