Compare commits

...

2 Commits

Author SHA1 Message Date
yiyixuxu
fb6ec06a39 style etc 2026-01-22 03:14:15 +01:00
yiyixuxu
ea63cccb8c add modular test and loading from standard repo 2026-01-22 03:13:32 +01:00
5 changed files with 174 additions and 0 deletions

View File

@@ -37,9 +37,14 @@ class TestFluxModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = FluxModularPipeline
pipeline_blocks_class = FluxAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-modular"
default_repo_id = "black-forest-labs/FLUX.1-dev"
params = frozenset(["prompt", "height", "width", "guidance_scale"])
batch_params = frozenset(["prompt"])
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(["prompt", "max_sequence_length"])
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = None # None if vae_encoder is not supported
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
@@ -63,10 +68,21 @@ class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = FluxModularPipeline
pipeline_blocks_class = FluxAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-modular"
default_repo_id = "black-forest-labs/FLUX.1-dev"
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
batch_params = frozenset(["prompt", "image"])
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(
[
"prompt",
"max_sequence_length",
]
)
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = frozenset(["image", "height", "width"])
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
pipeline = super().get_pipeline(components_manager, torch_dtype)
@@ -129,9 +145,13 @@ class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = FluxKontextModularPipeline
pipeline_blocks_class = FluxKontextAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-kontext-pipe"
default_repo_id = "black-forest-labs/FLUX.1-kontext-dev"
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
batch_params = frozenset(["prompt", "image"])
text_encoder_block_params = frozenset(["prompt", "max_sequence_length"])
decode_block_params = frozenset(["latents"])
vae_encoder_block_params = frozenset(["image", "height", "width"])
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)

View File

@@ -32,9 +32,14 @@ class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = Flux2ModularPipeline
pipeline_blocks_class = Flux2AutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular"
default_repo_id = "black-forest-labs/FLUX.2-dev"
params = frozenset(["prompt", "height", "width", "guidance_scale"])
batch_params = frozenset(["prompt"])
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(["prompt", "max_sequence_length", "text_encoder_out_layers"])
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = None
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
@@ -63,6 +68,10 @@ class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
batch_params = frozenset(["prompt", "image"])
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(["prompt", "max_sequence_length", "text_encoder_out_layers"])
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = frozenset(["image", "height", "width"])
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)

View File

