Files
diffusers/examples/vqgan/discriminator.py
Isamu Isozaki d27e996ccd Adding VQGAN Training script (#5483)
* Init commit

* Removed einops

* Added default movq config for training

* Update explanation of prompts

* Fixed inheritance of discriminator and init_tracker

* Fixed incompatible api between muse and here

* Fixed output

* Setup init training

* Basic structure done

* Removed attention for quick tests

* Style fixes

* Fixed vae/vqgan styles

* Removed redefinition of wandb

* Fixed log_validation and tqdm

* Nothing commit

* Added commit loss to lookup_from_codebook

* Update src/diffusers/models/vq_model.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Adding perliminary README

* Fixed one typo

* Local changes

* Fixed main issues

* Merging

* Update src/diffusers/models/vq_model.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Testing+Fixed bugs in training script

* Some style fixes

* Added wandb to docs

* Fixed timm test

* get testing suite ready.

* remove return loss

* remove return_loss

* Remove diffs

* Remove diffs

* fix ruff format

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2024-05-15 08:47:12 +05:30

49 lines
1.7 KiB
Python

"""
Ported from Paella
"""
import torch
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
# Discriminator model ported from Paella https://github.com/dome272/Paella/blob/main/src_distributed/vqgan.py
class Discriminator(ModelMixin, ConfigMixin):
@register_to_config
def __init__(self, in_channels=3, cond_channels=0, hidden_channels=512, depth=6):
super().__init__()
d = max(depth - 3, 3)
layers = [
nn.utils.spectral_norm(
nn.Conv2d(in_channels, hidden_channels // (2**d), kernel_size=3, stride=2, padding=1)
),
nn.LeakyReLU(0.2),
]
for i in range(depth - 1):
c_in = hidden_channels // (2 ** max((d - i), 0))
c_out = hidden_channels // (2 ** max((d - 1 - i), 0))
layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
layers.append(nn.InstanceNorm2d(c_out))
layers.append(nn.LeakyReLU(0.2))
self.encoder = nn.Sequential(*layers)
self.shuffle = nn.Conv2d(
(hidden_channels + cond_channels) if cond_channels > 0 else hidden_channels, 1, kernel_size=1
)
self.logits = nn.Sigmoid()
def forward(self, x, cond=None):
x = self.encoder(x)
if cond is not None:
cond = cond.view(
cond.size(0),
cond.size(1),
1,
1,
).expand(-1, -1, x.size(-2), x.size(-1))
x = torch.cat([x, cond], dim=1)
x = self.shuffle(x)
x = self.logits(x)
return x