Compare commits

...

6 Commits

Author SHA1 Message Date
Wauplin
7547a8afd3 Mention model_info.id instead of model_info.modelId 2024-07-20 07:26:54 +02:00
Pierre Chapuis
fe7948941d allow tensors in several schedulers step() call (#8905) 2024-07-19 18:58:06 -10:00
王奇勋
461efc57c5 [fix code annotation] Adjust the dimensions of the rotary positional embedding. (#8890)
* 2d rotary pos emb dim

* make style

---------

Co-authored-by: haofanwang <haofanwang.ai@gmail.com>
2024-07-19 18:57:36 -10:00
shinetzh
3b04cdc816 fix loop bug in SlicedAttnProcessor (#8836)
* fix loop bug in SlicedAttnProcessor


---------

Co-authored-by: neoshang <neoshang@tencent.com>
2024-07-19 18:14:29 -10:00
Álvaro Somoza
c009c203be [SDXL] Fix uncaught error with image to image (#8856)
* initial commit

* apply suggestion to sdxl pipelines

* apply fix to sd pipelines
2024-07-19 12:06:36 -10:00
Dhruv Nair
3f1411767b SSH into cpu runner additional fix (#8893)
* update

* update

* update
2024-07-18 16:18:45 +05:30
20 changed files with 96 additions and 28 deletions

View File

@@ -30,10 +30,6 @@ jobs:
with:
fetch-depth: 2
- name: NVIDIA-SMI
run: |
nvidia-smi
- name: Tailscale # In order to be able to SSH when a test fails
uses: huggingface/tailscale-action@main
with:

View File

@@ -103,12 +103,12 @@ results["google_ddpm_ema_cat_256"] = torch.tensor([
models = api.list_models(filter="diffusers")
for mod in models:
if "google" in mod.author or mod.modelId == "CompVis/ldm-celebahq-256":
local_checkpoint = "/home/patrick/google_checkpoints/" + mod.modelId.split("/")[-1]
if "google" in mod.author or mod.id == "CompVis/ldm-celebahq-256":
local_checkpoint = "/home/patrick/google_checkpoints/" + mod.id.split("/")[-1]
print(f"Started running {mod.modelId}!!!")
print(f"Started running {mod.id}!!!")
if mod.modelId.startswith("CompVis"):
if mod.id.startswith("CompVis"):
model = UNet2DModel.from_pretrained(local_checkpoint, subfolder="unet")
else:
model = UNet2DModel.from_pretrained(local_checkpoint)
@@ -122,6 +122,6 @@ for mod in models:
logits = model(noise, time_step).sample
assert torch.allclose(
logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3
logits[0, 0, 0, :30], results["_".join("_".join(mod.id.split("/")).split("-"))], atol=1e-3
)
print(f"{mod.modelId} has passed successfully!!!")
print(f"{mod.id} has passed successfully!!!")

View File

@@ -2190,7 +2190,7 @@ class SlicedAttnProcessor:
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)
for i in range(batch_size_attention // self.slice_size):
for i in range((batch_size_attention - 1) // self.slice_size + 1):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
@@ -2287,7 +2287,7 @@ class SlicedAttnAddedKVProcessor:
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)
for i in range(batch_size_attention // self.slice_size):
for i in range((batch_size_attention - 1) // self.slice_size + 1):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size

View File

@@ -319,12 +319,16 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
assert embed_dim % 4 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
emb_h = get_1d_rotary_pos_embed(
embed_dim // 2, grid[0].reshape(-1), use_real=use_real
) # (H*W, D/2) if use_real else (H*W, D/4)
emb_w = get_1d_rotary_pos_embed(
embed_dim // 2, grid[1].reshape(-1), use_real=use_real
) # (H*W, D/2) if use_real else (H*W, D/4)
if use_real:
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D)
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D)
return cos, sin
else:
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
@@ -371,6 +375,8 @@ def get_1d_rotary_pos_embed(
Returns:
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
assert dim % 2 == 0
if isinstance(pos, int):
pos = np.arange(pos)
theta = theta * ntk_factor

View File

@@ -824,6 +824,13 @@ class StableDiffusionControlNetImg2ImgPipeline(
)
elif isinstance(generator, list):
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
)
init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)

View File

@@ -930,6 +930,13 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
)
elif isinstance(generator, list):
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
)
init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)

View File

@@ -528,6 +528,13 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
)
elif isinstance(generator, list):
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
)
init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)

View File

@@ -520,6 +520,13 @@ class LatentConsistencyModelImg2ImgPipeline(
)
elif isinstance(generator, list):
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
)
init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)

View File

@@ -719,6 +719,13 @@ class StableDiffusionXLPAGImg2ImgPipeline(
)
elif isinstance(generator, list):
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
)
init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)

View File

@@ -494,6 +494,13 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
)
elif isinstance(generator, list):
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
)
init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)

View File

@@ -740,6 +740,13 @@ class StableDiffusionImg2ImgPipeline(
)
elif isinstance(generator, list):
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
)
init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)

View File

@@ -710,6 +710,13 @@ class StableDiffusionXLImg2ImgPipeline(
)
elif isinstance(generator, list):
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
)
init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)

View File

@@ -674,7 +674,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
timestep: int,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
@@ -685,7 +685,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.

View File

@@ -920,7 +920,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
timestep: int,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
generator=None,
variance_noise: Optional[torch.Tensor] = None,

View File

@@ -787,7 +787,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
timestep: int,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
generator=None,
variance_noise: Optional[torch.Tensor] = None,

View File

@@ -927,7 +927,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
timestep: int,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
generator=None,
return_dict: bool = True,

View File

@@ -594,7 +594,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
timestep: int,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
generator=None,
return_dict: bool = True,

View File

@@ -138,7 +138,7 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
timestep: int,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:

View File

@@ -822,7 +822,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
timestep: int,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:

View File

@@ -1351,14 +1351,24 @@ class PipelineTesterMixin:
pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing = pipe(**inputs)[0]
output_with_slicing1 = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]
if test_max_difference:
max_diff = np.abs(to_np(output_with_slicing) - to_np(output_without_slicing)).max()
self.assertLess(max_diff, expected_max_diff, "Attention slicing should not affect the inference results")
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
self.assertLess(
max(max_diff1, max_diff2),
expected_max_diff,
"Attention slicing should not affect the inference results",
)
if test_mean_pixel_difference:
assert_mean_pixel_difference(to_np(output_with_slicing[0]), to_np(output_without_slicing[0]))
assert_mean_pixel_difference(to_np(output_with_slicing1[0]), to_np(output_without_slicing[0]))
assert_mean_pixel_difference(to_np(output_with_slicing2[0]), to_np(output_without_slicing[0]))
@unittest.skipIf(
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),