Compare commits

...

12 Commits

Author SHA1 Message Date
Sayak Paul
d287428460 Update tests/pipelines/qwenimage/test_qwenimage_edit_plus.py
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2025-12-02 19:08:05 +05:30
Sayak Paul
93a470be0a Apply suggestions from code review
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2025-12-02 19:06:48 +05:30
Sayak Paul
c541f9a49f Merge branch 'main' into prompt-isolation-tests-qwen 2025-12-02 18:56:19 +05:30
Sayak Paul
ba120269dc Merge branch 'main' into prompt-isolation-tests-qwen 2025-11-26 15:24:43 +05:30
Sayak Paul
d29bfb78eb Merge branch 'main' into prompt-isolation-tests-qwen 2025-09-28 16:48:36 +05:30
sayakpaul
290b28354d up 2025-09-28 16:33:56 +05:30
sayakpaul
16544bbec3 revert pipeline changes. 2025-09-28 16:23:23 +05:30
sayakpaul
b26f7fc82f revert lora utils tests. 2025-09-28 16:22:34 +05:30
sayakpaul
f84b0ab796 up 2025-09-28 16:19:17 +05:30
sayakpaul
1185f82450 up 2025-09-28 16:18:35 +05:30
sayakpaul
a9d50c8f2a up 2025-09-26 22:42:52 +05:30
sayakpaul
f82c1523e5 fix prompt isolation test. 2025-09-26 18:50:26 +05:30
3 changed files with 54 additions and 38 deletions

View File

@@ -15,7 +15,6 @@
import unittest import unittest
import numpy as np import numpy as np
import pytest
import torch import torch
from PIL import Image from PIL import Image
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
@@ -134,15 +133,17 @@ class QwenImageEditPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
else: else:
generator = torch.Generator(device=device).manual_seed(seed) generator = torch.Generator(device=device).manual_seed(seed)
# Even if we specify smaller dimensions for the images, it won't work because of how
# the internal implementation enforces a minimal resolution of 1024x1024.
inputs = { inputs = {
"prompt": "dance monkey", "prompt": "dance monkey",
"image": Image.new("RGB", (32, 32)), "image": Image.new("RGB", (1024, 1024)),
"negative_prompt": "bad quality", "negative_prompt": "bad quality",
"generator": generator, "generator": generator,
"num_inference_steps": 2, "num_inference_steps": 2,
"true_cfg_scale": 1.0, "true_cfg_scale": 1.0,
"height": 32, "height": 1024,
"width": 32, "width": 1024,
"max_sequence_length": 16, "max_sequence_length": 16,
"output_type": "pt", "output_type": "pt",
} }
@@ -238,6 +239,11 @@ class QwenImageEditPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"VAE tiling should not affect the inference results", "VAE tiling should not affect the inference results",
) )
@pytest.mark.xfail(condition=True, reason="Preconfigured embeddings need to be revisited.", strict=True) def test_encode_prompt_works_in_isolation(
def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4): self, extra_required_param_value_dict=None, keep_params=None, atol=1e-4, rtol=1e-4
super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, atol, rtol) ):
# We include `image` because it's needed in both `encode_prompt` and some other subsequent calculations.
# `max_sequence_length` to maintain parity between its value during all invocations of `encode_prompt`
# in the following test.
keep_params = ["image", "max_sequence_length"]
super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, keep_params, atol, rtol)

View File

@@ -134,7 +134,9 @@ class QwenImageEditPlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
else: else:
generator = torch.Generator(device=device).manual_seed(seed) generator = torch.Generator(device=device).manual_seed(seed)
image = Image.new("RGB", (32, 32)) # Even if we specify smaller dimensions for the images, it won't work because of how
# the internal implementation enforces a minimal resolution of 384*384.
image = Image.new("RGB", (384, 384))
inputs = { inputs = {
"prompt": "dance monkey", "prompt": "dance monkey",
"image": [image, image], "image": [image, image],
@@ -142,8 +144,8 @@ class QwenImageEditPlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
"generator": generator, "generator": generator,
"num_inference_steps": 2, "num_inference_steps": 2,
"true_cfg_scale": 1.0, "true_cfg_scale": 1.0,
"height": 32, "height": 384,
"width": 32, "width": 384,
"max_sequence_length": 16, "max_sequence_length": 16,
"output_type": "pt", "output_type": "pt",
} }
@@ -236,9 +238,14 @@ class QwenImageEditPlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
"VAE tiling should not affect the inference results", "VAE tiling should not affect the inference results",
) )
@pytest.mark.xfail(condition=True, reason="Preconfigured embeddings need to be revisited.", strict=True) def test_encode_prompt_works_in_isolation(
def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4): self, extra_required_param_value_dict=None, keep_params=None, atol=1e-4, rtol=1e-4
super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, atol, rtol) ):
# We include `image` because it's needed in both `encode_prompt` and some other subsequent calculations.
# `max_sequence_length` to maintain parity between its value during all invocations of `encode_prompt`
# in the following test.
keep_params = ["image", "max_sequence_length"]
super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, keep_params, atol, rtol)
@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True) @pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
def test_num_images_per_prompt(): def test_num_images_per_prompt():

View File

