diff --git a/src/diffusers/modular_pipelines/helios/denoise.py b/src/diffusers/modular_pipelines/helios/denoise.py index 748a8da2cf..bb78137a46 100644 --- a/src/diffusers/modular_pipelines/helios/denoise.py +++ b/src/diffusers/modular_pipelines/helios/denoise.py @@ -39,7 +39,7 @@ from .modular_pipeline import HeliosModularPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def sample_block_noise(batch_size, channel, num_frames, height, width, gamma, patch_size=(1, 2, 2)): +def sample_block_noise(batch_size, channel, num_frames, height, width, gamma, patch_size=(1, 2, 2), device=None): """Generate spatially-correlated block noise for pyramid upsampling correction. Uses a multivariate normal distribution with covariance based on `gamma` to produce noise with block structure, @@ -48,9 +48,11 @@ def sample_block_noise(batch_size, channel, num_frames, height, width, gamma, pa _, ph, pw = patch_size block_size = ph * pw - cov = torch.eye(block_size) * (1 + gamma) - torch.ones(block_size, block_size) * gamma - cov += torch.eye(block_size) * 1e-6 - dist = torch.distributions.MultivariateNormal(torch.zeros(block_size, device=cov.device), covariance_matrix=cov) + cov = ( + torch.eye(block_size, device=device) * (1 + gamma) - torch.ones(block_size, block_size, device=device) * gamma + ) + cov += torch.eye(block_size, device=device) * 1e-6 + dist = torch.distributions.MultivariateNormal(torch.zeros(block_size, device=device), covariance_matrix=cov) block_number = batch_size * channel * num_frames * (height // ph) * (width // pw) noise = dist.sample((block_number,)) # [block number, block_size]