Compare commits

...

1 Commits

Author SHA1 Message Date
sayakpaul
b0b8008fa9 fix autoencoderkl qwenimage for xla 2026-04-15 15:36:06 +05:30

View File

@@ -180,7 +180,7 @@ class QwenImageResample(nn.Module):
feat_cache[idx] = "Rep"
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
cache_x = x[:, :, -min(CACHE_T, x.shape[2]) :, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
# cache last frame of last two chunk
cache_x = torch.cat(
@@ -258,7 +258,7 @@ class QwenImageResidualBlock(nn.Module):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
cache_x = x[:, :, -min(CACHE_T, x.shape[2]) :, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
@@ -277,7 +277,7 @@ class QwenImageResidualBlock(nn.Module):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
cache_x = x[:, :, -min(CACHE_T, x.shape[2]) :, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
@@ -446,7 +446,7 @@ class QwenImageEncoder3d(nn.Module):
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
cache_x = x[:, :, -min(CACHE_T, x.shape[2]) :, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
@@ -471,7 +471,7 @@ class QwenImageEncoder3d(nn.Module):
x = self.nonlinearity(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
cache_x = x[:, :, -min(CACHE_T, x.shape[2]) :, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
@@ -636,7 +636,7 @@ class QwenImageDecoder3d(nn.Module):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
cache_x = x[:, :, -min(CACHE_T, x.shape[2]) :, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
@@ -658,7 +658,7 @@ class QwenImageDecoder3d(nn.Module):
x = self.nonlinearity(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
cache_x = x[:, :, -min(CACHE_T, x.shape[2]) :, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)