This commit is contained in:
Kashif Rasul
2026-02-15 23:19:13 +00:00
parent d7cb12470b
commit 0d59b22732
3 changed files with 13 additions and 31 deletions

View File

@@ -527,15 +527,11 @@ class AutoencoderRAE(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMode
num_channels (`int`, *optional*, defaults to `3`):
Number of input/output channels.
latents_mean (`list` or `tuple`, *optional*):
Optional mean for latent normalization. Tensor inputs are accepted for backward compatibility and converted
to config-serializable lists.
Optional mean for latent normalization. Tensor inputs are accepted and converted to config-serializable
lists.
latents_std (`list` or `tuple`, *optional*):
Optional standard deviation for latent normalization. Tensor inputs are accepted for backward compatibility
and converted to config-serializable lists.
latent_mean (`list` or `tuple`, *optional*):
Deprecated alias of `latents_mean`.
latent_var (`list` or `tuple`, *optional*):
Deprecated alias of latent variance. If provided, it is converted to `latents_std = sqrt(latent_var + 1e-5)`.
Optional standard deviation for latent normalization. Tensor inputs are accepted and converted to
config-serializable lists.
noise_tau (`float`, *optional*, defaults to `0.0`):
Noise level for training (adds noise to latents during training).
reshape_to_2d (`bool`, *optional*, defaults to `True`):
@@ -563,8 +559,6 @@ class AutoencoderRAE(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMode
num_channels: int = 3,
latents_mean: Optional[Union[list, tuple, torch.Tensor]] = None,
latents_std: Optional[Union[list, tuple, torch.Tensor]] = None,
latent_mean: Optional[Union[list, tuple, torch.Tensor]] = None,
latent_var: Optional[Union[list, tuple, torch.Tensor]] = None,
noise_tau: float = 0.0,
reshape_to_2d: bool = True,
use_encoder_loss: bool = False,
@@ -586,14 +580,6 @@ class AutoencoderRAE(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMode
return [_to_config_compatible(v) for v in value]
return value
if latents_mean is not None and latent_mean is not None:
raise ValueError("Please provide only one of `latents_mean` or deprecated `latent_mean`.")
if latents_std is not None and latent_var is not None:
raise ValueError("Please provide only one of `latents_std` or deprecated `latent_var`.")
if latents_mean is None:
latents_mean = latent_mean
def _as_optional_tensor(value: Optional[Union[torch.Tensor, list, tuple]]) -> Optional[torch.Tensor]:
if value is None:
return None
@@ -602,18 +588,12 @@ class AutoencoderRAE(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMode
return torch.tensor(value, dtype=torch.float32)
latents_std_tensor = _as_optional_tensor(latents_std)
latent_var_tensor = _as_optional_tensor(latent_var)
if latents_std_tensor is None and latent_var_tensor is not None:
latents_std_tensor = torch.sqrt(latent_var_tensor + 1e-5)
latents_std = latents_std_tensor
# Ensure config values are JSON-serializable (list/None), even if caller passes torch.Tensors.
self.register_to_config(
encoder_name_or_path=encoder_name_or_path,
latents_mean=_to_config_compatible(latents_mean),
latents_std=_to_config_compatible(latents_std),
latent_mean=None,
latent_var=None,
)
self.encoder_input_size = encoder_input_size
@@ -680,7 +660,7 @@ class AutoencoderRAE(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMode
# Optional latent normalization (RAE-main uses mean/var)
latents_mean_tensor = _as_optional_tensor(latents_mean)
self.do_latent_normalization = latents_mean is not None or latents_std is not None or latent_var is not None
self.do_latent_normalization = latents_mean is not None or latents_std is not None
if latents_mean_tensor is not None:
self.register_buffer("_latents_mean", latents_mean_tensor, persistent=True)
else: