mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
27 Commits
v0.29.1
...
temp/debug
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fe6c903373 | ||
|
|
7ba7c65700 | ||
|
|
af7d5a6914 | ||
|
|
aa58d7a570 | ||
|
|
1a60865487 | ||
|
|
a1eb20c577 | ||
|
|
eada18a8c2 | ||
|
|
66d38f6eaa | ||
|
|
bc6b677a6a | ||
|
|
641e94da44 | ||
|
|
a86aa73aa1 | ||
|
|
893ef35bf1 | ||
|
|
1d813f6ebe | ||
|
|
ce4e6edefc | ||
|
|
a202bb1fca | ||
|
|
74483b9f14 | ||
|
|
dc42933feb | ||
|
|
eba1df08fb | ||
|
|
8e76e1269d | ||
|
|
a559b33eda | ||
|
|
3872e12d99 | ||
|
|
c83935a716 | ||
|
|
fe2501e540 | ||
|
|
5c3601b7a8 | ||
|
|
9658b24834 | ||
|
|
a1b6e29288 | ||
|
|
9bd4fda920 |
@@ -46,7 +46,6 @@ from diffusers import (
|
||||
DDPMScheduler,
|
||||
DiffusionPipeline,
|
||||
DPMSolverMultistepScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
|
||||
@@ -831,6 +830,7 @@ def main(args):
|
||||
|
||||
unet.set_attn_processor(unet_lora_attn_procs)
|
||||
unet_lora_layers = AttnProcsLayers(unet.attn_processors)
|
||||
unet_lora_layers.state_dict()
|
||||
|
||||
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
|
||||
# So, instead, we monkey-patch the forward calls of its attention-blocks. For this,
|
||||
@@ -844,7 +844,7 @@ def main(args):
|
||||
hidden_size=module.out_features, cross_attention_dim=None
|
||||
)
|
||||
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
|
||||
temp_pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
temp_pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path, text_encoder=text_encoder
|
||||
)
|
||||
temp_pipeline._modify_text_encoder(text_lora_attn_procs)
|
||||
@@ -1271,8 +1271,10 @@ def main(args):
|
||||
text_encoder = text_encoder.to(torch.float32)
|
||||
text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers)
|
||||
|
||||
print(f"Text encoder layers: {text_encoder_lora_layers}")
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
save_directory=args.output_dir,
|
||||
# unet_lora_layers=AttnProcsLayers(unet.attn_processors),
|
||||
unet_lora_layers=unet_lora_layers,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
)
|
||||
@@ -1300,13 +1302,19 @@ def main(args):
|
||||
|
||||
# load attention processors
|
||||
pipeline.load_lora_weights(args.output_dir)
|
||||
trained_state_dict = unet_lora_layers.state_dict()
|
||||
unet_attn_proc_state_dict = AttnProcsLayers(pipeline.unet.attn_processors).state_dict()
|
||||
for k in unet_attn_proc_state_dict:
|
||||
from_unet = unet_attn_proc_state_dict[k]
|
||||
orig = trained_state_dict[k]
|
||||
print(f"Assertion: {torch.allclose(from_unet, orig.to(from_unet.dtype))}")
|
||||
|
||||
# run inference
|
||||
images = []
|
||||
if args.validation_prompt and args.num_validation_images > 0:
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
images = [
|
||||
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
|
||||
pipeline(args.validation_prompt, generator=generator).images[0]
|
||||
for _ in range(args.num_validation_images)
|
||||
]
|
||||
|
||||
|
||||
@@ -322,6 +322,10 @@ class UNet2DConditionLoadersMixin:
|
||||
attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
|
||||
|
||||
# set layers
|
||||
print(
|
||||
"All processors are of type: LoRAAttnAddedKVProcessor: ",
|
||||
all(isinstance(attn_processors[k], LoRAAttnAddedKVProcessor) for k in attn_processors),
|
||||
)
|
||||
self.set_attn_processor(attn_processors)
|
||||
|
||||
def save_attn_procs(
|
||||
@@ -902,11 +906,12 @@ class LoraLoaderMixin:
|
||||
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
|
||||
logger.info(f"Loading {self.text_encoder_name}.")
|
||||
text_encoder_lora_state_dict = {
|
||||
k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
||||
}
|
||||
print(f"text_encoder_lora_state_dict: {text_encoder_lora_state_dict.keys()}")
|
||||
if len(text_encoder_lora_state_dict) > 0:
|
||||
logger.info(f"Loading {self.text_encoder_name}.")
|
||||
attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict)
|
||||
self._modify_text_encoder(attn_procs_text_encoder)
|
||||
|
||||
|
||||
@@ -297,6 +297,7 @@ class Attention(nn.Module):
|
||||
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
||||
self._modules.pop("processor")
|
||||
|
||||
# print(f"Processor type: {type(processor)}")
|
||||
self.processor = processor
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
|
||||
@@ -770,6 +771,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
|
||||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
|
||||
# print(f"{self.__class__.__name__} have been called.")
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
@@ -518,6 +518,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
# print("set_attn_processor() is called.")
|
||||
# for k in processor:
|
||||
# print(f"{k}: {type(processor[k])}")
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
|
||||
Reference in New Issue
Block a user