mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-16 01:14:47 +08:00
remove unnecessary print statements.
This commit is contained in:
@@ -847,65 +847,65 @@ def main(args):
|
||||
text_encoder = temp_pipeline.text_encoder
|
||||
del temp_pipeline
|
||||
|
||||
# # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
# def save_model_hook(models, weights, output_dir):
|
||||
# # there are only two options here. Either are just the unet attn processor layers
|
||||
# # or there are the unet and text encoder atten layers
|
||||
# unet_lora_layers_to_save = None
|
||||
# text_encoder_lora_layers_to_save = None
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
# there are only two options here. Either are just the unet attn processor layers
|
||||
# or there are the unet and text encoder atten layers
|
||||
unet_lora_layers_to_save = None
|
||||
text_encoder_lora_layers_to_save = None
|
||||
|
||||
# if args.train_text_encoder:
|
||||
# text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys()
|
||||
# unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys()
|
||||
if args.train_text_encoder:
|
||||
text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys()
|
||||
unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys()
|
||||
|
||||
# for model in models:
|
||||
# state_dict = model.state_dict()
|
||||
for model in models:
|
||||
state_dict = model.state_dict()
|
||||
|
||||
# if (
|
||||
# text_encoder_lora_layers is not None
|
||||
# and text_encoder_keys is not None
|
||||
# and state_dict.keys() == text_encoder_keys
|
||||
# ):
|
||||
# # text encoder
|
||||
# text_encoder_lora_layers_to_save = state_dict
|
||||
# elif state_dict.keys() == unet_keys:
|
||||
# # unet
|
||||
# unet_lora_layers_to_save = state_dict
|
||||
if (
|
||||
text_encoder_lora_layers is not None
|
||||
and text_encoder_keys is not None
|
||||
and state_dict.keys() == text_encoder_keys
|
||||
):
|
||||
# text encoder
|
||||
text_encoder_lora_layers_to_save = state_dict
|
||||
elif state_dict.keys() == unet_keys:
|
||||
# unet
|
||||
unet_lora_layers_to_save = state_dict
|
||||
|
||||
# # make sure to pop weight so that corresponding model is not saved again
|
||||
# weights.pop()
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
# LoraLoaderMixin.save_lora_weights(
|
||||
# output_dir,
|
||||
# unet_lora_layers=unet_lora_layers_to_save,
|
||||
# text_encoder_lora_layers=text_encoder_lora_layers_to_save,
|
||||
# )
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
output_dir,
|
||||
unet_lora_layers=unet_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
|
||||
)
|
||||
|
||||
# def load_model_hook(models, input_dir):
|
||||
# # Note we DON'T pass the unet and text encoder here an purpose
|
||||
# # so that the we don't accidentally override the LoRA layers of
|
||||
# # unet_lora_layers and text_encoder_lora_layers which are stored in `models`
|
||||
# # with new torch.nn.Modules / weights. We simply use the pipeline class as
|
||||
# # an easy way to load the lora checkpoints
|
||||
# temp_pipeline = DiffusionPipeline.from_pretrained(
|
||||
# args.pretrained_model_name_or_path,
|
||||
# revision=args.revision,
|
||||
# torch_dtype=weight_dtype,
|
||||
# )
|
||||
# temp_pipeline.load_lora_weights(input_dir)
|
||||
def load_model_hook(models, input_dir):
|
||||
# Note we DON'T pass the unet and text encoder here an purpose
|
||||
# so that the we don't accidentally override the LoRA layers of
|
||||
# unet_lora_layers and text_encoder_lora_layers which are stored in `models`
|
||||
# with new torch.nn.Modules / weights. We simply use the pipeline class as
|
||||
# an easy way to load the lora checkpoints
|
||||
temp_pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
revision=args.revision,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
temp_pipeline.load_lora_weights(input_dir)
|
||||
|
||||
# # load lora weights into models
|
||||
# models[0].load_state_dict(AttnProcsLayers(temp_pipeline.unet.attn_processors).state_dict())
|
||||
# if len(models) > 1:
|
||||
# models[1].load_state_dict(AttnProcsLayers(temp_pipeline.text_encoder_lora_attn_procs).state_dict())
|
||||
# load lora weights into models
|
||||
models[0].load_state_dict(AttnProcsLayers(temp_pipeline.unet.attn_processors).state_dict())
|
||||
if len(models) > 1:
|
||||
models[1].load_state_dict(AttnProcsLayers(temp_pipeline.text_encoder_lora_attn_procs).state_dict())
|
||||
|
||||
# # delete temporary pipeline and pop models
|
||||
# del temp_pipeline
|
||||
# for _ in range(len(models)):
|
||||
# models.pop()
|
||||
# delete temporary pipeline and pop models
|
||||
del temp_pipeline
|
||||
for _ in range(len(models)):
|
||||
models.pop()
|
||||
|
||||
# accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
# accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
@@ -1275,19 +1275,19 @@ def main(args):
|
||||
text_encoder = text_encoder.to(torch.float32)
|
||||
text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers)
|
||||
|
||||
# LoraLoaderMixin.save_lora_weights(
|
||||
# save_directory=args.output_dir,
|
||||
# unet_lora_layers=unet_lora_layers,
|
||||
# text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
# )
|
||||
unet.save_attn_procs(save_directory=args.output_dir,)
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
save_directory=args.output_dir,
|
||||
unet_lora_layers=unet_lora_layers,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
)
|
||||
# unet.save_attn_procs(save_directory=args.output_dir,)
|
||||
|
||||
# Final inference
|
||||
# Load previous pipeline
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
|
||||
)
|
||||
pipeline.unet = unet.to(weight_dtype)
|
||||
# pipeline.unet = unet.to(weight_dtype)
|
||||
|
||||
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
|
||||
scheduler_args = {}
|
||||
@@ -1305,14 +1305,14 @@ def main(args):
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
|
||||
# load attention processors
|
||||
# pipeline.load_lora_weights(args.output_dir)
|
||||
pipeline.load_lora_weights(args.output_dir)
|
||||
|
||||
# run inference
|
||||
images = []
|
||||
if args.validation_prompt and args.num_validation_images > 0:
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
images = [
|
||||
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
|
||||
pipeline(args.validation_prompt, generator=generator).images[0]
|
||||
for _ in range(args.num_validation_images)
|
||||
]
|
||||
|
||||
|
||||
@@ -281,7 +281,7 @@ class UNet2DConditionLoadersMixin:
|
||||
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
|
||||
attn_processor_class = LoRAAttnProcessor
|
||||
|
||||
print(f"attn_processor_class: {attn_processor_class}")
|
||||
# print(f"attn_processor_class: {attn_processor_class}")
|
||||
attn_processors[key] = attn_processor_class(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
|
||||
)
|
||||
@@ -893,7 +893,7 @@ class LoraLoaderMixin:
|
||||
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
print("Inside the lora loader.")
|
||||
# print("Inside the lora loader.")
|
||||
if all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in keys):
|
||||
# Load the layers corresponding to UNet.
|
||||
unet_keys = [k for k in keys if k.startswith(self.unet_name)]
|
||||
@@ -902,7 +902,7 @@ class LoraLoaderMixin:
|
||||
k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
|
||||
}
|
||||
self.unet.load_attn_procs(unet_lora_state_dict)
|
||||
print("UNet lora loaded.")
|
||||
# print("UNet lora loaded.")
|
||||
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
|
||||
|
||||
@@ -290,7 +290,7 @@ class Attention(nn.Module):
|
||||
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
||||
self._modules.pop("processor")
|
||||
|
||||
print(f"Processor type: {type(processor)}")
|
||||
# print(f"Processor type: {type(processor)}")
|
||||
self.processor = processor
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
|
||||
@@ -759,7 +759,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
|
||||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
|
||||
print(f"{self.__class__.__name__} have been called.")
|
||||
# print(f"{self.__class__.__name__} have been called.")
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
@@ -518,9 +518,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
print("set_attn_processor() is called.")
|
||||
for k in processor:
|
||||
print(f"{k}: {type(processor[k])}")
|
||||
# print("set_attn_processor() is called.")
|
||||
# for k in processor:
|
||||
# print(f"{k}: {type(processor[k])}")
|
||||
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
|
||||
Reference in New Issue
Block a user