mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-22 19:45:47 +08:00
Compare commits
2 Commits
modular-te
...
modular-up
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8d42a97a40 | ||
|
|
39a6a0c171 |
@@ -1552,11 +1552,11 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
else:
|
||||
logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")
|
||||
|
||||
self.blocks = blocks
|
||||
self._blocks = blocks
|
||||
self._components_manager = components_manager
|
||||
self._collection = collection
|
||||
self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components}
|
||||
self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs}
|
||||
self._component_specs = {spec.name: deepcopy(spec) for spec in self._blocks.expected_components}
|
||||
self._config_specs = {spec.name: deepcopy(spec) for spec in self._blocks.expected_configs}
|
||||
|
||||
# update component_specs and config_specs based on modular_model_index.json
|
||||
if modular_config_dict is not None:
|
||||
@@ -1603,7 +1603,9 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
for name, config_spec in self._config_specs.items():
|
||||
default_configs[name] = config_spec.default
|
||||
self.register_to_config(**default_configs)
|
||||
self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None)
|
||||
self.register_to_config(
|
||||
_blocks_class_name=self._blocks.__class__.__name__ if self._blocks is not None else None
|
||||
)
|
||||
|
||||
@property
|
||||
def default_call_parameters(self) -> Dict[str, Any]:
|
||||
@@ -1612,7 +1614,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
- Dictionary mapping input names to their default values
|
||||
"""
|
||||
params = {}
|
||||
for input_param in self.blocks.inputs:
|
||||
for input_param in self._blocks.inputs:
|
||||
params[input_param.name] = input_param.default
|
||||
return params
|
||||
|
||||
@@ -1775,7 +1777,15 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
Returns:
|
||||
- The docstring of the pipeline blocks
|
||||
"""
|
||||
return self.blocks.doc
|
||||
return self._blocks.doc
|
||||
|
||||
@property
|
||||
def blocks(self) -> ModularPipelineBlocks:
|
||||
"""
|
||||
Returns:
|
||||
- A copy of the pipeline blocks
|
||||
"""
|
||||
return deepcopy(self._blocks)
|
||||
|
||||
def register_components(self, **kwargs):
|
||||
"""
|
||||
@@ -2509,7 +2519,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
)
|
||||
|
||||
def set_progress_bar_config(self, **kwargs):
|
||||
for sub_block_name, sub_block in self.blocks.sub_blocks.items():
|
||||
for sub_block_name, sub_block in self._blocks.sub_blocks.items():
|
||||
if hasattr(sub_block, "set_progress_bar_config"):
|
||||
sub_block.set_progress_bar_config(**kwargs)
|
||||
|
||||
@@ -2563,7 +2573,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
# Add inputs to state, using defaults if not provided in the kwargs or the state
|
||||
# if same input already in the state, will override it if provided in the kwargs
|
||||
for expected_input_param in self.blocks.inputs:
|
||||
for expected_input_param in self._blocks.inputs:
|
||||
name = expected_input_param.name
|
||||
default = expected_input_param.default
|
||||
kwargs_type = expected_input_param.kwargs_type
|
||||
@@ -2582,9 +2592,9 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
# Run the pipeline
|
||||
with torch.no_grad():
|
||||
try:
|
||||
_, state = self.blocks(self, state)
|
||||
_, state = self._blocks(self, state)
|
||||
except Exception:
|
||||
error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n"
|
||||
error_msg = f"Error in block: ({self._blocks.__class__.__name__}):\n"
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
|
||||
@@ -37,14 +37,9 @@ class TestFluxModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = FluxModularPipeline
|
||||
pipeline_blocks_class = FluxAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-modular"
|
||||
default_repo_id = "black-forest-labs/FLUX.1-dev"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
# should choose from the dict returned by `get_dummy_inputs`
|
||||
text_encoder_block_params = frozenset(["prompt", "max_sequence_length"])
|
||||
decode_block_params = frozenset(["output_type"])
|
||||
vae_encoder_block_params = None # None if vae_encoder is not supported
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
@@ -68,21 +63,10 @@ class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = FluxModularPipeline
|
||||
pipeline_blocks_class = FluxAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-modular"
|
||||
default_repo_id = "black-forest-labs/FLUX.1-dev"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
|
||||
# should choose from the dict returned by `get_dummy_inputs`
|
||||
text_encoder_block_params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
"max_sequence_length",
|
||||
]
|
||||
)
|
||||
decode_block_params = frozenset(["output_type"])
|
||||
vae_encoder_block_params = frozenset(["image", "height", "width"])
|
||||
|
||||
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
|
||||
pipeline = super().get_pipeline(components_manager, torch_dtype)
|
||||
|
||||
@@ -145,13 +129,9 @@ class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = FluxKontextModularPipeline
|
||||
pipeline_blocks_class = FluxKontextAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-kontext-pipe"
|
||||
default_repo_id = "black-forest-labs/FLUX.1-kontext-dev"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
text_encoder_block_params = frozenset(["prompt", "max_sequence_length"])
|
||||
decode_block_params = frozenset(["latents"])
|
||||
vae_encoder_block_params = frozenset(["image", "height", "width"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
|
||||
@@ -32,14 +32,9 @@ class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2ModularPipeline
|
||||
pipeline_blocks_class = Flux2AutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular"
|
||||
default_repo_id = "black-forest-labs/FLUX.2-dev"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
# should choose from the dict returned by `get_dummy_inputs`
|
||||
text_encoder_block_params = frozenset(["prompt", "max_sequence_length", "text_encoder_out_layers"])
|
||||
decode_block_params = frozenset(["output_type"])
|
||||
vae_encoder_block_params = None
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
@@ -68,10 +63,6 @@ class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
# should choose from the dict returned by `get_dummy_inputs`
|
||||
text_encoder_block_params = frozenset(["prompt", "max_sequence_length", "text_encoder_out_layers"])
|
||||
decode_block_params = frozenset(["output_type"])
|
||||
vae_encoder_block_params = frozenset(["image", "height", "width"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
|
||||
@@ -34,16 +34,10 @@ class TestQwenImageModularPipelineFast(ModularPipelineTesterMixin, ModularGuider
|
||||
pipeline_class = QwenImageModularPipeline
|
||||
pipeline_blocks_class = QwenImageAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-modular"
|
||||
default_repo_id = "Qwen/Qwen-Image"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
|
||||
|
||||
# should choose from the dict returned by `get_dummy_inputs`
|
||||
text_encoder_block_params = frozenset(["prompt", "negative_prompt", "max_sequence_length"])
|
||||
decode_block_params = frozenset(["output_type"])
|
||||
vae_encoder_block_params = None # None if vae_encoder is not supported
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
generator = self.get_generator()
|
||||
inputs = {
|
||||
@@ -66,16 +60,10 @@ class TestQwenImageEditModularPipelineFast(ModularPipelineTesterMixin, ModularGu
|
||||
pipeline_class = QwenImageEditModularPipeline
|
||||
pipeline_blocks_class = QwenImageEditAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-edit-modular"
|
||||
default_repo_id = "Qwen/Qwen-Image-Edit"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
|
||||
|
||||
# should choose from the dict returned by `get_dummy_inputs`
|
||||
text_encoder_block_params = frozenset(["prompt", "negative_prompt", "max_sequence_length"])
|
||||
decode_block_params = frozenset(["output_type"])
|
||||
vae_encoder_block_params = frozenset(["image", "height", "width"])
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
generator = self.get_generator()
|
||||
inputs = {
|
||||
@@ -98,7 +86,6 @@ class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, Modul
|
||||
pipeline_class = QwenImageEditPlusModularPipeline
|
||||
pipeline_blocks_class = QwenImageEditPlusAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-edit-plus-modular"
|
||||
default_repo_id = "Qwen/Qwen-Image-Edit-2509"
|
||||
|
||||
# No `mask_image` yet.
|
||||
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image"])
|
||||
|
||||
@@ -279,8 +279,6 @@ class TestSDXLModularPipelineFast(
|
||||
pipeline_class = StableDiffusionXLModularPipeline
|
||||
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-sdxl-modular"
|
||||
default_repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
@@ -293,11 +291,6 @@ class TestSDXLModularPipelineFast(
|
||||
batch_params = frozenset(["prompt", "negative_prompt"])
|
||||
expected_image_output_shape = (1, 3, 64, 64)
|
||||
|
||||
# should choose from the dict returned by `get_dummy_inputs`
|
||||
text_encoder_block_params = frozenset(["prompt"])
|
||||
decode_block_params = frozenset(["output_type"])
|
||||
vae_encoder_block_params = None # None if vae_encoder is not supported
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
@@ -346,11 +339,6 @@ class TestSDXLImg2ImgModularPipelineFast(
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "image"])
|
||||
expected_image_output_shape = (1, 3, 64, 64)
|
||||
|
||||
# should choose from the dict returned by `get_dummy_inputs`
|
||||
text_encoder_block_params = frozenset(["prompt"])
|
||||
decode_block_params = frozenset(["output_type"])
|
||||
vae_encoder_block_params = frozenset(["image"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
|
||||
@@ -48,12 +48,6 @@ class ModularPipelineTesterMixin:
|
||||
"You need to set the attribute `pretrained_model_name_or_path` in the child test class. See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def default_repo_id(self) -> str:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `default_repo_id` in the child test class. See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def pipeline_blocks_class(self) -> Union[Callable, ModularPipelineBlocks]:
|
||||
raise NotImplementedError(
|
||||
@@ -96,30 +90,6 @@ class ModularPipelineTesterMixin:
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def text_encoder_block_params(self) -> frozenset:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `text_encoder_block_params` in the child test class. "
|
||||
"`text_encoder_block_params` are the parameters required to be passed to the text encoder block. "
|
||||
" if should be a subset of the parameters returned by `get_dummy_inputs`"
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def decode_block_params(self) -> frozenset:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `decode_block_params` in the child test class. "
|
||||
"`decode_block_params` are the parameters required to be passed to the decode block. "
|
||||
" if should be a subset of the parameters returned by `get_dummy_inputs`"
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def vae_encoder_block_params(self) -> frozenset:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `vae_encoder_block_params` in the child test class. "
|
||||
"`vae_encoder_block_params` are the parameters required to be passed to the vae encoder block. "
|
||||
" if should be a subset of the parameters returned by `get_dummy_inputs`"
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def setup_method(self):
|
||||
# clean up the VRAM before each test
|
||||
torch.compiler.reset()
|
||||
@@ -154,96 +124,6 @@ class ModularPipelineTesterMixin:
|
||||
_check_for_parameters(self.params, input_parameters, "input")
|
||||
_check_for_parameters(self.optional_params, optional_parameters, "optional")
|
||||
|
||||
def test_loading_from_default_repo(self):
|
||||
if self.default_repo_id is None:
|
||||
return
|
||||
|
||||
try:
|
||||
pipe = ModularPipeline.from_pretrained(self.default_repo_id)
|
||||
assert pipe.blocks.__class__ == self.pipeline_blocks_class
|
||||
except Exception as e:
|
||||
assert False, f"Failed to load pipeline from default repo: {e}"
|
||||
|
||||
def test_modular_inference(self):
|
||||
# run the pipeline to get the base output for comparison
|
||||
pipe = self.get_pipeline()
|
||||
pipe.to(torch_device, torch.float32)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
standard_output = pipe(**inputs, output="images")
|
||||
|
||||
# create text, denoise, decoder (and optional vae encoder) nodes
|
||||
blocks = self.pipeline_blocks_class()
|
||||
|
||||
assert "text_encoder" in blocks.sub_blocks, "`text_encoder` block is not present in the pipeline"
|
||||
assert "denoise" in blocks.sub_blocks, "`denoise` block is not present in the pipeline"
|
||||
assert "decode" in blocks.sub_blocks, "`decode` block is not present in the pipeline"
|
||||
if self.vae_encoder_block_params is not None:
|
||||
assert "vae_encoder" in blocks.sub_blocks, "`vae_encoder` block is not present in the pipeline"
|
||||
|
||||
# manually set the components in the sub_pipe
|
||||
# a hack to workaround the fact the default pipeline properties are often incorrect for testing cases,
|
||||
# #e.g. vae_scale_factor is ususally not 8 because vae is configured to be smaller for testing
|
||||
def manually_set_all_components(pipe: ModularPipeline, sub_pipe: ModularPipeline):
|
||||
for n, comp in pipe.components.items():
|
||||
if not hasattr(sub_pipe, n):
|
||||
setattr(sub_pipe, n, comp)
|
||||
|
||||
text_node = blocks.sub_blocks["text_encoder"].init_pipeline(self.pretrained_model_name_or_path)
|
||||
text_node.load_components(torch_dtype=torch.float32)
|
||||
text_node.to(torch_device)
|
||||
manually_set_all_components(pipe, text_node)
|
||||
|
||||
denoise_node = blocks.sub_blocks["denoise"].init_pipeline(self.pretrained_model_name_or_path)
|
||||
denoise_node.load_components(torch_dtype=torch.float32)
|
||||
denoise_node.to(torch_device)
|
||||
manually_set_all_components(pipe, denoise_node)
|
||||
|
||||
decoder_node = blocks.sub_blocks["decode"].init_pipeline(self.pretrained_model_name_or_path)
|
||||
decoder_node.load_components(torch_dtype=torch.float32)
|
||||
decoder_node.to(torch_device)
|
||||
manually_set_all_components(pipe, decoder_node)
|
||||
|
||||
if self.vae_encoder_block_params is not None:
|
||||
vae_encoder_node = blocks.sub_blocks["vae_encoder"].init_pipeline(self.pretrained_model_name_or_path)
|
||||
vae_encoder_node.load_components(torch_dtype=torch.float32)
|
||||
vae_encoder_node.to(torch_device)
|
||||
manually_set_all_components(pipe, vae_encoder_node)
|
||||
else:
|
||||
vae_encoder_node = None
|
||||
|
||||
# prepare inputs for each node
|
||||
inputs = self.get_dummy_inputs()
|
||||
|
||||
def get_block_inputs(inputs: dict, block_params: frozenset) -> tuple[dict, dict]:
|
||||
block_inputs = {}
|
||||
for name in block_params:
|
||||
if name in inputs:
|
||||
block_inputs[name] = inputs.pop(name)
|
||||
return block_inputs, inputs
|
||||
|
||||
text_inputs, inputs = get_block_inputs(inputs, self.text_encoder_block_params)
|
||||
decoder_inputs, inputs = get_block_inputs(inputs, self.decode_block_params)
|
||||
if vae_encoder_node is not None:
|
||||
vae_encoder_inputs, inputs = get_block_inputs(inputs, self.vae_encoder_block_params)
|
||||
|
||||
# this is also to make sure pipelines mark text outputs as denoiser_input_fields
|
||||
text_output = text_node(**text_inputs).get_by_kwargs("denoiser_input_fields")
|
||||
if vae_encoder_node is not None:
|
||||
vae_encoder_output = vae_encoder_node(**vae_encoder_inputs).values
|
||||
denoise_inputs = {**text_output, **vae_encoder_output, **inputs}
|
||||
else:
|
||||
denoise_inputs = {**text_output, **inputs}
|
||||
|
||||
# denoise node output should be "latents"
|
||||
latents = denoise_node(**denoise_inputs).latents
|
||||
# denoder node input should be "latents" and output should be "images"
|
||||
modular_output = decoder_node(**decoder_inputs, latents=latents).images
|
||||
|
||||
assert modular_output.shape == standard_output.shape, (
|
||||
f"Modular output should have same shape as standard output {standard_output.shape}, but got {modular_output.shape}"
|
||||
)
|
||||
|
||||
def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True):
|
||||
pipe = self.get_pipeline().to(torch_device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user