mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-18 18:34:37 +08:00
Compare commits
4 Commits
cleanup-te
...
flax-utils
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d72adb3ca8 | ||
|
|
160474ac61 | ||
|
|
c10861ee1b | ||
|
|
94b332c476 |
@@ -592,14 +592,14 @@ class DreamBoothDataset(Dataset):
|
||||
size=512,
|
||||
center_crop=False,
|
||||
encoder_hidden_states=None,
|
||||
instance_prompt_encoder_hidden_states=None,
|
||||
class_prompt_encoder_hidden_states=None,
|
||||
tokenizer_max_length=None,
|
||||
):
|
||||
self.size = size
|
||||
self.center_crop = center_crop
|
||||
self.tokenizer = tokenizer
|
||||
self.encoder_hidden_states = encoder_hidden_states
|
||||
self.instance_prompt_encoder_hidden_states = instance_prompt_encoder_hidden_states
|
||||
self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
|
||||
self.tokenizer_max_length = tokenizer_max_length
|
||||
|
||||
self.instance_data_root = Path(instance_data_root)
|
||||
@@ -662,8 +662,8 @@ class DreamBoothDataset(Dataset):
|
||||
class_image = class_image.convert("RGB")
|
||||
example["class_images"] = self.image_transforms(class_image)
|
||||
|
||||
if self.instance_prompt_encoder_hidden_states is not None:
|
||||
example["class_prompt_ids"] = self.instance_prompt_encoder_hidden_states
|
||||
if self.class_prompt_encoder_hidden_states is not None:
|
||||
example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
|
||||
else:
|
||||
class_text_inputs = tokenize_prompt(
|
||||
self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
|
||||
@@ -1027,10 +1027,10 @@ def main(args):
|
||||
else:
|
||||
validation_prompt_encoder_hidden_states = None
|
||||
|
||||
if args.instance_prompt is not None:
|
||||
pre_computed_instance_prompt_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
|
||||
if args.class_prompt is not None:
|
||||
pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt)
|
||||
else:
|
||||
pre_computed_instance_prompt_encoder_hidden_states = None
|
||||
pre_computed_class_prompt_encoder_hidden_states = None
|
||||
|
||||
text_encoder = None
|
||||
tokenizer = None
|
||||
@@ -1041,7 +1041,7 @@ def main(args):
|
||||
pre_computed_encoder_hidden_states = None
|
||||
validation_prompt_encoder_hidden_states = None
|
||||
validation_prompt_negative_prompt_embeds = None
|
||||
pre_computed_instance_prompt_encoder_hidden_states = None
|
||||
pre_computed_class_prompt_encoder_hidden_states = None
|
||||
|
||||
# Dataset and DataLoaders creation:
|
||||
train_dataset = DreamBoothDataset(
|
||||
@@ -1054,7 +1054,7 @@ def main(args):
|
||||
size=args.resolution,
|
||||
center_crop=args.center_crop,
|
||||
encoder_hidden_states=pre_computed_encoder_hidden_states,
|
||||
instance_prompt_encoder_hidden_states=pre_computed_instance_prompt_encoder_hidden_states,
|
||||
class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,
|
||||
tokenizer_max_length=args.tokenizer_max_length,
|
||||
)
|
||||
|
||||
|
||||
@@ -492,14 +492,14 @@ class DreamBoothDataset(Dataset):
|
||||
size=512,
|
||||
center_crop=False,
|
||||
encoder_hidden_states=None,
|
||||
instance_prompt_encoder_hidden_states=None,
|
||||
class_prompt_encoder_hidden_states=None,
|
||||
tokenizer_max_length=None,
|
||||
):
|
||||
self.size = size
|
||||
self.center_crop = center_crop
|
||||
self.tokenizer = tokenizer
|
||||
self.encoder_hidden_states = encoder_hidden_states
|
||||
self.instance_prompt_encoder_hidden_states = instance_prompt_encoder_hidden_states
|
||||
self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
|
||||
self.tokenizer_max_length = tokenizer_max_length
|
||||
|
||||
self.instance_data_root = Path(instance_data_root)
|
||||
@@ -562,8 +562,8 @@ class DreamBoothDataset(Dataset):
|
||||
class_image = class_image.convert("RGB")
|
||||
example["class_images"] = self.image_transforms(class_image)
|
||||
|
||||
if self.instance_prompt_encoder_hidden_states is not None:
|
||||
example["class_prompt_ids"] = self.instance_prompt_encoder_hidden_states
|
||||
if self.class_prompt_encoder_hidden_states is not None:
|
||||
example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
|
||||
else:
|
||||
class_text_inputs = tokenize_prompt(
|
||||
self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
|
||||
@@ -993,10 +993,10 @@ def main(args):
|
||||
else:
|
||||
validation_prompt_encoder_hidden_states = None
|
||||
|
||||
if args.instance_prompt is not None:
|
||||
pre_computed_instance_prompt_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
|
||||
if args.class_prompt is not None:
|
||||
pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
|
||||
else:
|
||||
pre_computed_instance_prompt_encoder_hidden_states = None
|
||||
pre_computed_class_prompt_encoder_hidden_states = None
|
||||
|
||||
text_encoder = None
|
||||
tokenizer = None
|
||||
@@ -1007,7 +1007,7 @@ def main(args):
|
||||
pre_computed_encoder_hidden_states = None
|
||||
validation_prompt_encoder_hidden_states = None
|
||||
validation_prompt_negative_prompt_embeds = None
|
||||
pre_computed_instance_prompt_encoder_hidden_states = None
|
||||
pre_computed_class_prompt_encoder_hidden_states = None
|
||||
|
||||
# Dataset and DataLoaders creation:
|
||||
train_dataset = DreamBoothDataset(
|
||||
@@ -1020,7 +1020,7 @@ def main(args):
|
||||
size=args.resolution,
|
||||
center_crop=args.center_crop,
|
||||
encoder_hidden_states=pre_computed_encoder_hidden_states,
|
||||
instance_prompt_encoder_hidden_states=pre_computed_instance_prompt_encoder_hidden_states,
|
||||
class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,
|
||||
tokenizer_max_length=args.tokenizer_max_length,
|
||||
)
|
||||
|
||||
|
||||
@@ -357,10 +357,29 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
# extract them here
|
||||
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
||||
|
||||
init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
init_kwargs = {}
|
||||
# define init kwargs
|
||||
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
|
||||
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
|
||||
|
||||
# remove `null` components
|
||||
def load_module(name, value):
|
||||
if value[0] is None:
|
||||
return False
|
||||
if name in passed_class_obj and passed_class_obj[name] is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
|
||||
|
||||
# Throw nice warnings / errors for fast accelerate loading
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(
|
||||
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
|
||||
)
|
||||
|
||||
# inference_params
|
||||
params = {}
|
||||
|
||||
@@ -1186,7 +1186,6 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
)
|
||||
@@ -1543,7 +1542,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs
|
||||
)
|
||||
|
||||
pipe = StableDiffusionXLPipeline(
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
|
||||
@@ -528,6 +528,7 @@ class PipelineTesterMixin:
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe_fp16 = self.pipeline_class(**components)
|
||||
pipe_fp16.to(torch_device, torch.float16)
|
||||
pipe_fp16.set_progress_bar_config(disable=None)
|
||||
|
||||
Reference in New Issue
Block a user