mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-05 18:35:23 +08:00
Compare commits
14 Commits
wan-modula
...
modular-te
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c91835c943 | ||
|
|
98b3a31259 | ||
|
|
4c1a5bcfeb | ||
|
|
027394d392 | ||
|
|
5c378a9415 | ||
|
|
f34cc7b344 | ||
|
|
24c4b1c47d | ||
|
|
13c922972e | ||
|
|
f4d27b9a8a | ||
|
|
1a2e736166 | ||
|
|
c293ad7899 | ||
|
|
2c7f5d7421 | ||
|
|
fb6ec06a39 | ||
|
|
ea63cccb8c |
@@ -302,7 +302,7 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("prompt", required=True),
|
||||
InputParam("prompt_2"),
|
||||
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
|
||||
InputParam("joint_attention_kwargs"),
|
||||
|
||||
@@ -80,7 +80,7 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("prompt", required=True),
|
||||
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
|
||||
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(10, 20, 30), required=False),
|
||||
]
|
||||
@@ -99,7 +99,7 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
prompt = block_state.prompt
|
||||
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@staticmethod
|
||||
@@ -193,7 +193,7 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("prompt", required=True),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -210,7 +210,7 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
prompt = block_state.prompt
|
||||
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -270,7 +270,7 @@ class Flux2KleinTextEncoderStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("prompt", required=True),
|
||||
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
|
||||
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(9, 18, 27), required=False),
|
||||
]
|
||||
@@ -290,7 +290,7 @@ class Flux2KleinTextEncoderStep(ModularPipelineBlocks):
|
||||
def check_inputs(block_state):
|
||||
prompt = block_state.prompt
|
||||
|
||||
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@staticmethod
|
||||
@@ -405,7 +405,7 @@ class Flux2KleinBaseTextEncoderStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("prompt", required=True),
|
||||
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
|
||||
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(9, 18, 27), required=False),
|
||||
]
|
||||
@@ -431,7 +431,7 @@ class Flux2KleinBaseTextEncoderStep(ModularPipelineBlocks):
|
||||
def check_inputs(block_state):
|
||||
prompt = block_state.prompt
|
||||
|
||||
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -56,52 +56,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# ====================
|
||||
# 1. TEXT ENCODER
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class QwenImageAutoTextEncoderStep(AutoPipelineBlocks):
|
||||
"""
|
||||
Text encoder step that encodes the text prompt into a text embedding. This is an auto pipeline block.
|
||||
|
||||
Components:
|
||||
text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use tokenizer (`Qwen2Tokenizer`):
|
||||
The tokenizer to use guider (`ClassifierFreeGuidance`)
|
||||
|
||||
Inputs:
|
||||
prompt (`str`, *optional*):
|
||||
The prompt or prompts to guide image generation.
|
||||
negative_prompt (`str`, *optional*):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
max_sequence_length (`int`, *optional*, defaults to 1024):
|
||||
Maximum sequence length for prompt encoding.
|
||||
|
||||
Outputs:
|
||||
prompt_embeds (`Tensor`):
|
||||
The prompt embeddings.
|
||||
prompt_embeds_mask (`Tensor`):
|
||||
The encoder attention mask.
|
||||
negative_prompt_embeds (`Tensor`):
|
||||
The negative prompt embeddings.
|
||||
negative_prompt_embeds_mask (`Tensor`):
|
||||
The negative prompt embeddings mask.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage"
|
||||
block_classes = [QwenImageTextEncoderStep()]
|
||||
block_names = ["text_encoder"]
|
||||
block_trigger_inputs = ["prompt"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Text encoder step that encodes the text prompt into a text embedding. This is an auto pipeline block."
|
||||
" - `QwenImageTextEncoderStep` (text_encoder) is used when `prompt` is provided."
|
||||
" - if `prompt` is not provided, step will be skipped."
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. VAE ENCODER
|
||||
# 1. VAE ENCODER
|
||||
# ====================
|
||||
|
||||
|
||||
@@ -249,7 +204,7 @@ class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks):
|
||||
|
||||
|
||||
# ====================
|
||||
# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise)
|
||||
# 2. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise)
|
||||
# ====================
|
||||
|
||||
|
||||
@@ -1011,7 +966,7 @@ class QwenImageAutoCoreDenoiseStep(ConditionalPipelineBlocks):
|
||||
|
||||
|
||||
# ====================
|
||||
# 4. DECODE
|
||||
# 3. DECODE
|
||||
# ====================
|
||||
|
||||
|
||||
@@ -1096,11 +1051,11 @@ class QwenImageAutoDecodeStep(AutoPipelineBlocks):
|
||||
|
||||
|
||||
# ====================
|
||||
# 5. AUTO BLOCKS & PRESETS
|
||||
# 4. AUTO BLOCKS & PRESETS
|
||||
# ====================
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageAutoTextEncoderStep()),
|
||||
("text_encoder", QwenImageTextEncoderStep()),
|
||||
("vae_encoder", QwenImageAutoVaeEncoderStep()),
|
||||
("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()),
|
||||
("denoise", QwenImageAutoCoreDenoiseStep()),
|
||||
|
||||
@@ -244,7 +244,7 @@ class StableDiffusionXLTextEncoderStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("prompt", required=True),
|
||||
InputParam("prompt_2"),
|
||||
InputParam("negative_prompt"),
|
||||
InputParam("negative_prompt_2"),
|
||||
|
||||
@@ -179,7 +179,7 @@ class WanTextEncoderStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("prompt", required=True),
|
||||
InputParam("negative_prompt"),
|
||||
InputParam("max_sequence_length", default=512),
|
||||
]
|
||||
|
||||
@@ -149,7 +149,7 @@ class ZImageTextEncoderStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("prompt", required=True),
|
||||
InputParam("negative_prompt"),
|
||||
InputParam("max_sequence_length", default=512),
|
||||
]
|
||||
|
||||
@@ -37,6 +37,7 @@ class TestFluxModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = FluxModularPipeline
|
||||
pipeline_blocks_class = FluxAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-modular"
|
||||
default_repo_id = "hf-internal-testing/tiny-flux-pipe"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
@@ -63,6 +64,7 @@ class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = FluxModularPipeline
|
||||
pipeline_blocks_class = FluxAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-modular"
|
||||
default_repo_id = "hf-internal-testing/tiny-flux-pipe"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
@@ -129,6 +131,7 @@ class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = FluxKontextModularPipeline
|
||||
pipeline_blocks_class = FluxKontextAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-kontext-pipe"
|
||||
default_repo_id = "hf-internal-testing/tiny-flux-kontext-pipe"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
|
||||
@@ -32,6 +32,8 @@ class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2ModularPipeline
|
||||
pipeline_blocks_class = Flux2AutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular"
|
||||
default_repo_id = "black-forest-labs/FLUX.2-dev"
|
||||
default_repo_id = "hf-internal-testing/tiny-flux2"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
@@ -60,6 +62,7 @@ class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2ModularPipeline
|
||||
pipeline_blocks_class = Flux2AutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular"
|
||||
default_repo_id = "hf-internal-testing/tiny-flux2"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
|
||||
@@ -32,6 +32,7 @@ class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2KleinModularPipeline
|
||||
pipeline_blocks_class = Flux2KleinAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-modular"
|
||||
default_repo_id = None # TODO
|
||||
|
||||
params = frozenset(["prompt", "height", "width"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
@@ -59,6 +60,7 @@ class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2KleinModularPipeline
|
||||
pipeline_blocks_class = Flux2KleinAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-modular"
|
||||
default_repo_id = None # TODO
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
|
||||
@@ -32,7 +32,7 @@ class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2KleinModularPipeline
|
||||
pipeline_blocks_class = Flux2KleinBaseAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-base-modular"
|
||||
|
||||
default_repo_id = "hf-internal-testing/tiny-flux2-klein"
|
||||
params = frozenset(["prompt", "height", "width"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
|
||||
@@ -59,6 +59,7 @@ class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2KleinModularPipeline
|
||||
pipeline_blocks_class = Flux2KleinBaseAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-base-modular"
|
||||
default_repo_id = "hf-internal-testing/tiny-flux2-klein"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
|
||||
@@ -34,6 +34,7 @@ class TestQwenImageModularPipelineFast(ModularPipelineTesterMixin, ModularGuider
|
||||
pipeline_class = QwenImageModularPipeline
|
||||
pipeline_blocks_class = QwenImageAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-modular"
|
||||
default_repo_id = "Qwen/Qwen-Image"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
|
||||
@@ -60,6 +61,7 @@ class TestQwenImageEditModularPipelineFast(ModularPipelineTesterMixin, ModularGu
|
||||
pipeline_class = QwenImageEditModularPipeline
|
||||
pipeline_blocks_class = QwenImageEditAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-edit-modular"
|
||||
default_repo_id = "Qwen/Qwen-Image-Edit"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
|
||||
@@ -86,6 +88,7 @@ class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, Modul
|
||||
pipeline_class = QwenImageEditPlusModularPipeline
|
||||
pipeline_blocks_class = QwenImageEditPlusAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-edit-plus-modular"
|
||||
default_repo_id = "Qwen/Qwen-Image-Edit-2509"
|
||||
|
||||
# No `mask_image` yet.
|
||||
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image"])
|
||||
|
||||
@@ -279,6 +279,8 @@ class TestSDXLModularPipelineFast(
|
||||
pipeline_class = StableDiffusionXLModularPipeline
|
||||
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-sdxl-modular"
|
||||
default_repo_id = "hf-internal-testing/tiny-sdxl-pipe"
|
||||
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
@@ -326,6 +328,7 @@ class TestSDXLImg2ImgModularPipelineFast(
|
||||
pipeline_class = StableDiffusionXLModularPipeline
|
||||
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-sdxl-modular"
|
||||
default_repo_id = "hf-internal-testing/tiny-sdxl-pipe"
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
@@ -379,6 +382,7 @@ class SDXLInpaintingModularPipelineFastTests(
|
||||
pipeline_class = StableDiffusionXLModularPipeline
|
||||
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-sdxl-modular"
|
||||
default_repo_id = "hf-internal-testing/tiny-sdxl-pipe"
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
|
||||
@@ -37,6 +37,8 @@ class ModularPipelineTesterMixin:
|
||||
optional_params = frozenset(["num_inference_steps", "num_images_per_prompt", "latents", "output_type"])
|
||||
# this is modular specific: generator needs to be a intermediate input because it's mutable
|
||||
intermediate_params = frozenset(["generator"])
|
||||
# prompt is required for most pipeline, with exceptions like qwen-image layer
|
||||
required_params = frozenset(["prompt"])
|
||||
|
||||
def get_generator(self, seed=0):
|
||||
generator = torch.Generator("cpu").manual_seed(seed)
|
||||
@@ -55,6 +57,12 @@ class ModularPipelineTesterMixin:
|
||||
"You need to set the attribute `pretrained_model_name_or_path` in the child test class. See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def default_repo_id(self) -> str:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `default_repo_id` in the child test class. See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def pipeline_blocks_class(self) -> Union[Callable, ModularPipelineBlocks]:
|
||||
raise NotImplementedError(
|
||||
@@ -121,6 +129,7 @@ class ModularPipelineTesterMixin:
|
||||
pipe = self.get_pipeline()
|
||||
input_parameters = pipe.blocks.input_names
|
||||
optional_parameters = pipe.default_call_parameters
|
||||
required_parameters = pipe.blocks.required_inputs
|
||||
|
||||
def _check_for_parameters(parameters, expected_parameters, param_type):
|
||||
remaining_parameters = {param for param in parameters if param not in expected_parameters}
|
||||
@@ -130,6 +139,98 @@ class ModularPipelineTesterMixin:
|
||||
|
||||
_check_for_parameters(self.params, input_parameters, "input")
|
||||
_check_for_parameters(self.optional_params, optional_parameters, "optional")
|
||||
_check_for_parameters(self.required_params, required_parameters, "required")
|
||||
|
||||
def test_loading_from_default_repo(self):
|
||||
if self.default_repo_id is None:
|
||||
return
|
||||
|
||||
try:
|
||||
pipe = ModularPipeline.from_pretrained(self.default_repo_id)
|
||||
assert pipe.blocks.__class__ == self.pipeline_blocks_class
|
||||
except Exception as e:
|
||||
assert False, f"Failed to load pipeline from default repo: {e}"
|
||||
|
||||
def test_modular_inference(self):
|
||||
# run the pipeline to get the base output for comparison
|
||||
pipe = self.get_pipeline()
|
||||
pipe.to(torch_device, torch.float32)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
standard_output = pipe(**inputs, output="images")
|
||||
|
||||
# create text, denoise, decoder (and optional vae encoder) nodes
|
||||
blocks = self.pipeline_blocks_class()
|
||||
|
||||
assert "text_encoder" in blocks.sub_blocks, "`text_encoder` block is not present in the pipeline"
|
||||
assert "denoise" in blocks.sub_blocks, "`denoise` block is not present in the pipeline"
|
||||
assert "decode" in blocks.sub_blocks, "`decode` block is not present in the pipeline"
|
||||
|
||||
# manually set the components in the sub_pipe
|
||||
# a hack to workaround the fact the default pipeline properties are often incorrect for testing cases,
|
||||
# #e.g. vae_scale_factor is ususally not 8 because vae is configured to be smaller for testing
|
||||
def manually_set_all_components(pipe: ModularPipeline, sub_pipe: ModularPipeline):
|
||||
for n, comp in pipe.components.items():
|
||||
setattr(sub_pipe, n, comp)
|
||||
|
||||
# Initialize all nodes
|
||||
text_node = blocks.sub_blocks["text_encoder"].init_pipeline(self.pretrained_model_name_or_path)
|
||||
text_node.load_components(torch_dtype=torch.float32)
|
||||
text_node.to(torch_device)
|
||||
manually_set_all_components(pipe, text_node)
|
||||
|
||||
denoise_node = blocks.sub_blocks["denoise"].init_pipeline(self.pretrained_model_name_or_path)
|
||||
denoise_node.load_components(torch_dtype=torch.float32)
|
||||
denoise_node.to(torch_device)
|
||||
manually_set_all_components(pipe, denoise_node)
|
||||
|
||||
decoder_node = blocks.sub_blocks["decode"].init_pipeline(self.pretrained_model_name_or_path)
|
||||
decoder_node.load_components(torch_dtype=torch.float32)
|
||||
decoder_node.to(torch_device)
|
||||
manually_set_all_components(pipe, decoder_node)
|
||||
|
||||
if "vae_encoder" in blocks.sub_blocks:
|
||||
vae_encoder_node = blocks.sub_blocks["vae_encoder"].init_pipeline(self.pretrained_model_name_or_path)
|
||||
vae_encoder_node.load_components(torch_dtype=torch.float32)
|
||||
vae_encoder_node.to(torch_device)
|
||||
manually_set_all_components(pipe, vae_encoder_node)
|
||||
else:
|
||||
vae_encoder_node = None
|
||||
|
||||
def filter_inputs(available: dict, expected_keys) -> dict:
|
||||
return {k: v for k, v in available.items() if k in expected_keys}
|
||||
|
||||
# prepare inputs for each node
|
||||
inputs = self.get_dummy_inputs()
|
||||
|
||||
# 1. Text encoder: takes from inputs
|
||||
text_inputs = filter_inputs(inputs, text_node.blocks.input_names)
|
||||
text_output = text_node(**text_inputs)
|
||||
text_output_dict = text_output.get_by_kwargs("denoiser_input_fields")
|
||||
|
||||
# 2. VAE encoder (optional): takes from inputs + text_output
|
||||
if vae_encoder_node is not None:
|
||||
vae_available = {**inputs, **text_output_dict}
|
||||
vae_encoder_inputs = filter_inputs(vae_available, vae_encoder_node.blocks.input_names)
|
||||
vae_encoder_output = vae_encoder_node(**vae_encoder_inputs)
|
||||
vae_output_dict = vae_encoder_output.values
|
||||
else:
|
||||
vae_output_dict = {}
|
||||
|
||||
# 3. Denoise: takes from inputs + text_output + vae_output
|
||||
denoise_available = {**inputs, **text_output_dict, **vae_output_dict}
|
||||
denoise_inputs = filter_inputs(denoise_available, denoise_node.blocks.input_names)
|
||||
denoise_output = denoise_node(**denoise_inputs)
|
||||
latents = denoise_output.latents
|
||||
|
||||
# 4. Decoder: takes from inputs + denoise_output
|
||||
decode_available = {**inputs, "latents": latents}
|
||||
decode_inputs = filter_inputs(decode_available, decoder_node.blocks.input_names)
|
||||
modular_output = decoder_node(**decode_inputs).images
|
||||
|
||||
assert modular_output.shape == standard_output.shape, (
|
||||
f"Modular output should have same shape as standard output {standard_output.shape}, but got {modular_output.shape}"
|
||||
)
|
||||
|
||||
def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True):
|
||||
pipe = self.get_pipeline().to(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user