Compare commits

..

8 Commits

Author SHA1 Message Date
Dhruv Nair
d1fa0301bc Merge branch 'fast-gpu-tests' of https://github.com/huggingface/diffusers into fast-gpu-tests 2025-02-27 08:51:36 +01:00
Dhruv Nair
cca8e144b7 update 2025-02-27 08:51:25 +01:00
Sayak Paul
fac5514e90 Merge branch 'main' into fast-gpu-tests 2025-02-27 09:10:16 +05:30
Dhruv Nair
828dd32464 Merge branch 'main' into fast-gpu-test-fixes 2025-02-26 18:27:56 +01:00
Dhruv Nair
721501c754 update 2025-02-26 18:24:02 +01:00
Dhruv Nair
4756522e55 update 2025-02-26 18:23:11 +01:00
Dhruv Nair
d108c18f50 update 2025-02-26 04:34:56 +01:00
Dhruv Nair
e2d2650117 update 2025-02-25 13:50:21 +01:00
5 changed files with 47 additions and 63 deletions

View File

@@ -106,18 +106,11 @@ jobs:
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
if [ "${{ matrix.module }}" = "ip_adapters" ]; then
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
else
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx and $pattern" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
fi
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx and $pattern" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
- name: Failure short reports
if: ${{ failure() }}

View File

@@ -1,4 +1,3 @@
<!---
Copyright 2022 - The HuggingFace Team. All rights reserved.

View File

