remove unnecessary print statements.

This commit is contained in:
Sayak Paul
2023-05-24 10:46:38 +05:30
parent ce4e6edefc
commit 1d813f6ebe
4 changed files with 67 additions and 67 deletions

View File

@@ -847,65 +847,65 @@ def main(args):
text_encoder = temp_pipeline.text_encoder
del temp_pipeline
# # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
# def save_model_hook(models, weights, output_dir):
# # there are only two options here. Either are just the unet attn processor layers
# # or there are the unet and text encoder atten layers
# unet_lora_layers_to_save = None
# text_encoder_lora_layers_to_save = None
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
# there are only two options here. Either are just the unet attn processor layers
# or there are the unet and text encoder atten layers
unet_lora_layers_to_save = None
text_encoder_lora_layers_to_save = None
# if args.train_text_encoder:
# text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys()
# unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys()
if args.train_text_encoder:
text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys()
unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys()
# for model in models:
# state_dict = model.state_dict()
for model in models:
state_dict = model.state_dict()
# if (
# text_encoder_lora_layers is not None
# and text_encoder_keys is not None
# and state_dict.keys() == text_encoder_keys
# ):
# # text encoder
# text_encoder_lora_layers_to_save = state_dict
# elif state_dict.keys() == unet_keys:
# # unet
# unet_lora_layers_to_save = state_dict
if (
text_encoder_lora_layers is not None
and text_encoder_keys is not None
and state_dict.keys() == text_encoder_keys
):
# text encoder
text_encoder_lora_layers_to_save = state_dict
elif state_dict.keys() == unet_keys:
# unet
unet_lora_layers_to_save = state_dict
# # make sure to pop weight so that corresponding model is not saved again
# weights.pop()
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
# LoraLoaderMixin.save_lora_weights(
# output_dir,
# unet_lora_layers=unet_lora_layers_to_save,
# text_encoder_lora_layers=text_encoder_lora_layers_to_save,
# )
LoraLoaderMixin.save_lora_weights(
output_dir,
unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
)
# def load_model_hook(models, input_dir):
# # Note we DON'T pass the unet and text encoder here an purpose
# # so that the we don't accidentally override the LoRA layers of
# # unet_lora_layers and text_encoder_lora_layers which are stored in `models`
# # with new torch.nn.Modules / weights. We simply use the pipeline class as
# # an easy way to load the lora checkpoints
# temp_pipeline = DiffusionPipeline.from_pretrained(
# args.pretrained_model_name_or_path,
# revision=args.revision,
# torch_dtype=weight_dtype,
# )
# temp_pipeline.load_lora_weights(input_dir)
def load_model_hook(models, input_dir):
# Note we DON'T pass the unet and text encoder here an purpose
# so that the we don't accidentally override the LoRA layers of
# unet_lora_layers and text_encoder_lora_layers which are stored in `models`
# with new torch.nn.Modules / weights. We simply use the pipeline class as
# an easy way to load the lora checkpoints
temp_pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
revision=args.revision,
torch_dtype=weight_dtype,
)
temp_pipeline.load_lora_weights(input_dir)
# # load lora weights into models
# models[0].load_state_dict(AttnProcsLayers(temp_pipeline.unet.attn_processors).state_dict())
# if len(models) > 1:
# models[1].load_state_dict(AttnProcsLayers(temp_pipeline.text_encoder_lora_attn_procs).state_dict())
# load lora weights into models
models[0].load_state_dict(AttnProcsLayers(temp_pipeline.unet.attn_processors).state_dict())
if len(models) > 1:
models[1].load_state_dict(AttnProcsLayers(temp_pipeline.text_encoder_lora_attn_procs).state_dict())
# # delete temporary pipeline and pop models
# del temp_pipeline
# for _ in range(len(models)):
# models.pop()
# delete temporary pipeline and pop models
del temp_pipeline
for _ in range(len(models)):
models.pop()
# accelerator.register_save_state_pre_hook(save_model_hook)
# accelerator.register_load_state_pre_hook(load_model_hook)
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
@@ -1275,19 +1275,19 @@ def main(args):
text_encoder = text_encoder.to(torch.float32)
text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers)
# LoraLoaderMixin.save_lora_weights(
# save_directory=args.output_dir,
# unet_lora_layers=unet_lora_layers,
# text_encoder_lora_layers=text_encoder_lora_layers,
# )
unet.save_attn_procs(save_directory=args.output_dir,)
LoraLoaderMixin.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
)
# unet.save_attn_procs(save_directory=args.output_dir,)
# Final inference
# Load previous pipeline
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
)
pipeline.unet = unet.to(weight_dtype)
# pipeline.unet = unet.to(weight_dtype)
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}
@@ -1305,14 +1305,14 @@ def main(args):
pipeline = pipeline.to(accelerator.device)
# load attention processors
# pipeline.load_lora_weights(args.output_dir)
pipeline.load_lora_weights(args.output_dir)
# run inference
images = []
if args.validation_prompt and args.num_validation_images > 0:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
images = [
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
pipeline(args.validation_prompt, generator=generator).images[0]
for _ in range(args.num_validation_images)
]

View File

@@ -281,7 +281,7 @@ class UNet2DConditionLoadersMixin:
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
attn_processor_class = LoRAAttnProcessor
print(f"attn_processor_class: {attn_processor_class}")
# print(f"attn_processor_class: {attn_processor_class}")
attn_processors[key] = attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
)
@@ -893,7 +893,7 @@ class LoraLoaderMixin:
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
print("Inside the lora loader.")
# print("Inside the lora loader.")
if all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in keys):
# Load the layers corresponding to UNet.
unet_keys = [k for k in keys if k.startswith(self.unet_name)]
@@ -902,7 +902,7 @@ class LoraLoaderMixin:
k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
}
self.unet.load_attn_procs(unet_lora_state_dict)
print("UNet lora loaded.")
# print("UNet lora loaded.")
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]

View File

@@ -290,7 +290,7 @@ class Attention(nn.Module):
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
self._modules.pop("processor")
print(f"Processor type: {type(processor)}")
# print(f"Processor type: {type(processor)}")
self.processor = processor
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
@@ -759,7 +759,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
print(f"{self.__class__.__name__} have been called.")
# print(f"{self.__class__.__name__} have been called.")
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape

View File

@@ -518,9 +518,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
print("set_attn_processor() is called.")
for k in processor:
print(f"{k}: {type(processor[k])}")
# print("set_attn_processor() is called.")
# for k in processor:
# print(f"{k}: {type(processor[k])}")
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):