Compare commits

..

3 Commits

Author SHA1 Message Date
DN6
1ca9acc269 update 2024-03-07 14:08:39 +05:30
DN6
150d22821f update 2024-03-07 13:48:01 +05:30
DN6
6f3fd3bd51 update 2024-03-07 13:20:16 +05:30
6 changed files with 24 additions and 88 deletions

View File

@@ -877,8 +877,6 @@ def collate_fn(examples, with_prior_preservation=False):
if with_prior_preservation:
pixel_values += [example["class_images"] for example in examples]
prompts += [example["class_prompt"] for example in examples]
original_sizes += [example["original_size"] for example in examples]
crop_top_lefts += [example["crop_top_left"] for example in examples]
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

View File

@@ -63,20 +63,13 @@ def build_sub_model_components(
num_in_channels=num_in_channels,
image_size=image_size,
torch_dtype=torch_dtype,
model_type=model_type,
)
return unet_components
if component_name == "vae":
scaling_factor = kwargs.get("scaling_factor", None)
vae_components = create_diffusers_vae_model_from_ldm(
pipeline_class_name,
original_config,
checkpoint,
image_size,
scaling_factor,
torch_dtype,
model_type=model_type,
pipeline_class_name, original_config, checkpoint, image_size, scaling_factor, torch_dtype
)
return vae_components
@@ -131,12 +124,11 @@ def build_sub_model_components(
def set_additional_components(
pipeline_class_name,
original_config,
checkpoint=None,
model_type=None,
):
components = {}
if pipeline_class_name in REFINER_PIPELINES:
model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)
model_type = infer_model_type(original_config, model_type=model_type)
is_refiner = model_type == "SDXL-Refiner"
components.update(
{

View File

@@ -28,7 +28,6 @@ from ..schedulers import (
DDIMScheduler,
DDPMScheduler,
DPMSolverMultistepScheduler,
EDMDPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
@@ -176,7 +175,6 @@ DIFFUSERS_TO_LDM_MAPPING = {
LDM_VAE_KEY = "first_stage_model."
LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
PLAYGROUND_VAE_SCALING_FACTOR = 0.5
LDM_UNET_KEY = "model.diffusion_model."
LDM_CONTROLNET_KEY = "control_model."
LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."]
@@ -307,7 +305,7 @@ def fetch_original_config(pipeline_class_name, checkpoint, original_config_file=
return original_config
def infer_model_type(original_config, checkpoint=None, model_type=None):
def infer_model_type(original_config, model_type=None):
if model_type is not None:
return model_type
@@ -325,9 +323,7 @@ def infer_model_type(original_config, checkpoint=None, model_type=None):
elif has_network_config:
context_dim = original_config["model"]["params"]["network_config"]["params"]["context_dim"]
if "edm_mean" in checkpoint and "edm_std" in checkpoint:
model_type = "Playground"
elif context_dim == 2048:
if context_dim == 2048:
model_type = "SDXL"
else:
model_type = "SDXL-Refiner"
@@ -348,13 +344,13 @@ def set_image_size(pipeline_class_name, original_config, checkpoint, image_size=
return image_size
global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
model_type = infer_model_type(original_config, checkpoint, model_type)
model_type = infer_model_type(original_config, model_type)
if pipeline_class_name == "StableDiffusionUpscalePipeline":
image_size = original_config["model"]["params"]["unet_config"]["params"]["image_size"]
return image_size
elif model_type in ["SDXL", "SDXL-Refiner", "Playground"]:
elif model_type in ["SDXL", "SDXL-Refiner"]:
image_size = 1024
return image_size
@@ -510,14 +506,12 @@ def create_controlnet_diffusers_config(original_config, image_size: int):
return controlnet_config
def create_vae_diffusers_config(original_config, image_size, scaling_factor=None, latents_mean=None, latents_std=None):
def create_vae_diffusers_config(original_config, image_size, scaling_factor=None):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
if (scaling_factor is None) and (latents_mean is not None) and (latents_std is not None):
scaling_factor = PLAYGROUND_VAE_SCALING_FACTOR
elif (scaling_factor is None) and ("scale_factor" in original_config["model"]["params"]):
if scaling_factor is None and "scale_factor" in original_config["model"]["params"]:
scaling_factor = original_config["model"]["params"]["scale_factor"]
elif scaling_factor is None:
scaling_factor = LDM_VAE_DEFAULT_SCALING_FACTOR
@@ -537,8 +531,6 @@ def create_vae_diffusers_config(original_config, image_size, scaling_factor=None
"layers_per_block": vae_params["num_res_blocks"],
"scaling_factor": scaling_factor,
}
if latents_mean is not None and latents_std is not None:
config.update({"latents_mean": latents_mean, "latents_std": latents_std})
return config
@@ -1180,7 +1172,6 @@ def create_diffusers_unet_model_from_ldm(
extract_ema=False,
image_size=None,
torch_dtype=None,
model_type=None,
):
from ..models import UNet2DConditionModel
@@ -1199,9 +1190,7 @@ def create_diffusers_unet_model_from_ldm(
else:
num_in_channels = 4
image_size = set_image_size(
pipeline_class_name, original_config, checkpoint, image_size=image_size, model_type=model_type
)
image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size)
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["in_channels"] = num_in_channels
unet_config["upcast_attention"] = upcast_attention
@@ -1234,40 +1223,14 @@ def create_diffusers_unet_model_from_ldm(
def create_diffusers_vae_model_from_ldm(
pipeline_class_name,
original_config,
checkpoint,
image_size=None,
scaling_factor=None,
torch_dtype=None,
model_type=None,
pipeline_class_name, original_config, checkpoint, image_size=None, scaling_factor=None, torch_dtype=None
):
# import here to avoid circular imports
from ..models import AutoencoderKL
image_size = set_image_size(
pipeline_class_name, original_config, checkpoint, image_size=image_size, model_type=model_type
)
model_type = infer_model_type(original_config, checkpoint, model_type)
image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size)
if model_type == "Playground":
edm_mean = (
checkpoint["edm_mean"].to(dtype=torch_dtype).tolist() if torch_dtype else checkpoint["edm_mean"].tolist()
)
edm_std = (
checkpoint["edm_std"].to(dtype=torch_dtype).tolist() if torch_dtype else checkpoint["edm_std"].tolist()
)
else:
edm_mean = None
edm_std = None
vae_config = create_vae_diffusers_config(
original_config,
image_size=image_size,
scaling_factor=scaling_factor,
latents_mean=edm_mean,
latents_std=edm_std,
)
vae_config = create_vae_diffusers_config(original_config, image_size=image_size, scaling_factor=scaling_factor)
diffusers_format_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
@@ -1302,7 +1265,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
local_files_only=False,
torch_dtype=None,
):
model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)
model_type = infer_model_type(original_config, model_type=model_type)
if model_type == "FrozenOpenCLIPEmbedder":
config_name = "stabilityai/stable-diffusion-2"
@@ -1369,7 +1332,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
"text_encoder_2": text_encoder_2,
}
elif model_type in ["SDXL", "Playground"]:
elif model_type == "SDXL":
try:
config_name = "openai/clip-vit-large-patch14"
tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
@@ -1420,7 +1383,7 @@ def create_scheduler_from_ldm(
model_type=None,
):
scheduler_config = get_default_scheduler_config()
model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)
model_type = infer_model_type(original_config, model_type=model_type)
global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
@@ -1443,8 +1406,7 @@ def create_scheduler_from_ldm(
if model_type in ["SDXL", "SDXL-Refiner"]:
scheduler_type = "euler"
elif model_type == "Playground":
scheduler_type = "edm_dpm_solver_multistep"
else:
beta_start = original_config["model"]["params"].get("linear_start", 0.02)
beta_end = original_config["model"]["params"].get("linear_end", 0.085)
@@ -1476,26 +1438,6 @@ def create_scheduler_from_ldm(
elif scheduler_type == "ddim":
scheduler = DDIMScheduler.from_config(scheduler_config)
elif scheduler_type == "edm_dpm_solver_multistep":
scheduler_config = {
"algorithm_type": "dpmsolver++",
"dynamic_thresholding_ratio": 0.995,
"euler_at_final": False,
"final_sigmas_type": "zero",
"lower_order_final": True,
"num_train_timesteps": 1000,
"prediction_type": "epsilon",
"rho": 7.0,
"sample_max_value": 1.0,
"sigma_data": 0.5,
"sigma_max": 80.0,
"sigma_min": 0.002,
"solver_order": 2,
"solver_type": "midpoint",
"thresholding": False,
}
scheduler = EDMDPMSolverMultistepScheduler(**scheduler_config)
else:
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")

View File

@@ -440,6 +440,7 @@ class TemporalBasicTransformerBlock(nn.Module):
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
self.norm_in = nn.LayerNorm(dim)
self.ff_in = FeedForward(
dim,
dim_out=time_mix_inner_dim,

View File

@@ -592,13 +592,15 @@ class StableCascadeUNet(ModelMixin, ConfigMixin):
# Model Blocks
x = self.embedding(sample)
# Interpolate operations are always run in fp32 in the original implementation
if hasattr(self, "effnet_mapper") and effnet is not None:
x = x + self.effnet_mapper(
nn.functional.interpolate(effnet, size=x.shape[-2:], mode="bilinear", align_corners=True)
nn.functional.interpolate(effnet.float(), size=x.shape[-2:], mode="bilinear", align_corners=True)
)
if hasattr(self, "pixels_mapper"):
x = x + nn.functional.interpolate(
self.pixels_mapper(pixels), size=x.shape[-2:], mode="bilinear", align_corners=True
self.pixels_mapper(pixels).float(), size=x.shape[-2:], mode="bilinear", align_corners=True
)
level_outputs = self._down_encode(x, timestep_ratio_embed, clip)
x = self._up_decode(level_outputs, timestep_ratio_embed, clip)

View File

@@ -99,13 +99,14 @@ class SDFunctionTesterMixin:
assert np.abs(output_2[0].flatten() - output_1[0].flatten()).max() < 1e-2
def test_vae_tiling(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
# make sure here that pndm scheduler skips prk
if "safety_checker" in components:
components["safety_checker"] = None
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
@@ -125,7 +126,7 @@ class SDFunctionTesterMixin:
# test that tiled decode works with various shapes
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
for shape in shapes:
zeros = torch.zeros(shape).to(torch_device)
zeros = torch.zeros(shape).to(device)
pipe.vae.decode(zeros)
def test_freeu_enabled(self):