@@ -605,13 +605,12 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
controlnet_cond: List[torch.Tensor],
control_type: torch.Tensor,
control_type_idx: List[int],
conditioning_scale: Union[float, List[float]] = 1.0,
conditioning_scale: float = 1.0,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
from_multi: bool = False,
guess_mode: bool = False,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
@@ -648,8 +647,6 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
Additional conditions for the Stable Diffusion XL UNet.
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
from_multi (`bool`, defaults to `False`):
Use standard scaling when called from `MultiControlNetUnionModel`.
guess_mode (`bool`, defaults to `False`):
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
@@ -661,9 +658,6 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
returned where the first element is the sample tensor.
"""
if isinstance(conditioning_scale, float):
conditioning_scale = [conditioning_scale] * len(controlnet_cond)
# check channel order
channel_order = self.config.controlnet_conditioning_channel_order
@@ -748,16 +742,12 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
inputs = []
condition_list = []
for cond, control_idx, scale in zip(controlnet_cond, control_type_idx, conditioning_scale):
for cond, control_idx in zip(controlnet_cond, control_type_idx):
condition = self.controlnet_cond_embedding(cond)
feat_seq = torch.mean(condition, dim=(2, 3))
feat_seq = feat_seq + self.task_embedding[control_idx]
if from_multi:
inputs.append(feat_seq.unsqueeze(1))
condition_list.append(condition)
else:
inputs.append(feat_seq.unsqueeze(1) * scale)
condition_list.append(condition * scale)
inputs.append(feat_seq.unsqueeze(1))
condition_list.append(condition)
condition = sample
feat_seq = torch.mean(condition, dim=(2, 3))
@@ -769,13 +759,10 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
x = layer(x)
controlnet_cond_fuser = sample * 0.0
for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale):
for idx, condition in enumerate(condition_list[:-1]):
alpha = self.spatial_ch_projs(x[:, idx])
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
if from_multi:
controlnet_cond_fuser += condition + alpha
else:
controlnet_cond_fuser += condition + alpha * scale
controlnet_cond_fuser += condition + alpha
sample = sample + controlnet_cond_fuser
@@ -819,13 +806,12 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
# 6. scaling
if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
if from_multi:
scales = scales * conditioning_scale[0]
scales = scales * conditioning_scale
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
elif from_multi:
down_block_res_samples = [sample * conditioning_scale[0] for sample in down_block_res_samples]
mid_block_res_sample = mid_block_res_sample * conditioning_scale[0]
else:
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
mid_block_res_sample = mid_block_res_sample * conditioning_scale
if self.config.global_pool_conditions:
down_block_res_samples = [

View File

@@ -47,12 +47,9 @@ class MultiControlNetUnionModel(ModelMixin):
guess_mode: bool = False,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]:
down_block_res_samples, mid_block_res_sample = None, None
for i, (image, ctype, ctype_idx, scale, controlnet) in enumerate(
zip(controlnet_cond, control_type, control_type_idx, conditioning_scale, self.nets)
):
if scale == 0.0:
continue
down_samples, mid_sample = controlnet(
sample=sample,
timestep=timestep,
@@ -66,13 +63,12 @@ class MultiControlNetUnionModel(ModelMixin):
attention_mask=attention_mask,
added_cond_kwargs=added_cond_kwargs,
cross_attention_kwargs=cross_attention_kwargs,
from_multi=True,
guess_mode=guess_mode,
return_dict=return_dict,
)
# merge samples
if down_block_res_samples is None and mid_block_res_sample is None:
if i == 0:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
down_block_res_samples = [

View File

@@ -757,9 +757,15 @@ class StableDiffusionXLControlNetUnionPipeline(
for images_ in image:
for image_ in images_:
self.check_image(image_, prompt, prompt_embeds)
else:
assert False
# Check `controlnet_conditioning_scale`
if isinstance(controlnet, MultiControlNetUnionModel):
# TODO Update for https://github.com/huggingface/diffusers/pull/10723
if isinstance(controlnet, ControlNetUnionModel):
if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
elif isinstance(controlnet, MultiControlNetUnionModel):
if isinstance(controlnet_conditioning_scale, list):
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
raise ValueError("A single batch of multiple conditionings is not supported at the moment.")
@@ -770,6 +776,8 @@ class StableDiffusionXLControlNetUnionPipeline(
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
" the same length as the number of controlnets"
)
else:
assert False
if len(control_guidance_start) != len(control_guidance_end):
raise ValueError(
@@ -800,6 +808,8 @@ class StableDiffusionXLControlNetUnionPipeline(
for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets):
if max(_control_mode) >= _controlnet.config.num_control_type:
raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.")
else:
assert False
# Equal number of `image` and `control_mode` elements
if isinstance(controlnet, ControlNetUnionModel):
@@ -813,6 +823,8 @@ class StableDiffusionXLControlNetUnionPipeline(
elif sum(len(x) for x in image) != sum(len(x) for x in control_mode):
raise ValueError("Expected len(control_image) == len(control_mode)")
else:
assert False
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
@@ -1189,6 +1201,18 @@ class StableDiffusionXLControlNetUnionPipeline(
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else 1
control_guidance_start, control_guidance_end = (
mult * [control_guidance_start],
mult * [control_guidance_end],
)
if not isinstance(control_image, list):
control_image = [control_image]
else:
@@ -1197,25 +1221,8 @@ class StableDiffusionXLControlNetUnionPipeline(
if not isinstance(control_mode, list):
control_mode = [control_mode]
if isinstance(controlnet, MultiControlNetUnionModel):
control_image = [[item] for item in control_image]
control_mode = [[item] for item in control_mode]
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
control_guidance_start, control_guidance_end = (
mult * [control_guidance_start],
mult * [control_guidance_end],
)
if isinstance(controlnet_conditioning_scale, float):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult
if isinstance(controlnet, MultiControlNetUnionModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
# 1. Check inputs
self.check_inputs(
@@ -1350,6 +1357,9 @@ class StableDiffusionXLControlNetUnionPipeline(
control_image = control_images
height, width = control_image[0][0].shape[-2:]
else:
assert False
# 5. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
@@ -1387,7 +1397,7 @@ class StableDiffusionXLControlNetUnionPipeline(
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
for s, e in zip(control_guidance_start, control_guidance_end)
]
controlnet_keep.append(keeps)
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetUnionModel) else keeps)
# 7.2 Prepare added time ids & embeddings
original_size = original_size or (height, width)