Compare commits

...

7 Commits

Author SHA1 Message Date
patil-suraj
a61c6079c9 fix vqf 2022-12-02 16:29:19 +01:00
patil-suraj
a7e651c75e fix factor 2022-12-02 16:19:23 +01:00
patil-suraj
87e39484b8 fix uf 2022-12-02 14:38:25 +01:00
patil-suraj
e2bc5e54b5 fix decodeing 2022-12-02 14:30:47 +01:00
patil-suraj
1a773f6d74 meshgrid 2022-12-02 14:28:54 +01:00
patil-suraj
2df84a57da delta border 2022-12-02 14:25:07 +01:00
patil-suraj
b67d30e95b split decode 2022-12-02 14:17:15 +01:00

View File

@@ -603,17 +603,163 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
self.use_slicing = False
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
# if self.use_slicing and z.shape[0] > 1:
# decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
# decoded = torch.cat(decoded_slices)
# else:
# decoded = self._decode(z).sample
decoded = self.split_decode(z)
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def meshgrid(self, h, w):
y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
arr = torch.cat([y, x], dim=-1)
return arr
def delta_border(self, h, w):
"""
:param h: height :param w: width :return: normalized distance to image border,
wtith min distance = 0 at border and max dist = 0.5 at image center
"""
lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
arr = self.meshgrid(h, w) / lower_right_corner
dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
return edge_dist
def get_weighting(self, h, w, Ly, Lx, device):
weighting = self.delta_border(h, w)
weighting = torch.clip(
weighting,
self.split_input_params["clip_min_weight"],
self.split_input_params["clip_max_weight"],
)
weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
if self.split_input_params["tie_braker"]:
L_weighting = self.delta_border(Ly, Lx)
L_weighting = torch.clip(
L_weighting,
self.split_input_params["clip_min_tie_weight"],
self.split_input_params["clip_max_tie_weight"],
)
L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
weighting = weighting * L_weighting
return weighting
def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
"""
:param x: img of size (bs, c, h, w) :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
"""
bs, nc, h, w = x.shape
# number of crops in image
Ly = (h - kernel_size[0]) // stride[0] + 1
Lx = (w - kernel_size[1]) // stride[1] + 1
if uf == 1 and df == 1:
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
unfold = torch.nn.Unfold(**fold_params)
fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
elif uf > 1 and df == 1:
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
unfold = torch.nn.Unfold(**fold_params)
fold_params2 = dict(
kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
dilation=1,
padding=0,
stride=(stride[0] * uf, stride[1] * uf),
)
fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
elif df > 1 and uf == 1:
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
unfold = torch.nn.Unfold(**fold_params)
fold_params2 = dict(
kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
dilation=1,
padding=0,
stride=(stride[0] // df, stride[1] // df),
)
fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
else:
raise NotImplementedError
return fold, unfold, normalization, weighting
def split_decode(self, z: torch.FloatTensor) -> torch.FloatTensor:
ks = 128
stride = 64
vqf = 2 ** (len(self.config.block_out_channels) - 1)
self.split_input_params = {
"ks": (ks, ks),
"stride": (stride, stride),
"vqf": vqf,
"patch_distributed_vq": True,
"tie_braker": False,
"clip_max_weight": 0.5,
"clip_min_weight": 0.01,
"clip_max_tie_weight": 0.5,
"clip_min_tie_weight": 0.01,
}
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
uf = self.split_input_params["vqf"]
bs, nc, h, w = z.shape
if ks[0] > h or ks[1] > w:
ks = (min(ks[0], h), min(ks[1], w))
print("reducing Kernel")
if stride[0] > h or stride[1] > w:
stride = (min(stride[0], h), min(stride[1], w))
print("reducing stride")
fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=vqf)
z = unfold(z) # (bn, nc * prod(**ks), L)
# 1. Reshape to img shape
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
# 2. apply model loop over last dim
output_list = [self._decode(z[:, :, :, :, i]).sample for i in range(z.shape[-1])]
o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
o = o * weighting
# Reverse 1. reshape to img shape
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
# stitch crops together
decoded = fold(o)
decoded = decoded / normalization # norm is shape (1, 1, h, w)
return decoded
def forward(
self,
sample: torch.FloatTensor,