mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
7 Commits
single-fil
...
hires-upsa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a61c6079c9 | ||
|
|
a7e651c75e | ||
|
|
87e39484b8 | ||
|
|
e2bc5e54b5 | ||
|
|
1a773f6d74 | ||
|
|
2df84a57da | ||
|
|
b67d30e95b |
@@ -603,17 +603,163 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
self.use_slicing = False
|
||||
|
||||
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z).sample
|
||||
# if self.use_slicing and z.shape[0] > 1:
|
||||
# decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
# decoded = torch.cat(decoded_slices)
|
||||
# else:
|
||||
# decoded = self._decode(z).sample
|
||||
|
||||
decoded = self.split_decode(z)
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def meshgrid(self, h, w):
|
||||
y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
|
||||
x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
|
||||
|
||||
arr = torch.cat([y, x], dim=-1)
|
||||
return arr
|
||||
|
||||
def delta_border(self, h, w):
|
||||
"""
|
||||
:param h: height :param w: width :return: normalized distance to image border,
|
||||
wtith min distance = 0 at border and max dist = 0.5 at image center
|
||||
"""
|
||||
lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
|
||||
arr = self.meshgrid(h, w) / lower_right_corner
|
||||
dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
|
||||
dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
|
||||
edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
|
||||
return edge_dist
|
||||
|
||||
def get_weighting(self, h, w, Ly, Lx, device):
|
||||
weighting = self.delta_border(h, w)
|
||||
weighting = torch.clip(
|
||||
weighting,
|
||||
self.split_input_params["clip_min_weight"],
|
||||
self.split_input_params["clip_max_weight"],
|
||||
)
|
||||
weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
|
||||
|
||||
if self.split_input_params["tie_braker"]:
|
||||
L_weighting = self.delta_border(Ly, Lx)
|
||||
L_weighting = torch.clip(
|
||||
L_weighting,
|
||||
self.split_input_params["clip_min_tie_weight"],
|
||||
self.split_input_params["clip_max_tie_weight"],
|
||||
)
|
||||
|
||||
L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
|
||||
weighting = weighting * L_weighting
|
||||
return weighting
|
||||
|
||||
def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
|
||||
"""
|
||||
:param x: img of size (bs, c, h, w) :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
|
||||
"""
|
||||
bs, nc, h, w = x.shape
|
||||
|
||||
# number of crops in image
|
||||
Ly = (h - kernel_size[0]) // stride[0] + 1
|
||||
Lx = (w - kernel_size[1]) // stride[1] + 1
|
||||
|
||||
if uf == 1 and df == 1:
|
||||
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
|
||||
unfold = torch.nn.Unfold(**fold_params)
|
||||
|
||||
fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
|
||||
|
||||
weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
|
||||
normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
|
||||
weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
|
||||
|
||||
elif uf > 1 and df == 1:
|
||||
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
|
||||
unfold = torch.nn.Unfold(**fold_params)
|
||||
|
||||
fold_params2 = dict(
|
||||
kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
|
||||
dilation=1,
|
||||
padding=0,
|
||||
stride=(stride[0] * uf, stride[1] * uf),
|
||||
)
|
||||
fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
|
||||
|
||||
weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
|
||||
normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
|
||||
weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
|
||||
|
||||
elif df > 1 and uf == 1:
|
||||
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
|
||||
unfold = torch.nn.Unfold(**fold_params)
|
||||
|
||||
fold_params2 = dict(
|
||||
kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
|
||||
dilation=1,
|
||||
padding=0,
|
||||
stride=(stride[0] // df, stride[1] // df),
|
||||
)
|
||||
fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
|
||||
|
||||
weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
|
||||
normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
|
||||
weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return fold, unfold, normalization, weighting
|
||||
|
||||
def split_decode(self, z: torch.FloatTensor) -> torch.FloatTensor:
|
||||
ks = 128
|
||||
stride = 64
|
||||
vqf = 2 ** (len(self.config.block_out_channels) - 1)
|
||||
self.split_input_params = {
|
||||
"ks": (ks, ks),
|
||||
"stride": (stride, stride),
|
||||
"vqf": vqf,
|
||||
"patch_distributed_vq": True,
|
||||
"tie_braker": False,
|
||||
"clip_max_weight": 0.5,
|
||||
"clip_min_weight": 0.01,
|
||||
"clip_max_tie_weight": 0.5,
|
||||
"clip_min_tie_weight": 0.01,
|
||||
}
|
||||
|
||||
ks = self.split_input_params["ks"] # eg. (128, 128)
|
||||
stride = self.split_input_params["stride"] # eg. (64, 64)
|
||||
uf = self.split_input_params["vqf"]
|
||||
bs, nc, h, w = z.shape
|
||||
if ks[0] > h or ks[1] > w:
|
||||
ks = (min(ks[0], h), min(ks[1], w))
|
||||
print("reducing Kernel")
|
||||
|
||||
if stride[0] > h or stride[1] > w:
|
||||
stride = (min(stride[0], h), min(stride[1], w))
|
||||
print("reducing stride")
|
||||
|
||||
fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=vqf)
|
||||
|
||||
z = unfold(z) # (bn, nc * prod(**ks), L)
|
||||
# 1. Reshape to img shape
|
||||
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
||||
|
||||
# 2. apply model loop over last dim
|
||||
|
||||
output_list = [self._decode(z[:, :, :, :, i]).sample for i in range(z.shape[-1])]
|
||||
|
||||
o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
|
||||
o = o * weighting
|
||||
# Reverse 1. reshape to img shape
|
||||
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
||||
# stitch crops together
|
||||
decoded = fold(o)
|
||||
decoded = decoded / normalization # norm is shape (1, 1, h, w)
|
||||
return decoded
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
|
||||
Reference in New Issue
Block a user