This commit is contained in:
yiyixuxu
2026-03-09 03:28:10 +01:00
parent 921a7e77c8
commit 6cd5fb7d6c

View File

@@ -39,7 +39,7 @@ from .modular_pipeline import HeliosModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def sample_block_noise(batch_size, channel, num_frames, height, width, gamma, patch_size=(1, 2, 2)):
def sample_block_noise(batch_size, channel, num_frames, height, width, gamma, patch_size=(1, 2, 2), device=None):
"""Generate spatially-correlated block noise for pyramid upsampling correction.
Uses a multivariate normal distribution with covariance based on `gamma` to produce noise with block structure,
@@ -48,9 +48,11 @@ def sample_block_noise(batch_size, channel, num_frames, height, width, gamma, pa
_, ph, pw = patch_size
block_size = ph * pw
cov = torch.eye(block_size) * (1 + gamma) - torch.ones(block_size, block_size) * gamma
cov += torch.eye(block_size) * 1e-6
dist = torch.distributions.MultivariateNormal(torch.zeros(block_size, device=cov.device), covariance_matrix=cov)
cov = (
torch.eye(block_size, device=device) * (1 + gamma) - torch.ones(block_size, block_size, device=device) * gamma
)
cov += torch.eye(block_size, device=device) * 1e-6
dist = torch.distributions.MultivariateNormal(torch.zeros(block_size, device=device), covariance_matrix=cov)
block_number = batch_size * channel * num_frames * (height // ph) * (width // pw)
noise = dist.sample((block_number,)) # [block number, block_size]