Compare commits

..

1 Commits

Author SHA1 Message Date
sayakpaul
971d0cd6e6 replace with actual fix. 2024-09-23 08:18:50 +05:30
5 changed files with 81 additions and 92 deletions

View File

@@ -153,9 +153,12 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
# every down weight has a corresponding up weight and potentially an alpha weight
lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")]
for key in lora_keys:
# Iterate over all LoRA weights.
all_lora_keys = list(state_dict.keys())
for key in all_lora_keys:
if not key.endswith("lora_down.weight"):
continue
# Extract LoRA name.
lora_name = key.split(".")[0]
@@ -174,12 +177,9 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
# Store DoRA scale if present.
if dora_present_in_unet:
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
new_key = diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
# dora_weight = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
dora_weight = state_dict.pop(lora_name + ".dora_scale")
if dora_weight.dim() <= 2:
dora_weight = dora_weight.squeeze()
unet_state_dict[new_key] = dora_weight
unet_state_dict[
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
# Handle text encoder LoRAs.
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
@@ -194,24 +194,18 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
# Store DoRA scale if present.
if (dora_present_in_te or dora_present_in_te2):
if dora_present_in_te or dora_present_in_te2:
dora_scale_key_to_replace_te = (
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
)
if lora_name.startswith(("lora_te_", "lora_te1_")):
new_key = diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
# dora_weight = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
dora_weight = state_dict.pop(lora_name + ".dora_scale")
if dora_weight.dim() <= 2:
dora_weight = dora_weight.squeeze()
te_state_dict[new_key] = dora_weight
te_state_dict[
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
elif lora_name.startswith("lora_te2_"):
new_key = diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
# dora_weight = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
dora_weight = state_dict.pop(lora_name + ".dora_scale")
if dora_weight.dim() <= 2:
dora_weight = dora_weight.squeeze()
te2_state_dict[new_key] = dora_weight
te2_state_dict[
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
# Store alpha if present.
if lora_name_alpha in state_dict:
@@ -220,8 +214,7 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
# Check if any keys remain.
if len(state_dict) > 0:
all_keys_remaining = sorted(list(state_dict.keys()))
raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(all_keys_remaining)}")
raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
logger.info("Non-diffusers checkpoint detected.")
@@ -292,7 +285,7 @@ def _convert_unet_lora_key(key):
pass
else:
pass
return diffusers_name

View File

@@ -188,9 +188,6 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
self.vae_scale_factor_temporal = (
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
)
self.vae_scaling_factor_image = (
self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@@ -320,12 +317,6 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
shape = (
batch_size,
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
@@ -333,6 +324,11 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
@@ -345,7 +341,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
latents = 1 / self.vae_scaling_factor_image * latents
latents = 1 / self.vae.config.scaling_factor * latents
frames = self.vae.decode(latents).sample
return frames
@@ -514,10 +510,10 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
The height in pixels of the generated image. This is set to 480 by default for the best results.
width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
The width in pixels of the generated image. This is set to 720 by default for the best results.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_frames (`int`, defaults to `48`):
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
@@ -591,6 +587,8 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct

View File

@@ -207,9 +207,6 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
self.vae_scale_factor_temporal = (
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
)
self.vae_scaling_factor_image = (
self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@@ -351,12 +348,6 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
):
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
shape = (
batch_size,
@@ -366,6 +357,12 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
width // self.vae_scale_factor_spatial,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
image = image.unsqueeze(2) # [B, C, F, H, W]
if isinstance(generator, list):
@@ -376,7 +373,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
image_latents = self.vae_scaling_factor_image * image_latents
image_latents = self.vae.config.scaling_factor * image_latents
padding_shape = (
batch_size,
@@ -400,7 +397,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
latents = 1 / self.vae_scaling_factor_image * latents
latents = 1 / self.vae.config.scaling_factor * latents
frames = self.vae.decode(latents).sample
return frames
@@ -441,6 +438,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
video=None,
latents=None,
prompt_embeds=None,
negative_prompt_embeds=None,
@@ -496,6 +494,9 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
f" {negative_prompt_embeds.shape}."
)
if video is not None and latents is not None:
raise ValueError("Only one of `video` or `latents` should be provided")
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections
def fuse_qkv_projections(self) -> None:
r"""Enables fused QKV projections."""
@@ -583,7 +584,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
Args:
image (`PipelineImageInput`):
The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
The input video to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
@@ -591,10 +592,10 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
The height in pixels of the generated image. This is set to 480 by default for the best results.
width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
The width in pixels of the generated image. This is set to 720 by default for the best results.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_frames (`int`, defaults to `48`):
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
@@ -664,19 +665,20 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct
self.check_inputs(
image=image,
prompt=prompt,
height=height,
width=width,
negative_prompt=negative_prompt,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
latents=latents,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
image,
prompt,
height,
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds,
negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._interrupt = False

View File

@@ -204,16 +204,12 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor_spatial = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.vae_scale_factor_temporal = (
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
)
self.vae_scaling_factor_image = (
self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@@ -355,12 +351,6 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
latents: Optional[torch.Tensor] = None,
timestep: Optional[torch.Tensor] = None,
):
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
num_frames = (video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1)
shape = (
@@ -371,6 +361,12 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
width // self.vae_scale_factor_spatial,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
if isinstance(generator, list):
if len(generator) != batch_size:
@@ -386,7 +382,7 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
init_latents = torch.cat(init_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
init_latents = self.vae_scaling_factor_image * init_latents
init_latents = self.vae.config.scaling_factor * init_latents
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self.scheduler.add_noise(init_latents, noise, timestep)
@@ -400,7 +396,7 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
latents = 1 / self.vae_scaling_factor_image * latents
latents = 1 / self.vae.config.scaling_factor * latents
frames = self.vae.decode(latents).sample
return frames
@@ -593,10 +589,10 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
The height in pixels of the generated image. This is set to 480 by default for the best results.
width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
The width in pixels of the generated image. This is set to 720 by default for the best results.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
@@ -662,20 +658,20 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
height=height,
width=width,
strength=strength,
negative_prompt=negative_prompt,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
video=video,
latents=latents,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt,
height,
width,
strength,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds,
negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs

View File

@@ -30,11 +30,11 @@ enable_full_determinism()
@require_torch_gpu
class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionControlNetPipeline
ckpt_path = "https://huggingface.co/Jiali/stable-diffusion-1.5/blob/main/v1-5-pruned-emaonly.safetensors"
ckpt_path = "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
original_config = (
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
)
repo_id = "Jiali/stable-diffusion-1.5"
repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
def setUp(self):
super().setUp()