mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-10 06:24:19 +08:00
Compare commits
27 Commits
mps-video
...
temp/debug
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fe6c903373 | ||
|
|
7ba7c65700 | ||
|
|
af7d5a6914 | ||
|
|
aa58d7a570 | ||
|
|
1a60865487 | ||
|
|
a1eb20c577 | ||
|
|
eada18a8c2 | ||
|
|
66d38f6eaa | ||
|
|
bc6b677a6a | ||
|
|
641e94da44 | ||
|
|
a86aa73aa1 | ||
|
|
893ef35bf1 | ||
|
|
1d813f6ebe | ||
|
|
ce4e6edefc | ||
|
|
a202bb1fca | ||
|
|
74483b9f14 | ||
|
|
dc42933feb | ||
|
|
eba1df08fb | ||
|
|
8e76e1269d | ||
|
|
a559b33eda | ||
|
|
3872e12d99 | ||
|
|
c83935a716 | ||
|
|
fe2501e540 | ||
|
|
5c3601b7a8 | ||
|
|
9658b24834 | ||
|
|
a1b6e29288 | ||
|
|
9bd4fda920 |
@@ -46,7 +46,6 @@ from diffusers import (
|
|||||||
DDPMScheduler,
|
DDPMScheduler,
|
||||||
DiffusionPipeline,
|
DiffusionPipeline,
|
||||||
DPMSolverMultistepScheduler,
|
DPMSolverMultistepScheduler,
|
||||||
StableDiffusionPipeline,
|
|
||||||
UNet2DConditionModel,
|
UNet2DConditionModel,
|
||||||
)
|
)
|
||||||
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
|
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
|
||||||
@@ -831,6 +830,7 @@ def main(args):
|
|||||||
|
|
||||||
unet.set_attn_processor(unet_lora_attn_procs)
|
unet.set_attn_processor(unet_lora_attn_procs)
|
||||||
unet_lora_layers = AttnProcsLayers(unet.attn_processors)
|
unet_lora_layers = AttnProcsLayers(unet.attn_processors)
|
||||||
|
unet_lora_layers.state_dict()
|
||||||
|
|
||||||
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
|
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
|
||||||
# So, instead, we monkey-patch the forward calls of its attention-blocks. For this,
|
# So, instead, we monkey-patch the forward calls of its attention-blocks. For this,
|
||||||
@@ -844,7 +844,7 @@ def main(args):
|
|||||||
hidden_size=module.out_features, cross_attention_dim=None
|
hidden_size=module.out_features, cross_attention_dim=None
|
||||||
)
|
)
|
||||||
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
|
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
|
||||||
temp_pipeline = StableDiffusionPipeline.from_pretrained(
|
temp_pipeline = DiffusionPipeline.from_pretrained(
|
||||||
args.pretrained_model_name_or_path, text_encoder=text_encoder
|
args.pretrained_model_name_or_path, text_encoder=text_encoder
|
||||||
)
|
)
|
||||||
temp_pipeline._modify_text_encoder(text_lora_attn_procs)
|
temp_pipeline._modify_text_encoder(text_lora_attn_procs)
|
||||||
@@ -1271,8 +1271,10 @@ def main(args):
|
|||||||
text_encoder = text_encoder.to(torch.float32)
|
text_encoder = text_encoder.to(torch.float32)
|
||||||
text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers)
|
text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers)
|
||||||
|
|
||||||
|
print(f"Text encoder layers: {text_encoder_lora_layers}")
|
||||||
LoraLoaderMixin.save_lora_weights(
|
LoraLoaderMixin.save_lora_weights(
|
||||||
save_directory=args.output_dir,
|
save_directory=args.output_dir,
|
||||||
|
# unet_lora_layers=AttnProcsLayers(unet.attn_processors),
|
||||||
unet_lora_layers=unet_lora_layers,
|
unet_lora_layers=unet_lora_layers,
|
||||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||||
)
|
)
|
||||||
@@ -1300,13 +1302,19 @@ def main(args):
|
|||||||
|
|
||||||
# load attention processors
|
# load attention processors
|
||||||
pipeline.load_lora_weights(args.output_dir)
|
pipeline.load_lora_weights(args.output_dir)
|
||||||
|
trained_state_dict = unet_lora_layers.state_dict()
|
||||||
|
unet_attn_proc_state_dict = AttnProcsLayers(pipeline.unet.attn_processors).state_dict()
|
||||||
|
for k in unet_attn_proc_state_dict:
|
||||||
|
from_unet = unet_attn_proc_state_dict[k]
|
||||||
|
orig = trained_state_dict[k]
|
||||||
|
print(f"Assertion: {torch.allclose(from_unet, orig.to(from_unet.dtype))}")
|
||||||
|
|
||||||
# run inference
|
# run inference
|
||||||
images = []
|
images = []
|
||||||
if args.validation_prompt and args.num_validation_images > 0:
|
if args.validation_prompt and args.num_validation_images > 0:
|
||||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||||
images = [
|
images = [
|
||||||
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
|
pipeline(args.validation_prompt, generator=generator).images[0]
|
||||||
for _ in range(args.num_validation_images)
|
for _ in range(args.num_validation_images)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -322,6 +322,10 @@ class UNet2DConditionLoadersMixin:
|
|||||||
attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
|
attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
|
||||||
|
|
||||||
# set layers
|
# set layers
|
||||||
|
print(
|
||||||
|
"All processors are of type: LoRAAttnAddedKVProcessor: ",
|
||||||
|
all(isinstance(attn_processors[k], LoRAAttnAddedKVProcessor) for k in attn_processors),
|
||||||
|
)
|
||||||
self.set_attn_processor(attn_processors)
|
self.set_attn_processor(attn_processors)
|
||||||
|
|
||||||
def save_attn_procs(
|
def save_attn_procs(
|
||||||
@@ -902,11 +906,12 @@ class LoraLoaderMixin:
|
|||||||
|
|
||||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||||
text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
|
text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
|
||||||
logger.info(f"Loading {self.text_encoder_name}.")
|
|
||||||
text_encoder_lora_state_dict = {
|
text_encoder_lora_state_dict = {
|
||||||
k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
||||||
}
|
}
|
||||||
|
print(f"text_encoder_lora_state_dict: {text_encoder_lora_state_dict.keys()}")
|
||||||
if len(text_encoder_lora_state_dict) > 0:
|
if len(text_encoder_lora_state_dict) > 0:
|
||||||
|
logger.info(f"Loading {self.text_encoder_name}.")
|
||||||
attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict)
|
attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict)
|
||||||
self._modify_text_encoder(attn_procs_text_encoder)
|
self._modify_text_encoder(attn_procs_text_encoder)
|
||||||
|
|
||||||
|
|||||||
@@ -297,6 +297,7 @@ class Attention(nn.Module):
|
|||||||
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
||||||
self._modules.pop("processor")
|
self._modules.pop("processor")
|
||||||
|
|
||||||
|
# print(f"Processor type: {type(processor)}")
|
||||||
self.processor = processor
|
self.processor = processor
|
||||||
|
|
||||||
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
|
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
|
||||||
@@ -770,6 +771,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
|
|||||||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
||||||
|
|
||||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
|
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
|
||||||
|
# print(f"{self.__class__.__name__} have been called.")
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
||||||
batch_size, sequence_length, _ = hidden_states.shape
|
batch_size, sequence_length, _ = hidden_states.shape
|
||||||
|
|||||||
@@ -518,6 +518,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|||||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||||
)
|
)
|
||||||
|
# print("set_attn_processor() is called.")
|
||||||
|
# for k in processor:
|
||||||
|
# print(f"{k}: {type(processor[k])}")
|
||||||
|
|
||||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||||
if hasattr(module, "set_processor"):
|
if hasattr(module, "set_processor"):
|
||||||
|
|||||||
Reference in New Issue
Block a user