mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-18 06:28:06 +08:00
up
This commit is contained in:
@@ -39,7 +39,7 @@ from .modular_pipeline import HeliosModularPipeline
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def sample_block_noise(batch_size, channel, num_frames, height, width, gamma, patch_size=(1, 2, 2)):
|
||||
def sample_block_noise(batch_size, channel, num_frames, height, width, gamma, patch_size=(1, 2, 2), device=None):
|
||||
"""Generate spatially-correlated block noise for pyramid upsampling correction.
|
||||
|
||||
Uses a multivariate normal distribution with covariance based on `gamma` to produce noise with block structure,
|
||||
@@ -48,9 +48,11 @@ def sample_block_noise(batch_size, channel, num_frames, height, width, gamma, pa
|
||||
_, ph, pw = patch_size
|
||||
block_size = ph * pw
|
||||
|
||||
cov = torch.eye(block_size) * (1 + gamma) - torch.ones(block_size, block_size) * gamma
|
||||
cov += torch.eye(block_size) * 1e-6
|
||||
dist = torch.distributions.MultivariateNormal(torch.zeros(block_size, device=cov.device), covariance_matrix=cov)
|
||||
cov = (
|
||||
torch.eye(block_size, device=device) * (1 + gamma) - torch.ones(block_size, block_size, device=device) * gamma
|
||||
)
|
||||
cov += torch.eye(block_size, device=device) * 1e-6
|
||||
dist = torch.distributions.MultivariateNormal(torch.zeros(block_size, device=device), covariance_matrix=cov)
|
||||
block_number = batch_size * channel * num_frames * (height // ph) * (width // pw)
|
||||
|
||||
noise = dist.sample((block_number,)) # [block number, block_size]
|
||||
|
||||
Reference in New Issue
Block a user