Compare commits

...

4 Commits

Author SHA1 Message Date
Pedro Cuenca
d72adb3ca8 Handle null modules and non-module params 2023-08-01 21:08:09 +02:00
Will Berman
160474ac61 train dreambooth fix pre encode class prompt (#4395) 2023-08-01 12:00:05 -07:00
YiYi Xu
c10861ee1b fix test_float16_inference (#4412)
fix

Co-authored-by: yiyixuxu <yixu310@gmail,com>
2023-08-01 07:49:50 -10:00
YiYi Xu
94b332c476 support from_single_file for SDXL inpainting (#4408)
fix

Co-authored-by: yiyixuxu <yixu310@gmail,com>
2023-08-01 07:47:22 -10:00
5 changed files with 41 additions and 22 deletions

View File

@@ -592,14 +592,14 @@ class DreamBoothDataset(Dataset):
size=512,
center_crop=False,
encoder_hidden_states=None,
instance_prompt_encoder_hidden_states=None,
class_prompt_encoder_hidden_states=None,
tokenizer_max_length=None,
):
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer
self.encoder_hidden_states = encoder_hidden_states
self.instance_prompt_encoder_hidden_states = instance_prompt_encoder_hidden_states
self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
self.tokenizer_max_length = tokenizer_max_length
self.instance_data_root = Path(instance_data_root)
@@ -662,8 +662,8 @@ class DreamBoothDataset(Dataset):
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
if self.instance_prompt_encoder_hidden_states is not None:
example["class_prompt_ids"] = self.instance_prompt_encoder_hidden_states
if self.class_prompt_encoder_hidden_states is not None:
example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
else:
class_text_inputs = tokenize_prompt(
self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
@@ -1027,10 +1027,10 @@ def main(args):
else:
validation_prompt_encoder_hidden_states = None
if args.instance_prompt is not None:
pre_computed_instance_prompt_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
if args.class_prompt is not None:
pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt)
else:
pre_computed_instance_prompt_encoder_hidden_states = None
pre_computed_class_prompt_encoder_hidden_states = None
text_encoder = None
tokenizer = None
@@ -1041,7 +1041,7 @@ def main(args):
pre_computed_encoder_hidden_states = None
validation_prompt_encoder_hidden_states = None
validation_prompt_negative_prompt_embeds = None
pre_computed_instance_prompt_encoder_hidden_states = None
pre_computed_class_prompt_encoder_hidden_states = None
# Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset(
@@ -1054,7 +1054,7 @@ def main(args):
size=args.resolution,
center_crop=args.center_crop,
encoder_hidden_states=pre_computed_encoder_hidden_states,
instance_prompt_encoder_hidden_states=pre_computed_instance_prompt_encoder_hidden_states,
class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,
tokenizer_max_length=args.tokenizer_max_length,
)

View File

@@ -492,14 +492,14 @@ class DreamBoothDataset(Dataset):
size=512,
center_crop=False,
encoder_hidden_states=None,
instance_prompt_encoder_hidden_states=None,
class_prompt_encoder_hidden_states=None,
tokenizer_max_length=None,
):
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer
self.encoder_hidden_states = encoder_hidden_states
self.instance_prompt_encoder_hidden_states = instance_prompt_encoder_hidden_states
self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
self.tokenizer_max_length = tokenizer_max_length
self.instance_data_root = Path(instance_data_root)
@@ -562,8 +562,8 @@ class DreamBoothDataset(Dataset):
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
if self.instance_prompt_encoder_hidden_states is not None:
example["class_prompt_ids"] = self.instance_prompt_encoder_hidden_states
if self.class_prompt_encoder_hidden_states is not None:
example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
else:
class_text_inputs = tokenize_prompt(
self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
@@ -993,10 +993,10 @@ def main(args):
else:
validation_prompt_encoder_hidden_states = None
if args.instance_prompt is not None:
pre_computed_instance_prompt_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
if args.class_prompt is not None:
pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
else:
pre_computed_instance_prompt_encoder_hidden_states = None
pre_computed_class_prompt_encoder_hidden_states = None
text_encoder = None
tokenizer = None
@@ -1007,7 +1007,7 @@ def main(args):
pre_computed_encoder_hidden_states = None
validation_prompt_encoder_hidden_states = None
validation_prompt_negative_prompt_embeds = None
pre_computed_instance_prompt_encoder_hidden_states = None
pre_computed_class_prompt_encoder_hidden_states = None
# Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset(
@@ -1020,7 +1020,7 @@ def main(args):
size=args.resolution,
center_crop=args.center_crop,
encoder_hidden_states=pre_computed_encoder_hidden_states,
instance_prompt_encoder_hidden_states=pre_computed_instance_prompt_encoder_hidden_states,
class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,
tokenizer_max_length=args.tokenizer_max_length,
)

View File

@@ -357,10 +357,29 @@ class FlaxDiffusionPipeline(ConfigMixin):
# extract them here
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_kwargs = {}
# define init kwargs
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
# remove `null` components
def load_module(name, value):
if value[0] is None:
return False
if name in passed_class_obj and passed_class_obj[name] is None:
return False
return True
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
# Throw nice warnings / errors for fast accelerate loading
if len(unused_kwargs) > 0:
logger.warning(
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
)
# inference_params
params = {}

View File

@@ -1186,7 +1186,6 @@ def download_from_original_stable_diffusion_ckpt(
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
)
@@ -1543,7 +1542,7 @@ def download_from_original_stable_diffusion_ckpt(
checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs
)
pipe = StableDiffusionXLPipeline(
pipe = pipeline_class(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,

View File

@@ -528,6 +528,7 @@ class PipelineTesterMixin:
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
components = self.get_dummy_components()
pipe_fp16 = self.pipeline_class(**components)
pipe_fp16.to(torch_device, torch.float16)
pipe_fp16.set_progress_bar_config(disable=None)