Compare commits

...

1 Commits

Author SHA1 Message Date
Dhruv Nair
5aed8c633b update 2023-12-19 04:49:34 +00:00
2 changed files with 67 additions and 23 deletions

View File

@@ -23,9 +23,7 @@ from torch.nn.modules.normalization import GroupNorm
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging
from .attention_processor import (
AttentionProcessor,
)
from .attention_processor import USE_PEFT_BACKEND, AttentionProcessor
from .autoencoders import AutoencoderKL
from .lora import LoRACompatibleConv
from .modeling_utils import ModelMixin
@@ -817,11 +815,23 @@ def increase_block_input_in_encoder_resnet(unet: UNet2DConditionModel, block_no,
norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args}
norm_kwargs["num_channels"] += by # surgery done here
# conv1
conv1_args = (
"in_channels out_channels kernel_size stride padding dilation groups bias padding_mode lora_layer".split(" ")
)
conv1_args = [
"in_channels",
"out_channels",
"kernel_size",
"stride",
"padding",
"dilation",
"groups",
"bias",
"padding_mode",
]
if not USE_PEFT_BACKEND:
conv1_args.append("lora_layer")
for a in conv1_args:
assert hasattr(old_conv1, a)
conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args}
conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor.
conv1_kwargs["in_channels"] += by # surgery done here
@@ -839,25 +849,42 @@ def increase_block_input_in_encoder_resnet(unet: UNet2DConditionModel, block_no,
}
# swap old with new modules
unet.down_blocks[block_no].resnets[resnet_idx].norm1 = GroupNorm(**norm_kwargs)
unet.down_blocks[block_no].resnets[resnet_idx].conv1 = LoRACompatibleConv(**conv1_kwargs)
unet.down_blocks[block_no].resnets[resnet_idx].conv_shortcut = LoRACompatibleConv(**conv_shortcut_args_kwargs)
unet.down_blocks[block_no].resnets[resnet_idx].conv1 = (
nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs)
)
unet.down_blocks[block_no].resnets[resnet_idx].conv_shortcut = (
nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs)
)
unet.down_blocks[block_no].resnets[resnet_idx].in_channels += by # surgery done here
def increase_block_input_in_encoder_downsampler(unet: UNet2DConditionModel, block_no, by):
"""Increase channels sizes to allow for additional concatted information from base model"""
old_down = unet.down_blocks[block_no].downsamplers[0].conv
# conv1
args = "in_channels out_channels kernel_size stride padding dilation groups bias padding_mode lora_layer".split(
" "
)
args = [
"in_channels",
"out_channels",
"kernel_size",
"stride",
"padding",
"dilation",
"groups",
"bias",
"padding_mode",
]
if not USE_PEFT_BACKEND:
args.append("lora_layer")
for a in args:
assert hasattr(old_down, a)
kwargs = {a: getattr(old_down, a) for a in args}
kwargs["bias"] = "bias" in kwargs # as param, bias is a boolean, but as attr, it's a tensor.
kwargs["in_channels"] += by # surgery done here
# swap old with new modules
unet.down_blocks[block_no].downsamplers[0].conv = LoRACompatibleConv(**kwargs)
unet.down_blocks[block_no].downsamplers[0].conv = (
nn.Conv2d(**kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**kwargs)
)
unet.down_blocks[block_no].downsamplers[0].channels += by # surgery done here
@@ -871,12 +898,20 @@ def increase_block_input_in_mid_resnet(unet: UNet2DConditionModel, by):
assert hasattr(old_norm1, a)
norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args}
norm_kwargs["num_channels"] += by # surgery done here
# conv1
conv1_args = (
"in_channels out_channels kernel_size stride padding dilation groups bias padding_mode lora_layer".split(" ")
)
for a in conv1_args:
assert hasattr(old_conv1, a)
conv1_args = [
"in_channels",
"out_channels",
"kernel_size",
"stride",
"padding",
"dilation",
"groups",
"bias",
"padding_mode",
]
if not USE_PEFT_BACKEND:
conv1_args.append("lora_layer")
conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args}
conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor.
conv1_kwargs["in_channels"] += by # surgery done here
@@ -894,8 +929,12 @@ def increase_block_input_in_mid_resnet(unet: UNet2DConditionModel, by):
}
# swap old with new modules
unet.mid_block.resnets[0].norm1 = GroupNorm(**norm_kwargs)
unet.mid_block.resnets[0].conv1 = LoRACompatibleConv(**conv1_kwargs)
unet.mid_block.resnets[0].conv_shortcut = LoRACompatibleConv(**conv_shortcut_args_kwargs)
unet.mid_block.resnets[0].conv1 = (
nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs)
)
unet.mid_block.resnets[0].conv_shortcut = (
nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs)
)
unet.mid_block.resnets[0].in_channels += by # surgery done here

View File

@@ -34,6 +34,7 @@ from diffusers.utils.testing_utils import (
enable_full_determinism,
load_image,
load_numpy,
numpy_cosine_similarity_distance,
require_python39_or_higher,
require_torch_2,
require_torch_gpu,
@@ -273,7 +274,9 @@ class ControlNetXSPipelineSlowTests(unittest.TestCase):
original_image = image[-3:, -3:, -1].flatten()
expected_image = np.array([0.1274, 0.1401, 0.147, 0.1185, 0.1555, 0.1492, 0.1565, 0.1474, 0.1701])
assert np.allclose(original_image, expected_image, atol=1e-04)
max_diff = numpy_cosine_similarity_distance(original_image, expected_image)
assert max_diff < 1e-4
def test_depth(self):
controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SD2.1-depth")
@@ -298,7 +301,9 @@ class ControlNetXSPipelineSlowTests(unittest.TestCase):
original_image = image[-3:, -3:, -1].flatten()
expected_image = np.array([0.1098, 0.1025, 0.1211, 0.1129, 0.1165, 0.1262, 0.1185, 0.1261, 0.1703])
assert np.allclose(original_image, expected_image, atol=1e-04)
max_diff = numpy_cosine_similarity_distance(original_image, expected_image)
assert max_diff < 1e-4
@require_python39_or_higher
@require_torch_2