Compare commits

...

1 Commits

Author SHA1 Message Date
sayakpaul
f93c7807e4 fix small nits in pixart sigma 2024-04-24 15:57:46 +05:30
2 changed files with 3 additions and 21 deletions

View File

@@ -273,15 +273,6 @@ class PixArtAlphaPipeline(DiffusionPipeline):
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
def mask_text_embeddings(self, emb, mask):
if emb.shape[0] == 1:
keep_index = mask.sum().item()
return emb[:, :, :keep_index, :], keep_index
else:
masked_feature = emb * mask[:, None, :, None]
return masked_feature, emb.shape[2]
# Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
def encode_prompt(
self,

View File

@@ -199,16 +199,7 @@ class PixArtSigmaPipeline(DiffusionPipeline):
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
# copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.py
def mask_text_embeddings(self, emb, mask):
if emb.shape[0] == 1:
keep_index = mask.sum().item()
return emb[:, :, :keep_index, :], keep_index
else:
masked_feature = emb * mask[:, None, :, None]
return masked_feature, emb.shape[2]
# Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
# Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
@@ -369,7 +360,7 @@ class PixArtSigmaPipeline(DiffusionPipeline):
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.py
# Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.check_inputs
def check_inputs(
self,
prompt,
@@ -462,7 +453,7 @@ class PixArtSigmaPipeline(DiffusionPipeline):
return [process(t) for t in text]
# Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline._clean_caption
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
def _clean_caption(self, caption):
caption = str(caption)
caption = ul.unquote_plus(caption)