@@ -5,7 +5,7 @@ import os
import tempfile import tempfile
import unittest import unittest
import uuid import uuid
from typing import Any, Callable, Dict, Union from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
@@ -2082,20 +2082,26 @@ class PipelineTesterMixin:
assert f"You are trying to load the model files of the `variant={variant}`" in str(error.exception) assert f"You are trying to load the model files of the `variant={variant}`" in str(error.exception)
def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4): def test_encode_prompt_works_in_isolation(
self,
extra_required_param_value_dict: Optional[Dict] = None,
keep_params: Optional[List] = None,
atol=1e-4,
rtol=1e-4,
):
if not hasattr(self.pipeline_class, "encode_prompt"): if not hasattr(self.pipeline_class, "encode_prompt"):
return return
components = self.get_dummy_components() components = self.get_dummy_components()
def _contains_text_key(name):
return any(token in name for token in ("text", "tokenizer", "processor"))
# We initialize the pipeline with only text encoders and tokenizers, # We initialize the pipeline with only text encoders and tokenizers,
# mimicking a real-world scenario. # mimicking a real-world scenario.
components_with_text_encoders = {} components_with_text_encoders = {
for k in components: name: component if _contains_text_key(name) else None for name, component in components.items()
if "text" in k or "tokenizer" in k: }
components_with_text_encoders[k] = components[k]
else:
components_with_text_encoders[k] = None
pipe_with_just_text_encoder = self.pipeline_class(**components_with_text_encoders) pipe_with_just_text_encoder = self.pipeline_class(**components_with_text_encoders)
pipe_with_just_text_encoder = pipe_with_just_text_encoder.to(torch_device) pipe_with_just_text_encoder = pipe_with_just_text_encoder.to(torch_device)
@@ -2105,17 +2111,18 @@ class PipelineTesterMixin:
encode_prompt_parameters = list(encode_prompt_signature.parameters.values()) encode_prompt_parameters = list(encode_prompt_signature.parameters.values())
# Required args in encode_prompt with those with no default. # Required args in encode_prompt with those with no default.
required_params = [] required_params = [
for param in encode_prompt_parameters: param.name
if param.name == "self" or param.name == "kwargs": for param in encode_prompt_parameters
continue if param.name not in {"self", "kwargs"} and param.default is inspect.Parameter.empty
if param.default is inspect.Parameter.empty: ]
required_params.append(param.name)
# Craft inputs for the `encode_prompt()` method to run in isolation. # Craft inputs for the `encode_prompt()` method to run in isolation.
encode_prompt_param_names = [p.name for p in encode_prompt_parameters if p.name != "self"] encode_prompt_param_names = [p.name for p in encode_prompt_parameters if p.name != "self"]
input_keys = list(inputs.keys()) encode_prompt_inputs = {name: inputs[name] for name in encode_prompt_param_names if name in inputs}
encode_prompt_inputs = {k: inputs.pop(k) for k in input_keys if k in encode_prompt_param_names} for name in encode_prompt_param_names:
if name in inputs and (not keep_params or name not in keep_params):
inputs.pop(name)
pipe_call_signature = inspect.signature(pipe_with_just_text_encoder.__call__) pipe_call_signature = inspect.signature(pipe_with_just_text_encoder.__call__)
pipe_call_parameters = pipe_call_signature.parameters pipe_call_parameters = pipe_call_signature.parameters
@@ -2150,18 +2157,15 @@ class PipelineTesterMixin:
# Pack the outputs of `encode_prompt`. # Pack the outputs of `encode_prompt`.
adapted_prompt_embeds_kwargs = { adapted_prompt_embeds_kwargs = {
k: prompt_embeds_kwargs.pop(k) for k in list(prompt_embeds_kwargs.keys()) if k in pipe_call_parameters name: prompt_embeds_kwargs[name] for name in prompt_embeds_kwargs if name in pipe_call_parameters
} }
# now initialize a pipeline without text encoders and compute outputs with the # now initialize a pipeline without text encoders and compute outputs with the
# `encode_prompt()` outputs and other relevant inputs. # `encode_prompt()` outputs and other relevant inputs.
components_with_text_encoders = {} components_without_text_encoders = {
for k in components: name: None if _contains_text_key(name) else component for name, component in components.items()
if "text" in k or "tokenizer" in k: }
components_with_text_encoders[k] = None pipe_without_text_encoders = self.pipeline_class(**components_without_text_encoders).to(torch_device)
else:
components_with_text_encoders[k] = components[k]
pipe_without_text_encoders = self.pipeline_class(**components_with_text_encoders).to(torch_device)
# Set `negative_prompt` to None as we have already calculated its embeds # Set `negative_prompt` to None as we have already calculated its embeds
# if it was present in `inputs`. This is because otherwise we will interfere wrongly # if it was present in `inputs`. This is because otherwise we will interfere wrongly
@@ -2181,7 +2185,6 @@ class PipelineTesterMixin:
and pipe_call_parameters.get("prompt_embeds").default is None and pipe_call_parameters.get("prompt_embeds").default is None
): ):
pipe_without_tes_inputs.update({"prompt": None}) pipe_without_tes_inputs.update({"prompt": None})
pipe_out = pipe_without_text_encoders(**pipe_without_tes_inputs)[0] pipe_out = pipe_without_text_encoders(**pipe_without_tes_inputs)[0]
# Compare against regular pipeline outputs. # Compare against regular pipeline outputs.