@@ -34,10 +34,16 @@ class TestQwenImageModularPipelineFast(ModularPipelineTesterMixin, ModularGuider
pipeline_class = QwenImageModularPipeline
pipeline_blocks_class = QwenImageAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-modular"
default_repo_id = "Qwen/Qwen-Image"
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(["prompt", "negative_prompt", "max_sequence_length"])
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = None # None if vae_encoder is not supported
def get_dummy_inputs(self):
generator = self.get_generator()
inputs = {
@@ -60,10 +66,16 @@ class TestQwenImageEditModularPipelineFast(ModularPipelineTesterMixin, ModularGu
pipeline_class = QwenImageEditModularPipeline
pipeline_blocks_class = QwenImageEditAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-edit-modular"
default_repo_id = "Qwen/Qwen-Image-Edit"
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(["prompt", "negative_prompt", "max_sequence_length"])
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = frozenset(["image", "height", "width"])
def get_dummy_inputs(self):
generator = self.get_generator()
inputs = {
@@ -86,6 +98,7 @@ class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, Modul
pipeline_class = QwenImageEditPlusModularPipeline
pipeline_blocks_class = QwenImageEditPlusAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-edit-plus-modular"
default_repo_id = "Qwen/Qwen-Image-Edit-2509"
# No `mask_image` yet.
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image"])

View File

@@ -279,6 +279,8 @@ class TestSDXLModularPipelineFast(
pipeline_class = StableDiffusionXLModularPipeline
pipeline_blocks_class = StableDiffusionXLAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-sdxl-modular"
default_repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
params = frozenset(
[
"prompt",
@@ -291,6 +293,11 @@ class TestSDXLModularPipelineFast(
batch_params = frozenset(["prompt", "negative_prompt"])
expected_image_output_shape = (1, 3, 64, 64)
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(["prompt"])
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = None # None if vae_encoder is not supported
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
@@ -339,6 +346,11 @@ class TestSDXLImg2ImgModularPipelineFast(
batch_params = frozenset(["prompt", "negative_prompt", "image"])
expected_image_output_shape = (1, 3, 64, 64)
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(["prompt"])
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = frozenset(["image"])
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {

View File

@@ -48,6 +48,12 @@ class ModularPipelineTesterMixin:
"You need to set the attribute `pretrained_model_name_or_path` in the child test class. See existing pipeline tests for reference."
)
@property
def default_repo_id(self) -> str:
raise NotImplementedError(
"You need to set the attribute `default_repo_id` in the child test class. See existing pipeline tests for reference."
)
@property
def pipeline_blocks_class(self) -> Union[Callable, ModularPipelineBlocks]:
raise NotImplementedError(
@@ -90,6 +96,30 @@ class ModularPipelineTesterMixin:
"See existing pipeline tests for reference."
)
def text_encoder_block_params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `text_encoder_block_params` in the child test class. "
"`text_encoder_block_params` are the parameters required to be passed to the text encoder block. "
" if should be a subset of the parameters returned by `get_dummy_inputs`"
"See existing pipeline tests for reference."
)
def decode_block_params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `decode_block_params` in the child test class. "
"`decode_block_params` are the parameters required to be passed to the decode block. "
" if should be a subset of the parameters returned by `get_dummy_inputs`"
"See existing pipeline tests for reference."
)
def vae_encoder_block_params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `vae_encoder_block_params` in the child test class. "
"`vae_encoder_block_params` are the parameters required to be passed to the vae encoder block. "
" if should be a subset of the parameters returned by `get_dummy_inputs`"
"See existing pipeline tests for reference."
)
def setup_method(self):
# clean up the VRAM before each test
torch.compiler.reset()
@@ -124,6 +154,96 @@ class ModularPipelineTesterMixin:
_check_for_parameters(self.params, input_parameters, "input")
_check_for_parameters(self.optional_params, optional_parameters, "optional")
def test_loading_from_default_repo(self):
if self.default_repo_id is None:
return
try:
pipe = ModularPipeline.from_pretrained(self.default_repo_id)
assert pipe.blocks.__class__ == self.pipeline_blocks_class
except Exception as e:
assert False, f"Failed to load pipeline from default repo: {e}"
def test_modular_inference(self):
# run the pipeline to get the base output for comparison
pipe = self.get_pipeline()
pipe.to(torch_device, torch.float32)
inputs = self.get_dummy_inputs()
standard_output = pipe(**inputs, output="images")
# create text, denoise, decoder (and optional vae encoder) nodes
blocks = self.pipeline_blocks_class()
assert "text_encoder" in blocks.sub_blocks, "`text_encoder` block is not present in the pipeline"
assert "denoise" in blocks.sub_blocks, "`denoise` block is not present in the pipeline"
assert "decode" in blocks.sub_blocks, "`decode` block is not present in the pipeline"
if self.vae_encoder_block_params is not None:
assert "vae_encoder" in blocks.sub_blocks, "`vae_encoder` block is not present in the pipeline"
# manually set the components in the sub_pipe
# a hack to workaround the fact the default pipeline properties are often incorrect for testing cases,
# #e.g. vae_scale_factor is ususally not 8 because vae is configured to be smaller for testing
def manually_set_all_components(pipe: ModularPipeline, sub_pipe: ModularPipeline):
for n, comp in pipe.components.items():
if not hasattr(sub_pipe, n):
setattr(sub_pipe, n, comp)
text_node = blocks.sub_blocks["text_encoder"].init_pipeline(self.pretrained_model_name_or_path)
text_node.load_components(torch_dtype=torch.float32)
text_node.to(torch_device)
manually_set_all_components(pipe, text_node)
denoise_node = blocks.sub_blocks["denoise"].init_pipeline(self.pretrained_model_name_or_path)
denoise_node.load_components(torch_dtype=torch.float32)
denoise_node.to(torch_device)
manually_set_all_components(pipe, denoise_node)
decoder_node = blocks.sub_blocks["decode"].init_pipeline(self.pretrained_model_name_or_path)
decoder_node.load_components(torch_dtype=torch.float32)
decoder_node.to(torch_device)
manually_set_all_components(pipe, decoder_node)
if self.vae_encoder_block_params is not None:
vae_encoder_node = blocks.sub_blocks["vae_encoder"].init_pipeline(self.pretrained_model_name_or_path)
vae_encoder_node.load_components(torch_dtype=torch.float32)
vae_encoder_node.to(torch_device)
manually_set_all_components(pipe, vae_encoder_node)
else:
vae_encoder_node = None
# prepare inputs for each node
inputs = self.get_dummy_inputs()
def get_block_inputs(inputs: dict, block_params: frozenset) -> tuple[dict, dict]:
block_inputs = {}
for name in block_params:
if name in inputs:
block_inputs[name] = inputs.pop(name)
return block_inputs, inputs
text_inputs, inputs = get_block_inputs(inputs, self.text_encoder_block_params)
decoder_inputs, inputs = get_block_inputs(inputs, self.decode_block_params)
if vae_encoder_node is not None:
vae_encoder_inputs, inputs = get_block_inputs(inputs, self.vae_encoder_block_params)
# this is also to make sure pipelines mark text outputs as denoiser_input_fields
text_output = text_node(**text_inputs).get_by_kwargs("denoiser_input_fields")
if vae_encoder_node is not None:
vae_encoder_output = vae_encoder_node(**vae_encoder_inputs).values
denoise_inputs = {**text_output, **vae_encoder_output, **inputs}
else:
denoise_inputs = {**text_output, **inputs}
# denoise node output should be "latents"
latents = denoise_node(**denoise_inputs).latents
# denoder node input should be "latents" and output should be "images"
modular_output = decoder_node(**decoder_inputs, latents=latents).images
assert modular_output.shape == standard_output.shape, (
f"Modular output should have same shape as standard output {standard_output.shape}, but got {modular_output.shape}"
)
def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True):
pipe = self.get_pipeline().to(torch_device)