mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-10 06:24:19 +08:00
Compare commits
5 Commits
enable-all
...
custom-cod
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
98954fc2e1 | ||
|
|
1262d19d16 | ||
|
|
201da97dd0 | ||
|
|
4423097b23 | ||
|
|
60d1b81023 |
@@ -324,9 +324,12 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
config_name = "config.json"
|
||||
config_name = "modular_config.json"
|
||||
model_name = None
|
||||
|
||||
def __init__(self):
|
||||
self.sub_blocks = InsertableDict()
|
||||
|
||||
@classmethod
|
||||
def _get_signature_keys(cls, obj):
|
||||
parameters = inspect.signature(obj.__init__).parameters
|
||||
@@ -344,6 +347,11 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
"""List of intermediate output parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
@@ -425,6 +433,60 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
)
|
||||
return modular_pipeline
|
||||
|
||||
def get_block_state(self, state: PipelineState) -> dict:
|
||||
"""Get all inputs and intermediates in one dictionary"""
|
||||
data = {}
|
||||
state_inputs = self.inputs + self.intermediate_inputs
|
||||
|
||||
# Check inputs
|
||||
for input_param in state_inputs:
|
||||
if input_param.name:
|
||||
value = state.get_input(input_param.name) or state.get_intermediate(input_param.name)
|
||||
if input_param.required and value is None:
|
||||
raise ValueError(f"Required input '{input_param.name}' is missing")
|
||||
elif value is not None or (value is None and input_param.name not in data):
|
||||
data[input_param.name] = value
|
||||
|
||||
elif input_param.kwargs_type:
|
||||
# if kwargs_type is provided, get all inputs with matching kwargs_type
|
||||
if input_param.kwargs_type not in data:
|
||||
data[input_param.kwargs_type] = {}
|
||||
inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) or state.get_intermediate_kwargs(
|
||||
input_param.kwargs_type
|
||||
)
|
||||
if inputs_kwargs:
|
||||
for k, v in inputs_kwargs.items():
|
||||
if v is not None:
|
||||
data[k] = v
|
||||
data[input_param.kwargs_type][k] = v
|
||||
|
||||
return BlockState(**data)
|
||||
|
||||
def set_block_state(self, state: PipelineState, block_state: BlockState):
|
||||
for output_param in self.intermediate_outputs:
|
||||
if not hasattr(block_state, output_param.name):
|
||||
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
|
||||
param = getattr(block_state, output_param.name)
|
||||
state.set_intermediate(output_param.name, param, output_param.kwargs_type)
|
||||
|
||||
for input_param in self.intermediate_inputs:
|
||||
if input_param.name and hasattr(block_state, input_param.name):
|
||||
param = getattr(block_state, input_param.name)
|
||||
# Only add if the value is different from what's in the state
|
||||
current_value = state.get_intermediate(input_param.name)
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.set_intermediate(input_param.name, param, input_param.kwargs_type)
|
||||
elif input_param.kwargs_type:
|
||||
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
|
||||
# we need to first find out which inputs are and loop through them.
|
||||
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
|
||||
for param_name, current_value in intermediate_kwargs.items():
|
||||
if not hasattr(block_state, param_name):
|
||||
continue
|
||||
param = getattr(block_state, param_name)
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.set_intermediate(param_name, param, input_param.kwargs_type)
|
||||
|
||||
@staticmethod
|
||||
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
|
||||
"""
|
||||
@@ -654,51 +716,6 @@ class PipelineBlock(ModularPipelineBlocks):
|
||||
expected_configs=self.expected_configs,
|
||||
)
|
||||
|
||||
# YiYi TODO: input and inteermediate inputs with same name? should warn?
|
||||
def get_block_state(self, state: PipelineState) -> dict:
|
||||
"""Get all inputs and intermediates in one dictionary"""
|
||||
data = {}
|
||||
|
||||
# Check inputs
|
||||
for input_param in self.inputs:
|
||||
if input_param.name:
|
||||
value = state.get_input(input_param.name)
|
||||
if input_param.required and value is None:
|
||||
raise ValueError(f"Required input '{input_param.name}' is missing")
|
||||
elif value is not None or (value is None and input_param.name not in data):
|
||||
data[input_param.name] = value
|
||||
elif input_param.kwargs_type:
|
||||
# if kwargs_type is provided, get all inputs with matching kwargs_type
|
||||
if input_param.kwargs_type not in data:
|
||||
data[input_param.kwargs_type] = {}
|
||||
inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type)
|
||||
if inputs_kwargs:
|
||||
for k, v in inputs_kwargs.items():
|
||||
if v is not None:
|
||||
data[k] = v
|
||||
data[input_param.kwargs_type][k] = v
|
||||
|
||||
# Check intermediates
|
||||
for input_param in self.intermediate_inputs:
|
||||
if input_param.name:
|
||||
value = state.get_intermediate(input_param.name)
|
||||
if input_param.required and value is None:
|
||||
raise ValueError(f"Required intermediate input '{input_param.name}' is missing")
|
||||
elif value is not None or (value is None and input_param.name not in data):
|
||||
data[input_param.name] = value
|
||||
elif input_param.kwargs_type:
|
||||
# if kwargs_type is provided, get all intermediates with matching kwargs_type
|
||||
if input_param.kwargs_type not in data:
|
||||
data[input_param.kwargs_type] = {}
|
||||
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
|
||||
if intermediate_kwargs:
|
||||
for k, v in intermediate_kwargs.items():
|
||||
if v is not None:
|
||||
if k not in data:
|
||||
data[k] = v
|
||||
data[input_param.kwargs_type][k] = v
|
||||
return BlockState(**data)
|
||||
|
||||
def set_block_state(self, state: PipelineState, block_state: BlockState):
|
||||
for output_param in self.intermediate_outputs:
|
||||
if not hasattr(block_state, output_param.name):
|
||||
@@ -1439,11 +1456,6 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
"""List of input parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
@property
|
||||
def loop_intermediate_inputs(self) -> List[InputParam]:
|
||||
"""List of intermediate input parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
@property
|
||||
def loop_intermediate_outputs(self) -> List[OutputParam]:
|
||||
"""List of intermediate output parameters. Must be implemented by subclasses."""
|
||||
@@ -1457,14 +1469,6 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
input_names.append(input_param.name)
|
||||
return input_names
|
||||
|
||||
@property
|
||||
def loop_required_intermediate_inputs(self) -> List[str]:
|
||||
input_names = []
|
||||
for input_param in self.loop_intermediate_inputs:
|
||||
if input_param.required:
|
||||
input_names.append(input_param.name)
|
||||
return input_names
|
||||
|
||||
# modified from SequentialPipelineBlocks to include loop_expected_components
|
||||
@property
|
||||
def expected_components(self):
|
||||
@@ -1635,75 +1639,6 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
raise NotImplementedError("`__call__` method needs to be implemented by the subclass")
|
||||
|
||||
def get_block_state(self, state: PipelineState) -> dict:
|
||||
"""Get all inputs and intermediates in one dictionary"""
|
||||
data = {}
|
||||
|
||||
# Check inputs
|
||||
for input_param in self.inputs:
|
||||
if input_param.name:
|
||||
value = state.get_input(input_param.name)
|
||||
if input_param.required and value is None:
|
||||
raise ValueError(f"Required input '{input_param.name}' is missing")
|
||||
elif value is not None or (value is None and input_param.name not in data):
|
||||
data[input_param.name] = value
|
||||
elif input_param.kwargs_type:
|
||||
# if kwargs_type is provided, get all inputs with matching kwargs_type
|
||||
if input_param.kwargs_type not in data:
|
||||
data[input_param.kwargs_type] = {}
|
||||
inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type)
|
||||
if inputs_kwargs:
|
||||
for k, v in inputs_kwargs.items():
|
||||
if v is not None:
|
||||
data[k] = v
|
||||
data[input_param.kwargs_type][k] = v
|
||||
|
||||
# Check intermediates
|
||||
for input_param in self.intermediate_inputs:
|
||||
if input_param.name:
|
||||
value = state.get_intermediate(input_param.name)
|
||||
if input_param.required and value is None:
|
||||
raise ValueError(f"Required intermediate input '{input_param.name}' is missing")
|
||||
elif value is not None or (value is None and input_param.name not in data):
|
||||
data[input_param.name] = value
|
||||
elif input_param.kwargs_type:
|
||||
# if kwargs_type is provided, get all intermediates with matching kwargs_type
|
||||
if input_param.kwargs_type not in data:
|
||||
data[input_param.kwargs_type] = {}
|
||||
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
|
||||
if intermediate_kwargs:
|
||||
for k, v in intermediate_kwargs.items():
|
||||
if v is not None:
|
||||
if k not in data:
|
||||
data[k] = v
|
||||
data[input_param.kwargs_type][k] = v
|
||||
return BlockState(**data)
|
||||
|
||||
def set_block_state(self, state: PipelineState, block_state: BlockState):
|
||||
for output_param in self.intermediate_outputs:
|
||||
if not hasattr(block_state, output_param.name):
|
||||
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
|
||||
param = getattr(block_state, output_param.name)
|
||||
state.set_intermediate(output_param.name, param, output_param.kwargs_type)
|
||||
|
||||
for input_param in self.intermediate_inputs:
|
||||
if input_param.name and hasattr(block_state, input_param.name):
|
||||
param = getattr(block_state, input_param.name)
|
||||
# Only add if the value is different from what's in the state
|
||||
current_value = state.get_intermediate(input_param.name)
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.set_intermediate(input_param.name, param, input_param.kwargs_type)
|
||||
elif input_param.kwargs_type:
|
||||
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
|
||||
# we need to first find out which inputs are and loop through them.
|
||||
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
|
||||
for param_name, current_value in intermediate_kwargs.items():
|
||||
if not hasattr(block_state, param_name):
|
||||
continue
|
||||
param = getattr(block_state, param_name)
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.set_intermediate(param_name, param, input_param.kwargs_type)
|
||||
|
||||
@property
|
||||
def doc(self):
|
||||
return make_doc_string(
|
||||
@@ -1976,7 +1911,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
# Add inputs to state, using defaults if not provided in the kwargs or the state
|
||||
# if same input already in the state, will override it if provided in the kwargs
|
||||
|
||||
intermediate_inputs = [inp.name for inp in self.blocks.intermediate_inputs]
|
||||
for expected_input_param in self.blocks.inputs:
|
||||
name = expected_input_param.name
|
||||
|
||||
@@ -27,7 +27,7 @@ from ...schedulers import EulerDiscreteScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor, unwrap_module
|
||||
from ..modular_pipeline import (
|
||||
PipelineBlock,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
@@ -195,7 +195,7 @@ def prepare_latents_img2img(
|
||||
return latents
|
||||
|
||||
|
||||
class StableDiffusionXLInputStep(PipelineBlock):
|
||||
class StableDiffusionXLInputStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -394,7 +394,7 @@ class StableDiffusionXLInputStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
class StableDiffusionXLImg2ImgSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -543,7 +543,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLSetTimestepsStep(PipelineBlock):
|
||||
class StableDiffusionXLSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -611,7 +611,7 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
class StableDiffusionXLInpaintPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -900,7 +900,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
class StableDiffusionXLImg2ImgPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -981,7 +981,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
||||
class StableDiffusionXLPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1092,7 +1092,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1316,7 +1316,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
class StableDiffusionXLPrepareAdditionalConditioningStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1499,7 +1499,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
class StableDiffusionXLControlNetInputStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1718,7 +1718,7 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
|
||||
class StableDiffusionXLControlNetUnionInputStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
|
||||
@@ -23,17 +23,14 @@ from ...image_processor import VaeImageProcessor
|
||||
from ...models import AutoencoderKL
|
||||
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import (
|
||||
PipelineBlock,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class StableDiffusionXLDecodeStep(PipelineBlock):
|
||||
class StableDiffusionXLDecodeStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -157,7 +154,7 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
|
||||
class StableDiffusionXLInpaintOverlayMaskStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
|
||||
@@ -25,7 +25,7 @@ from ...utils import logging
|
||||
from ..modular_pipeline import (
|
||||
BlockState,
|
||||
LoopSequentialPipelineBlocks,
|
||||
PipelineBlock,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
@@ -37,7 +37,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# YiYi experimenting composible denoise loop
|
||||
# loop step (1): prepare latent input for denoiser
|
||||
class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -55,7 +55,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
||||
)
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
def inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
@@ -73,7 +73,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (1): prepare latent input for denoiser (with inpainting)
|
||||
class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLInpaintLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -144,7 +144,7 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (2): denoise the latents with guidance
|
||||
class StableDiffusionXLLoopDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -249,7 +249,7 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (2): denoise the latents with guidance (with controlnet)
|
||||
class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -449,7 +449,7 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (3): scheduler step to update latents
|
||||
class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -520,7 +520,7 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (3): scheduler step to update latents (with inpainting)
|
||||
class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLInpaintLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -660,7 +660,7 @@ class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
]
|
||||
|
||||
@property
|
||||
def loop_intermediate_inputs(self) -> List[InputParam]:
|
||||
def loop_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"timesteps",
|
||||
|
||||
@@ -35,7 +35,7 @@ from ...utils import (
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import StableDiffusionXLModularPipeline
|
||||
|
||||
@@ -57,7 +57,7 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
class StableDiffusionXLIPAdapterStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -215,7 +215,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
class StableDiffusionXLTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -576,7 +576,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
class StableDiffusionXLVaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -691,7 +691,7 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
class StableDiffusionXLInpaintVaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch
|
||||
from ...schedulers import UniPCMultistepScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import WanModularPipeline
|
||||
|
||||
@@ -94,7 +94,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class WanInputStep(PipelineBlock):
|
||||
class WanInputStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
@@ -194,7 +194,7 @@ class WanInputStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class WanSetTimestepsStep(PipelineBlock):
|
||||
class WanSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
@@ -243,7 +243,7 @@ class WanSetTimestepsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class WanPrepareLatentsStep(PipelineBlock):
|
||||
class WanPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
|
||||
@@ -22,14 +22,14 @@ from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKLWan
|
||||
from ...utils import logging
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class WanDecodeStep(PipelineBlock):
|
||||
class WanDecodeStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
|
||||
@@ -24,7 +24,7 @@ from ...utils import logging
|
||||
from ..modular_pipeline import (
|
||||
BlockState,
|
||||
LoopSequentialPipelineBlocks,
|
||||
PipelineBlock,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
@@ -34,7 +34,7 @@ from .modular_pipeline import WanModularPipeline
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class WanLoopDenoiser(PipelineBlock):
|
||||
class WanLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
@@ -132,7 +132,7 @@ class WanLoopDenoiser(PipelineBlock):
|
||||
return components, block_state
|
||||
|
||||
|
||||
class WanLoopAfterDenoiser(PipelineBlock):
|
||||
class WanLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
|
||||
@@ -22,7 +22,7 @@ from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...utils import is_ftfy_available, logging
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import WanModularPipeline
|
||||
|
||||
@@ -51,7 +51,7 @@ def prompt_clean(text):
|
||||
return text
|
||||
|
||||
|
||||
class WanTextEncoderStep(PipelineBlock):
|
||||
class WanTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user