mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
2 Commits
v0.35.2
...
depcrecate
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2ac264c737 | ||
|
|
4bd8525339 |
@@ -517,6 +517,18 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
if torch_dtype is not None:
|
||||
deprecate("torch_dtype", "0.30.0", "Using `torch_dtype` is depcrecated. Use `dtype`, instead.")
|
||||
|
||||
dtype_kwarg = kwargs.pop("dtype", None)
|
||||
|
||||
if torch_dtype is not None and dtype_kwarg is not None:
|
||||
raise ValueError(
|
||||
"You have passed both `torch_dtype` and `dtype` as a keyword argument. Please make sure to only pass `dtype`."
|
||||
)
|
||||
|
||||
dtype = torch_dtype or dtype_kwarg
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
@@ -670,7 +682,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
model,
|
||||
state_dict,
|
||||
device=param_device,
|
||||
dtype=torch_dtype,
|
||||
dtype=dtype,
|
||||
model_name_or_path=pretrained_model_name_or_path,
|
||||
)
|
||||
|
||||
@@ -755,12 +767,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
"error_msgs": error_msgs,
|
||||
}
|
||||
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||
if dtype is not None and not isinstance(dtype, torch.dtype):
|
||||
raise ValueError(
|
||||
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
||||
f"{dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
||||
)
|
||||
elif torch_dtype is not None:
|
||||
model = model.to(torch_dtype)
|
||||
elif dtype is not None:
|
||||
model = model.to(dtype)
|
||||
|
||||
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
||||
|
||||
|
||||
@@ -1078,6 +1078,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
use_onnx = kwargs.pop("use_onnx", None)
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
|
||||
if torch_dtype is not None:
|
||||
deprecate("torch_dtype", "0.30.0", "Using `torch_dtype` is depcrecated. Use `dtype`, instead.")
|
||||
|
||||
dtype_kwarg = kwargs.pop("dtype", None)
|
||||
|
||||
if torch_dtype is not None and dtype_kwarg is not None:
|
||||
raise ValueError(
|
||||
"You have passed both `torch_dtype` and `dtype` as a keyword argument. Please make sure to only pass `dtype`."
|
||||
)
|
||||
|
||||
dtype = torch_dtype or dtype_kwarg
|
||||
|
||||
# 1. Download the checkpoints and configs
|
||||
# use snapshot download here to get it working from from_pretrained
|
||||
if not os.path.isdir(pretrained_model_name_or_path):
|
||||
@@ -1268,7 +1280,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
pipelines=pipelines,
|
||||
is_pipeline_module=is_pipeline_module,
|
||||
pipeline_class=pipeline_class,
|
||||
torch_dtype=torch_dtype,
|
||||
torch_dtype=dtype,
|
||||
provider=provider,
|
||||
sess_options=sess_options,
|
||||
device_map=device_map,
|
||||
@@ -1300,7 +1312,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
"local_files_only": local_files_only,
|
||||
"token": token,
|
||||
"revision": revision,
|
||||
"torch_dtype": torch_dtype,
|
||||
"torch_dtype": dtype,
|
||||
"custom_pipeline": custom_pipeline,
|
||||
"custom_revision": custom_revision,
|
||||
"provider": provider,
|
||||
|
||||
Reference in New Issue
Block a user