Compare commits

...

8 Commits

Author SHA1 Message Date
Sayak Paul
541b89b3e4 Merge branch 'main' into fix-lora-device-test 2024-04-25 17:13:28 +05:30
Dhruv Nair
ff7a10dedc Merge branch 'main' into fix-lora-device-test 2024-04-24 10:58:50 +05:30
sayakpaul
c8b10a4656 empty 2024-04-23 20:40:11 +05:30
Sayak Paul
8058612d73 Merge branch 'main' into fix-lora-device-test 2024-04-23 15:30:26 +05:30
sayakpaul
c55f925f10 quality 2024-04-22 17:23:42 +05:30
sayakpaul
4faf220b68 fix more/ 2024-04-22 17:20:26 +05:30
sayakpaul
3874e8cc6e fix more. 2024-04-22 17:18:43 +05:30
sayakpaul
edb6cd74f7 fix lora device test 2024-04-22 17:17:25 +05:30

View File

@@ -1268,9 +1268,10 @@ class LoraLoaderMixin:
unet_module.lora_A[adapter_name].to(device)
unet_module.lora_B[adapter_name].to(device)
# this is a param, not a module, so device placement is not in-place -> re-assign
unet_module.lora_magnitude_vector[adapter_name] = unet_module.lora_magnitude_vector[
adapter_name
].to(device)
if hasattr(unet_module, "lora_magnitude_vector") and unet_module.lora_magnitude_vector is not None:
unet_module.lora_magnitude_vector[adapter_name] = unet_module.lora_magnitude_vector[
adapter_name
].to(device)
# Handle the text encoder
modules_to_process = []
@@ -1288,9 +1289,13 @@ class LoraLoaderMixin:
text_encoder_module.lora_A[adapter_name].to(device)
text_encoder_module.lora_B[adapter_name].to(device)
# this is a param, not a module, so device placement is not in-place -> re-assign
text_encoder_module.lora_magnitude_vector[
adapter_name
] = text_encoder_module.lora_magnitude_vector[adapter_name].to(device)
if (
hasattr(text_encoder, "lora_magnitude_vector")
and text_encoder_module.lora_magnitude_vector is not None
):
text_encoder_module.lora_magnitude_vector[
adapter_name
] = text_encoder_module.lora_magnitude_vector[adapter_name].to(device)
class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):