Compare commits

...

27 Commits

Author SHA1 Message Date
Sayak Paul
fe6c903373 removed print statements. 2023-05-24 17:25:57 +05:30
Sayak Paul
7ba7c65700 more debugging 2023-05-24 17:06:03 +05:30
Sayak Paul
af7d5a6914 more debugging 2023-05-24 16:42:03 +05:30
Sayak Paul
aa58d7a570 more debugging 2023-05-24 16:31:12 +05:30
Sayak Paul
1a60865487 more debugging 2023-05-24 16:12:44 +05:30
Sayak Paul
a1eb20c577 more debugging . 2023-05-24 15:25:33 +05:30
Sayak Paul
eada18a8c2 more debugging 2023-05-24 15:01:02 +05:30
Sayak Paul
66d38f6eaa more debugging 2023-05-24 14:48:33 +05:30
Sayak Paul
bc6b677a6a wrap within attnprocslayers. 2023-05-24 14:30:58 +05:30
Sayak Paul
641e94da44 fix: state_dict() call. 2023-05-24 11:13:29 +05:30
Sayak Paul
a86aa73aa1 more strategic debugging 2023-05-24 10:59:41 +05:30
Sayak Paul
893ef35bf1 Merge branch 'main' into temp/debug-load-lora 2023-05-24 10:47:04 +05:30
Sayak Paul
1d813f6ebe remove unnecessary print statements. 2023-05-24 10:46:38 +05:30
Sayak Paul
ce4e6edefc proper casting 2023-05-23 18:17:23 +05:30
Sayak Paul
a202bb1fca directly use the attention layers. 2023-05-23 17:59:04 +05:30
Sayak Paul
74483b9f14 disable hooks. 2023-05-23 16:05:10 +05:30
Sayak Paul
dc42933feb debugging 2023-05-19 15:16:55 +05:30
Sayak Paul
eba1df08fb debugging 2023-05-19 14:24:01 +05:30
Sayak Paul
8e76e1269d debugging statements. 2023-05-19 13:43:06 +05:30
Sayak Paul
a559b33eda debugging statements. 2023-05-19 13:32:45 +05:30
Sayak Paul
3872e12d99 debugging statements. 2023-05-19 13:22:59 +05:30
Sayak Paul
c83935a716 debugging statement to LoRAAttnAddedKVProcessor. 2023-05-19 13:18:31 +05:30
Sayak Paul
fe2501e540 max difference between the params. 2023-05-19 11:42:29 +05:30
Sayak Paul
5c3601b7a8 device placement. 2023-05-19 11:32:43 +05:30
Sayak Paul
9658b24834 allclose() call. 2023-05-19 11:24:52 +05:30
Sayak Paul
a1b6e29288 are trained params being saved at all? 2023-05-19 11:13:59 +05:30
Sayak Paul
9bd4fda920 add: debugging statements to lora loader unet. 2023-05-19 08:15:01 +05:30
4 changed files with 22 additions and 4 deletions

View File

@@ -46,7 +46,6 @@ from diffusers import (
DDPMScheduler,
DiffusionPipeline,
DPMSolverMultistepScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
@@ -831,6 +830,7 @@ def main(args):
unet.set_attn_processor(unet_lora_attn_procs)
unet_lora_layers = AttnProcsLayers(unet.attn_processors)
unet_lora_layers.state_dict()
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks. For this,
@@ -844,7 +844,7 @@ def main(args):
hidden_size=module.out_features, cross_attention_dim=None
)
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
temp_pipeline = StableDiffusionPipeline.from_pretrained(
temp_pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, text_encoder=text_encoder
)
temp_pipeline._modify_text_encoder(text_lora_attn_procs)
@@ -1271,8 +1271,10 @@ def main(args):
text_encoder = text_encoder.to(torch.float32)
text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers)
print(f"Text encoder layers: {text_encoder_lora_layers}")
LoraLoaderMixin.save_lora_weights(
save_directory=args.output_dir,
# unet_lora_layers=AttnProcsLayers(unet.attn_processors),
unet_lora_layers=unet_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
)
@@ -1300,13 +1302,19 @@ def main(args):
# load attention processors
pipeline.load_lora_weights(args.output_dir)
trained_state_dict = unet_lora_layers.state_dict()
unet_attn_proc_state_dict = AttnProcsLayers(pipeline.unet.attn_processors).state_dict()
for k in unet_attn_proc_state_dict:
from_unet = unet_attn_proc_state_dict[k]
orig = trained_state_dict[k]
print(f"Assertion: {torch.allclose(from_unet, orig.to(from_unet.dtype))}")
# run inference
images = []
if args.validation_prompt and args.num_validation_images > 0:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
images = [
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
pipeline(args.validation_prompt, generator=generator).images[0]
for _ in range(args.num_validation_images)
]

View File

@@ -322,6 +322,10 @@ class UNet2DConditionLoadersMixin:
attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
# set layers
print(
"All processors are of type: LoRAAttnAddedKVProcessor: ",
all(isinstance(attn_processors[k], LoRAAttnAddedKVProcessor) for k in attn_processors),
)
self.set_attn_processor(attn_processors)
def save_attn_procs(
@@ -902,11 +906,12 @@ class LoraLoaderMixin:
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
logger.info(f"Loading {self.text_encoder_name}.")
text_encoder_lora_state_dict = {
k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
print(f"text_encoder_lora_state_dict: {text_encoder_lora_state_dict.keys()}")
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {self.text_encoder_name}.")
attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict)
self._modify_text_encoder(attn_procs_text_encoder)

View File

@@ -297,6 +297,7 @@ class Attention(nn.Module):
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
self._modules.pop("processor")
# print(f"Processor type: {type(processor)}")
self.processor = processor
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
@@ -770,6 +771,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
# print(f"{self.__class__.__name__} have been called.")
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape

View File

@@ -518,6 +518,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
# print("set_attn_processor() is called.")
# for k in processor:
# print(f"{k}: {type(processor[k])}")
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):