mirror of
https://github.com/huggingface/diffusers.git
synced 2026-04-25 01:11:34 +08:00
cleanup
This commit is contained in:
@@ -527,15 +527,11 @@ class AutoencoderRAE(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMode
|
||||
num_channels (`int`, *optional*, defaults to `3`):
|
||||
Number of input/output channels.
|
||||
latents_mean (`list` or `tuple`, *optional*):
|
||||
Optional mean for latent normalization. Tensor inputs are accepted for backward compatibility and converted
|
||||
to config-serializable lists.
|
||||
Optional mean for latent normalization. Tensor inputs are accepted and converted to config-serializable
|
||||
lists.
|
||||
latents_std (`list` or `tuple`, *optional*):
|
||||
Optional standard deviation for latent normalization. Tensor inputs are accepted for backward compatibility
|
||||
and converted to config-serializable lists.
|
||||
latent_mean (`list` or `tuple`, *optional*):
|
||||
Deprecated alias of `latents_mean`.
|
||||
latent_var (`list` or `tuple`, *optional*):
|
||||
Deprecated alias of latent variance. If provided, it is converted to `latents_std = sqrt(latent_var + 1e-5)`.
|
||||
Optional standard deviation for latent normalization. Tensor inputs are accepted and converted to
|
||||
config-serializable lists.
|
||||
noise_tau (`float`, *optional*, defaults to `0.0`):
|
||||
Noise level for training (adds noise to latents during training).
|
||||
reshape_to_2d (`bool`, *optional*, defaults to `True`):
|
||||
@@ -563,8 +559,6 @@ class AutoencoderRAE(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMode
|
||||
num_channels: int = 3,
|
||||
latents_mean: Optional[Union[list, tuple, torch.Tensor]] = None,
|
||||
latents_std: Optional[Union[list, tuple, torch.Tensor]] = None,
|
||||
latent_mean: Optional[Union[list, tuple, torch.Tensor]] = None,
|
||||
latent_var: Optional[Union[list, tuple, torch.Tensor]] = None,
|
||||
noise_tau: float = 0.0,
|
||||
reshape_to_2d: bool = True,
|
||||
use_encoder_loss: bool = False,
|
||||
@@ -586,14 +580,6 @@ class AutoencoderRAE(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMode
|
||||
return [_to_config_compatible(v) for v in value]
|
||||
return value
|
||||
|
||||
if latents_mean is not None and latent_mean is not None:
|
||||
raise ValueError("Please provide only one of `latents_mean` or deprecated `latent_mean`.")
|
||||
if latents_std is not None and latent_var is not None:
|
||||
raise ValueError("Please provide only one of `latents_std` or deprecated `latent_var`.")
|
||||
|
||||
if latents_mean is None:
|
||||
latents_mean = latent_mean
|
||||
|
||||
def _as_optional_tensor(value: Optional[Union[torch.Tensor, list, tuple]]) -> Optional[torch.Tensor]:
|
||||
if value is None:
|
||||
return None
|
||||
@@ -602,18 +588,12 @@ class AutoencoderRAE(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMode
|
||||
return torch.tensor(value, dtype=torch.float32)
|
||||
|
||||
latents_std_tensor = _as_optional_tensor(latents_std)
|
||||
latent_var_tensor = _as_optional_tensor(latent_var)
|
||||
if latents_std_tensor is None and latent_var_tensor is not None:
|
||||
latents_std_tensor = torch.sqrt(latent_var_tensor + 1e-5)
|
||||
latents_std = latents_std_tensor
|
||||
|
||||
# Ensure config values are JSON-serializable (list/None), even if caller passes torch.Tensors.
|
||||
self.register_to_config(
|
||||
encoder_name_or_path=encoder_name_or_path,
|
||||
latents_mean=_to_config_compatible(latents_mean),
|
||||
latents_std=_to_config_compatible(latents_std),
|
||||
latent_mean=None,
|
||||
latent_var=None,
|
||||
)
|
||||
|
||||
self.encoder_input_size = encoder_input_size
|
||||
@@ -680,7 +660,7 @@ class AutoencoderRAE(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMode
|
||||
|
||||
# Optional latent normalization (RAE-main uses mean/var)
|
||||
latents_mean_tensor = _as_optional_tensor(latents_mean)
|
||||
self.do_latent_normalization = latents_mean is not None or latents_std is not None or latent_var is not None
|
||||
self.do_latent_normalization = latents_mean is not None or latents_std is not None
|
||||
if latents_mean_tensor is not None:
|
||||
self.register_buffer("_latents_mean", latents_mean_tensor, persistent=True)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user