mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-13 07:54:45 +08:00
* Init commit * Removed einops * Added default movq config for training * Update explanation of prompts * Fixed inheritance of discriminator and init_tracker * Fixed incompatible api between muse and here * Fixed output * Setup init training * Basic structure done * Removed attention for quick tests * Style fixes * Fixed vae/vqgan styles * Removed redefinition of wandb * Fixed log_validation and tqdm * Nothing commit * Added commit loss to lookup_from_codebook * Update src/diffusers/models/vq_model.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Adding perliminary README * Fixed one typo * Local changes * Fixed main issues * Merging * Update src/diffusers/models/vq_model.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Testing+Fixed bugs in training script * Some style fixes * Added wandb to docs * Fixed timm test * get testing suite ready. * remove return loss * remove return_loss * Remove diffs * Remove diffs * fix ruff format --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
49 lines
1.7 KiB
Python
49 lines
1.7 KiB
Python
"""
|
|
Ported from Paella
|
|
"""
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
|
from diffusers.models.modeling_utils import ModelMixin
|
|
|
|
|
|
# Discriminator model ported from Paella https://github.com/dome272/Paella/blob/main/src_distributed/vqgan.py
|
|
class Discriminator(ModelMixin, ConfigMixin):
|
|
@register_to_config
|
|
def __init__(self, in_channels=3, cond_channels=0, hidden_channels=512, depth=6):
|
|
super().__init__()
|
|
d = max(depth - 3, 3)
|
|
layers = [
|
|
nn.utils.spectral_norm(
|
|
nn.Conv2d(in_channels, hidden_channels // (2**d), kernel_size=3, stride=2, padding=1)
|
|
),
|
|
nn.LeakyReLU(0.2),
|
|
]
|
|
for i in range(depth - 1):
|
|
c_in = hidden_channels // (2 ** max((d - i), 0))
|
|
c_out = hidden_channels // (2 ** max((d - 1 - i), 0))
|
|
layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
|
|
layers.append(nn.InstanceNorm2d(c_out))
|
|
layers.append(nn.LeakyReLU(0.2))
|
|
self.encoder = nn.Sequential(*layers)
|
|
self.shuffle = nn.Conv2d(
|
|
(hidden_channels + cond_channels) if cond_channels > 0 else hidden_channels, 1, kernel_size=1
|
|
)
|
|
self.logits = nn.Sigmoid()
|
|
|
|
def forward(self, x, cond=None):
|
|
x = self.encoder(x)
|
|
if cond is not None:
|
|
cond = cond.view(
|
|
cond.size(0),
|
|
cond.size(1),
|
|
1,
|
|
1,
|
|
).expand(-1, -1, x.size(-2), x.size(-1))
|
|
x = torch.cat([x, cond], dim=1)
|
|
x = self.shuffle(x)
|
|
x = self.logits(x)
|
|
return x
|