Compare commits

...

2 Commits

Author SHA1 Message Date
sayakpaul
2ac264c737 prolong the cycle 2024-02-08 11:50:23 +05:30
sayakpaul
4bd8525339 start the depcrecation cycle for torch_dtype in from_pretrained 2024-02-08 11:49:56 +05:30
2 changed files with 31 additions and 7 deletions

View File

@@ -517,6 +517,18 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
if torch_dtype is not None:
deprecate("torch_dtype", "0.30.0", "Using `torch_dtype` is depcrecated. Use `dtype`, instead.")
dtype_kwarg = kwargs.pop("dtype", None)
if torch_dtype is not None and dtype_kwarg is not None:
raise ValueError(
"You have passed both `torch_dtype` and `dtype` as a keyword argument. Please make sure to only pass `dtype`."
)
dtype = torch_dtype or dtype_kwarg
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
@@ -670,7 +682,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
model,
state_dict,
device=param_device,
dtype=torch_dtype,
dtype=dtype,
model_name_or_path=pretrained_model_name_or_path,
)
@@ -755,12 +767,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
"error_msgs": error_msgs,
}
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
if dtype is not None and not isinstance(dtype, torch.dtype):
raise ValueError(
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
f"{dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
)
elif torch_dtype is not None:
model = model.to(torch_dtype)
elif dtype is not None:
model = model.to(dtype)
model.register_to_config(_name_or_path=pretrained_model_name_or_path)

View File

@@ -1078,6 +1078,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
if torch_dtype is not None:
deprecate("torch_dtype", "0.30.0", "Using `torch_dtype` is depcrecated. Use `dtype`, instead.")
dtype_kwarg = kwargs.pop("dtype", None)
if torch_dtype is not None and dtype_kwarg is not None:
raise ValueError(
"You have passed both `torch_dtype` and `dtype` as a keyword argument. Please make sure to only pass `dtype`."
)
dtype = torch_dtype or dtype_kwarg
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
if not os.path.isdir(pretrained_model_name_or_path):
@@ -1268,7 +1280,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
pipelines=pipelines,
is_pipeline_module=is_pipeline_module,
pipeline_class=pipeline_class,
torch_dtype=torch_dtype,
torch_dtype=dtype,
provider=provider,
sess_options=sess_options,
device_map=device_map,
@@ -1300,7 +1312,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
"local_files_only": local_files_only,
"token": token,
"revision": revision,
"torch_dtype": torch_dtype,
"torch_dtype": dtype,
"custom_pipeline": custom_pipeline,
"custom_revision": custom_revision,
"provider": provider,