mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-20 03:14:43 +08:00
Compare commits
5 Commits
pipeline-s
...
custom-cod
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
98954fc2e1 | ||
|
|
1262d19d16 | ||
|
|
201da97dd0 | ||
|
|
4423097b23 | ||
|
|
60d1b81023 |
@@ -324,9 +324,12 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
|||||||
</Tip>
|
</Tip>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config_name = "config.json"
|
config_name = "modular_config.json"
|
||||||
model_name = None
|
model_name = None
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.sub_blocks = InsertableDict()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_signature_keys(cls, obj):
|
def _get_signature_keys(cls, obj):
|
||||||
parameters = inspect.signature(obj.__init__).parameters
|
parameters = inspect.signature(obj.__init__).parameters
|
||||||
@@ -344,6 +347,11 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
|||||||
def expected_configs(self) -> List[ConfigSpec]:
|
def expected_configs(self) -> List[ConfigSpec]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
"""List of intermediate output parameters. Must be implemented by subclasses."""
|
||||||
|
return []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
cls,
|
cls,
|
||||||
@@ -425,6 +433,60 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
|||||||
)
|
)
|
||||||
return modular_pipeline
|
return modular_pipeline
|
||||||
|
|
||||||
|
def get_block_state(self, state: PipelineState) -> dict:
|
||||||
|
"""Get all inputs and intermediates in one dictionary"""
|
||||||
|
data = {}
|
||||||
|
state_inputs = self.inputs + self.intermediate_inputs
|
||||||
|
|
||||||
|
# Check inputs
|
||||||
|
for input_param in state_inputs:
|
||||||
|
if input_param.name:
|
||||||
|
value = state.get_input(input_param.name) or state.get_intermediate(input_param.name)
|
||||||
|
if input_param.required and value is None:
|
||||||
|
raise ValueError(f"Required input '{input_param.name}' is missing")
|
||||||
|
elif value is not None or (value is None and input_param.name not in data):
|
||||||
|
data[input_param.name] = value
|
||||||
|
|
||||||
|
elif input_param.kwargs_type:
|
||||||
|
# if kwargs_type is provided, get all inputs with matching kwargs_type
|
||||||
|
if input_param.kwargs_type not in data:
|
||||||
|
data[input_param.kwargs_type] = {}
|
||||||
|
inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) or state.get_intermediate_kwargs(
|
||||||
|
input_param.kwargs_type
|
||||||
|
)
|
||||||
|
if inputs_kwargs:
|
||||||
|
for k, v in inputs_kwargs.items():
|
||||||
|
if v is not None:
|
||||||
|
data[k] = v
|
||||||
|
data[input_param.kwargs_type][k] = v
|
||||||
|
|
||||||
|
return BlockState(**data)
|
||||||
|
|
||||||
|
def set_block_state(self, state: PipelineState, block_state: BlockState):
|
||||||
|
for output_param in self.intermediate_outputs:
|
||||||
|
if not hasattr(block_state, output_param.name):
|
||||||
|
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
|
||||||
|
param = getattr(block_state, output_param.name)
|
||||||
|
state.set_intermediate(output_param.name, param, output_param.kwargs_type)
|
||||||
|
|
||||||
|
for input_param in self.intermediate_inputs:
|
||||||
|
if input_param.name and hasattr(block_state, input_param.name):
|
||||||
|
param = getattr(block_state, input_param.name)
|
||||||
|
# Only add if the value is different from what's in the state
|
||||||
|
current_value = state.get_intermediate(input_param.name)
|
||||||
|
if current_value is not param: # Using identity comparison to check if object was modified
|
||||||
|
state.set_intermediate(input_param.name, param, input_param.kwargs_type)
|
||||||
|
elif input_param.kwargs_type:
|
||||||
|
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
|
||||||
|
# we need to first find out which inputs are and loop through them.
|
||||||
|
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
|
||||||
|
for param_name, current_value in intermediate_kwargs.items():
|
||||||
|
if not hasattr(block_state, param_name):
|
||||||
|
continue
|
||||||
|
param = getattr(block_state, param_name)
|
||||||
|
if current_value is not param: # Using identity comparison to check if object was modified
|
||||||
|
state.set_intermediate(param_name, param, input_param.kwargs_type)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
|
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
|
||||||
"""
|
"""
|
||||||
@@ -654,51 +716,6 @@ class PipelineBlock(ModularPipelineBlocks):
|
|||||||
expected_configs=self.expected_configs,
|
expected_configs=self.expected_configs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# YiYi TODO: input and inteermediate inputs with same name? should warn?
|
|
||||||
def get_block_state(self, state: PipelineState) -> dict:
|
|
||||||
"""Get all inputs and intermediates in one dictionary"""
|
|
||||||
data = {}
|
|
||||||
|
|
||||||
# Check inputs
|
|
||||||
for input_param in self.inputs:
|
|
||||||
if input_param.name:
|
|
||||||
value = state.get_input(input_param.name)
|
|
||||||
if input_param.required and value is None:
|
|
||||||
raise ValueError(f"Required input '{input_param.name}' is missing")
|
|
||||||
elif value is not None or (value is None and input_param.name not in data):
|
|
||||||
data[input_param.name] = value
|
|
||||||
elif input_param.kwargs_type:
|
|
||||||
# if kwargs_type is provided, get all inputs with matching kwargs_type
|
|
||||||
if input_param.kwargs_type not in data:
|
|
||||||
data[input_param.kwargs_type] = {}
|
|
||||||
inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type)
|
|
||||||
if inputs_kwargs:
|
|
||||||
for k, v in inputs_kwargs.items():
|
|
||||||
if v is not None:
|
|
||||||
data[k] = v
|
|
||||||
data[input_param.kwargs_type][k] = v
|
|
||||||
|
|
||||||
# Check intermediates
|
|
||||||
for input_param in self.intermediate_inputs:
|
|
||||||
if input_param.name:
|
|
||||||
value = state.get_intermediate(input_param.name)
|
|
||||||
if input_param.required and value is None:
|
|
||||||
raise ValueError(f"Required intermediate input '{input_param.name}' is missing")
|
|
||||||
elif value is not None or (value is None and input_param.name not in data):
|
|
||||||
data[input_param.name] = value
|
|
||||||
elif input_param.kwargs_type:
|
|
||||||
# if kwargs_type is provided, get all intermediates with matching kwargs_type
|
|
||||||
if input_param.kwargs_type not in data:
|
|
||||||
data[input_param.kwargs_type] = {}
|
|
||||||
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
|
|
||||||
if intermediate_kwargs:
|
|
||||||
for k, v in intermediate_kwargs.items():
|
|
||||||
if v is not None:
|
|
||||||
if k not in data:
|
|
||||||
data[k] = v
|
|
||||||
data[input_param.kwargs_type][k] = v
|
|
||||||
return BlockState(**data)
|
|
||||||
|
|
||||||
def set_block_state(self, state: PipelineState, block_state: BlockState):
|
def set_block_state(self, state: PipelineState, block_state: BlockState):
|
||||||
for output_param in self.intermediate_outputs:
|
for output_param in self.intermediate_outputs:
|
||||||
if not hasattr(block_state, output_param.name):
|
if not hasattr(block_state, output_param.name):
|
||||||
@@ -1439,11 +1456,6 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
|||||||
"""List of input parameters. Must be implemented by subclasses."""
|
"""List of input parameters. Must be implemented by subclasses."""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@property
|
|
||||||
def loop_intermediate_inputs(self) -> List[InputParam]:
|
|
||||||
"""List of intermediate input parameters. Must be implemented by subclasses."""
|
|
||||||
return []
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loop_intermediate_outputs(self) -> List[OutputParam]:
|
def loop_intermediate_outputs(self) -> List[OutputParam]:
|
||||||
"""List of intermediate output parameters. Must be implemented by subclasses."""
|
"""List of intermediate output parameters. Must be implemented by subclasses."""
|
||||||
@@ -1457,14 +1469,6 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
|||||||
input_names.append(input_param.name)
|
input_names.append(input_param.name)
|
||||||
return input_names
|
return input_names
|
||||||
|
|
||||||
@property
|
|
||||||
def loop_required_intermediate_inputs(self) -> List[str]:
|
|
||||||
input_names = []
|
|
||||||
for input_param in self.loop_intermediate_inputs:
|
|
||||||
if input_param.required:
|
|
||||||
input_names.append(input_param.name)
|
|
||||||
return input_names
|
|
||||||
|
|
||||||
# modified from SequentialPipelineBlocks to include loop_expected_components
|
# modified from SequentialPipelineBlocks to include loop_expected_components
|
||||||
@property
|
@property
|
||||||
def expected_components(self):
|
def expected_components(self):
|
||||||
@@ -1635,75 +1639,6 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
|||||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||||
raise NotImplementedError("`__call__` method needs to be implemented by the subclass")
|
raise NotImplementedError("`__call__` method needs to be implemented by the subclass")
|
||||||
|
|
||||||
def get_block_state(self, state: PipelineState) -> dict:
|
|
||||||
"""Get all inputs and intermediates in one dictionary"""
|
|
||||||
data = {}
|
|
||||||
|
|
||||||
# Check inputs
|
|
||||||
for input_param in self.inputs:
|
|
||||||
if input_param.name:
|
|
||||||
value = state.get_input(input_param.name)
|
|
||||||
if input_param.required and value is None:
|
|
||||||
raise ValueError(f"Required input '{input_param.name}' is missing")
|
|
||||||
elif value is not None or (value is None and input_param.name not in data):
|
|
||||||
data[input_param.name] = value
|
|
||||||
elif input_param.kwargs_type:
|
|
||||||
# if kwargs_type is provided, get all inputs with matching kwargs_type
|
|
||||||
if input_param.kwargs_type not in data:
|
|
||||||
data[input_param.kwargs_type] = {}
|
|
||||||
inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type)
|
|
||||||
if inputs_kwargs:
|
|
||||||
for k, v in inputs_kwargs.items():
|
|
||||||
if v is not None:
|
|
||||||
data[k] = v
|
|
||||||
data[input_param.kwargs_type][k] = v
|
|
||||||
|
|
||||||
# Check intermediates
|
|
||||||
for input_param in self.intermediate_inputs:
|
|
||||||
if input_param.name:
|
|
||||||
value = state.get_intermediate(input_param.name)
|
|
||||||
if input_param.required and value is None:
|
|
||||||
raise ValueError(f"Required intermediate input '{input_param.name}' is missing")
|
|
||||||
elif value is not None or (value is None and input_param.name not in data):
|
|
||||||
data[input_param.name] = value
|
|
||||||
elif input_param.kwargs_type:
|
|
||||||
# if kwargs_type is provided, get all intermediates with matching kwargs_type
|
|
||||||
if input_param.kwargs_type not in data:
|
|
||||||
data[input_param.kwargs_type] = {}
|
|
||||||
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
|
|
||||||
if intermediate_kwargs:
|
|
||||||
for k, v in intermediate_kwargs.items():
|
|
||||||
if v is not None:
|
|
||||||
if k not in data:
|
|
||||||
data[k] = v
|
|
||||||
data[input_param.kwargs_type][k] = v
|
|
||||||
return BlockState(**data)
|
|
||||||
|
|
||||||
def set_block_state(self, state: PipelineState, block_state: BlockState):
|
|
||||||
for output_param in self.intermediate_outputs:
|
|
||||||
if not hasattr(block_state, output_param.name):
|
|
||||||
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
|
|
||||||
param = getattr(block_state, output_param.name)
|
|
||||||
state.set_intermediate(output_param.name, param, output_param.kwargs_type)
|
|
||||||
|
|
||||||
for input_param in self.intermediate_inputs:
|
|
||||||
if input_param.name and hasattr(block_state, input_param.name):
|
|
||||||
param = getattr(block_state, input_param.name)
|
|
||||||
# Only add if the value is different from what's in the state
|
|
||||||
current_value = state.get_intermediate(input_param.name)
|
|
||||||
if current_value is not param: # Using identity comparison to check if object was modified
|
|
||||||
state.set_intermediate(input_param.name, param, input_param.kwargs_type)
|
|
||||||
elif input_param.kwargs_type:
|
|
||||||
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
|
|
||||||
# we need to first find out which inputs are and loop through them.
|
|
||||||
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
|
|
||||||
for param_name, current_value in intermediate_kwargs.items():
|
|
||||||
if not hasattr(block_state, param_name):
|
|
||||||
continue
|
|
||||||
param = getattr(block_state, param_name)
|
|
||||||
if current_value is not param: # Using identity comparison to check if object was modified
|
|
||||||
state.set_intermediate(param_name, param, input_param.kwargs_type)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def doc(self):
|
def doc(self):
|
||||||
return make_doc_string(
|
return make_doc_string(
|
||||||
@@ -1976,7 +1911,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
|
|
||||||
# Add inputs to state, using defaults if not provided in the kwargs or the state
|
# Add inputs to state, using defaults if not provided in the kwargs or the state
|
||||||
# if same input already in the state, will override it if provided in the kwargs
|
# if same input already in the state, will override it if provided in the kwargs
|
||||||
|
|
||||||
intermediate_inputs = [inp.name for inp in self.blocks.intermediate_inputs]
|
intermediate_inputs = [inp.name for inp in self.blocks.intermediate_inputs]
|
||||||
for expected_input_param in self.blocks.inputs:
|
for expected_input_param in self.blocks.inputs:
|
||||||
name = expected_input_param.name
|
name = expected_input_param.name
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ from ...schedulers import EulerDiscreteScheduler
|
|||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ...utils.torch_utils import randn_tensor, unwrap_module
|
from ...utils.torch_utils import randn_tensor, unwrap_module
|
||||||
from ..modular_pipeline import (
|
from ..modular_pipeline import (
|
||||||
PipelineBlock,
|
ModularPipelineBlocks,
|
||||||
PipelineState,
|
PipelineState,
|
||||||
)
|
)
|
||||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||||
@@ -195,7 +195,7 @@ def prepare_latents_img2img(
|
|||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLInputStep(PipelineBlock):
|
class StableDiffusionXLInputStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -394,7 +394,7 @@ class StableDiffusionXLInputStep(PipelineBlock):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
class StableDiffusionXLImg2ImgSetTimestepsStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -543,7 +543,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLSetTimestepsStep(PipelineBlock):
|
class StableDiffusionXLSetTimestepsStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -611,7 +611,7 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
class StableDiffusionXLInpaintPrepareLatentsStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -900,7 +900,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
class StableDiffusionXLImg2ImgPrepareLatentsStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -981,7 +981,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
class StableDiffusionXLPrepareLatentsStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -1092,7 +1092,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -1316,7 +1316,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
class StableDiffusionXLPrepareAdditionalConditioningStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -1499,7 +1499,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
class StableDiffusionXLControlNetInputStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -1718,7 +1718,7 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
|
class StableDiffusionXLControlNetUnionInputStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -23,17 +23,14 @@ from ...image_processor import VaeImageProcessor
|
|||||||
from ...models import AutoencoderKL
|
from ...models import AutoencoderKL
|
||||||
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
|
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..modular_pipeline import (
|
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||||
PipelineBlock,
|
|
||||||
PipelineState,
|
|
||||||
)
|
|
||||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLDecodeStep(PipelineBlock):
|
class StableDiffusionXLDecodeStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -157,7 +154,7 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
|
class StableDiffusionXLInpaintOverlayMaskStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from ...utils import logging
|
|||||||
from ..modular_pipeline import (
|
from ..modular_pipeline import (
|
||||||
BlockState,
|
BlockState,
|
||||||
LoopSequentialPipelineBlocks,
|
LoopSequentialPipelineBlocks,
|
||||||
PipelineBlock,
|
ModularPipelineBlocks,
|
||||||
PipelineState,
|
PipelineState,
|
||||||
)
|
)
|
||||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||||
@@ -37,7 +37,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|||||||
|
|
||||||
# YiYi experimenting composible denoise loop
|
# YiYi experimenting composible denoise loop
|
||||||
# loop step (1): prepare latent input for denoiser
|
# loop step (1): prepare latent input for denoiser
|
||||||
class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
class StableDiffusionXLLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -55,7 +55,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def intermediate_inputs(self) -> List[str]:
|
def inputs(self) -> List[str]:
|
||||||
return [
|
return [
|
||||||
InputParam(
|
InputParam(
|
||||||
"latents",
|
"latents",
|
||||||
@@ -73,7 +73,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
|||||||
|
|
||||||
|
|
||||||
# loop step (1): prepare latent input for denoiser (with inpainting)
|
# loop step (1): prepare latent input for denoiser (with inpainting)
|
||||||
class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
|
class StableDiffusionXLInpaintLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -144,7 +144,7 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
|
|||||||
|
|
||||||
|
|
||||||
# loop step (2): denoise the latents with guidance
|
# loop step (2): denoise the latents with guidance
|
||||||
class StableDiffusionXLLoopDenoiser(PipelineBlock):
|
class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -249,7 +249,7 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock):
|
|||||||
|
|
||||||
|
|
||||||
# loop step (2): denoise the latents with guidance (with controlnet)
|
# loop step (2): denoise the latents with guidance (with controlnet)
|
||||||
class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
|
class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -449,7 +449,7 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
|
|||||||
|
|
||||||
|
|
||||||
# loop step (3): scheduler step to update latents
|
# loop step (3): scheduler step to update latents
|
||||||
class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
|
class StableDiffusionXLLoopAfterDenoiser(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -520,7 +520,7 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
|
|||||||
|
|
||||||
|
|
||||||
# loop step (3): scheduler step to update latents (with inpainting)
|
# loop step (3): scheduler step to update latents (with inpainting)
|
||||||
class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock):
|
class StableDiffusionXLInpaintLoopAfterDenoiser(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -660,7 +660,7 @@ class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loop_intermediate_inputs(self) -> List[InputParam]:
|
def loop_inputs(self) -> List[InputParam]:
|
||||||
return [
|
return [
|
||||||
InputParam(
|
InputParam(
|
||||||
"timesteps",
|
"timesteps",
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ from ...utils import (
|
|||||||
scale_lora_layers,
|
scale_lora_layers,
|
||||||
unscale_lora_layers,
|
unscale_lora_layers,
|
||||||
)
|
)
|
||||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||||
from .modular_pipeline import StableDiffusionXLModularPipeline
|
from .modular_pipeline import StableDiffusionXLModularPipeline
|
||||||
|
|
||||||
@@ -57,7 +57,7 @@ def retrieve_latents(
|
|||||||
raise AttributeError("Could not access latents of provided encoder_output")
|
raise AttributeError("Could not access latents of provided encoder_output")
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
class StableDiffusionXLIPAdapterStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -215,7 +215,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
class StableDiffusionXLTextEncoderStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -576,7 +576,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
class StableDiffusionXLVaeEncoderStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -691,7 +691,7 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
class StableDiffusionXLInpaintVaeEncoderStep(ModularPipelineBlocks):
|
||||||
model_name = "stable-diffusion-xl"
|
model_name = "stable-diffusion-xl"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import torch
|
|||||||
from ...schedulers import UniPCMultistepScheduler
|
from ...schedulers import UniPCMultistepScheduler
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ...utils.torch_utils import randn_tensor
|
from ...utils.torch_utils import randn_tensor
|
||||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||||
from .modular_pipeline import WanModularPipeline
|
from .modular_pipeline import WanModularPipeline
|
||||||
|
|
||||||
@@ -94,7 +94,7 @@ def retrieve_timesteps(
|
|||||||
return timesteps, num_inference_steps
|
return timesteps, num_inference_steps
|
||||||
|
|
||||||
|
|
||||||
class WanInputStep(PipelineBlock):
|
class WanInputStep(ModularPipelineBlocks):
|
||||||
model_name = "wan"
|
model_name = "wan"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -194,7 +194,7 @@ class WanInputStep(PipelineBlock):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class WanSetTimestepsStep(PipelineBlock):
|
class WanSetTimestepsStep(ModularPipelineBlocks):
|
||||||
model_name = "wan"
|
model_name = "wan"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -243,7 +243,7 @@ class WanSetTimestepsStep(PipelineBlock):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class WanPrepareLatentsStep(PipelineBlock):
|
class WanPrepareLatentsStep(ModularPipelineBlocks):
|
||||||
model_name = "wan"
|
model_name = "wan"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -22,14 +22,14 @@ from ...configuration_utils import FrozenDict
|
|||||||
from ...models import AutoencoderKLWan
|
from ...models import AutoencoderKLWan
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ...video_processor import VideoProcessor
|
from ...video_processor import VideoProcessor
|
||||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class WanDecodeStep(PipelineBlock):
|
class WanDecodeStep(ModularPipelineBlocks):
|
||||||
model_name = "wan"
|
model_name = "wan"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from ...utils import logging
|
|||||||
from ..modular_pipeline import (
|
from ..modular_pipeline import (
|
||||||
BlockState,
|
BlockState,
|
||||||
LoopSequentialPipelineBlocks,
|
LoopSequentialPipelineBlocks,
|
||||||
PipelineBlock,
|
ModularPipelineBlocks,
|
||||||
PipelineState,
|
PipelineState,
|
||||||
)
|
)
|
||||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||||
@@ -34,7 +34,7 @@ from .modular_pipeline import WanModularPipeline
|
|||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class WanLoopDenoiser(PipelineBlock):
|
class WanLoopDenoiser(ModularPipelineBlocks):
|
||||||
model_name = "wan"
|
model_name = "wan"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -132,7 +132,7 @@ class WanLoopDenoiser(PipelineBlock):
|
|||||||
return components, block_state
|
return components, block_state
|
||||||
|
|
||||||
|
|
||||||
class WanLoopAfterDenoiser(PipelineBlock):
|
class WanLoopAfterDenoiser(ModularPipelineBlocks):
|
||||||
model_name = "wan"
|
model_name = "wan"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from transformers import AutoTokenizer, UMT5EncoderModel
|
|||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
from ...guiders import ClassifierFreeGuidance
|
from ...guiders import ClassifierFreeGuidance
|
||||||
from ...utils import is_ftfy_available, logging
|
from ...utils import is_ftfy_available, logging
|
||||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||||
from .modular_pipeline import WanModularPipeline
|
from .modular_pipeline import WanModularPipeline
|
||||||
|
|
||||||
@@ -51,7 +51,7 @@ def prompt_clean(text):
|
|||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
class WanTextEncoderStep(PipelineBlock):
|
class WanTextEncoderStep(ModularPipelineBlocks):
|
||||||
model_name = "wan"
|
model_name = "wan"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
Reference in New Issue
Block a user