mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-24 05:14:55 +08:00
Compare commits
3 Commits
remove-unn
...
latte/chun
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e35f377a3a | ||
|
|
239e5107b3 | ||
|
|
953f4e6667 |
@@ -165,6 +165,48 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
|
|||||||
def _set_gradient_checkpointing(self, module, value=False):
|
def _set_gradient_checkpointing(self, module, value=False):
|
||||||
self.gradient_checkpointing = value
|
self.gradient_checkpointing = value
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
||||||
|
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
||||||
|
"""
|
||||||
|
Sets the attention processor to use [feed forward
|
||||||
|
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
chunk_size (`int`, *optional*):
|
||||||
|
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
|
||||||
|
over each tensor of dim=`dim`.
|
||||||
|
dim (`int`, *optional*, defaults to `0`):
|
||||||
|
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
|
||||||
|
or dim=1 (sequence length).
|
||||||
|
"""
|
||||||
|
if dim not in [0, 1]:
|
||||||
|
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
|
||||||
|
|
||||||
|
# By default chunk size is 1
|
||||||
|
chunk_size = chunk_size or 1
|
||||||
|
|
||||||
|
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
||||||
|
if hasattr(module, "set_chunk_feed_forward"):
|
||||||
|
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
||||||
|
|
||||||
|
for child in module.children():
|
||||||
|
fn_recursive_feed_forward(child, chunk_size, dim)
|
||||||
|
|
||||||
|
for module in self.children():
|
||||||
|
fn_recursive_feed_forward(module, chunk_size, dim)
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
|
||||||
|
def disable_forward_chunking(self):
|
||||||
|
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
||||||
|
if hasattr(module, "set_chunk_feed_forward"):
|
||||||
|
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
||||||
|
|
||||||
|
for child in module.children():
|
||||||
|
fn_recursive_feed_forward(child, chunk_size, dim)
|
||||||
|
|
||||||
|
for module in self.children():
|
||||||
|
fn_recursive_feed_forward(module, None, 0)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|||||||
@@ -256,6 +256,26 @@ class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|||||||
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
|
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
|
||||||
self.assertLess(max_diff, 1.0)
|
self.assertLess(max_diff, 1.0)
|
||||||
|
|
||||||
|
def test_feed_forward_chunking(self):
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
pipe.to(device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
image = pipe(**inputs)[0]
|
||||||
|
image_slice_no_chunking = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
pipe.transformer.enable_forward_chunking(chunk_size=1, dim=0)
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
image = pipe(**inputs)[0]
|
||||||
|
image_slice_chunking = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
max_diff = np.abs(to_np(image_slice_no_chunking) - to_np(image_slice_chunking)).max()
|
||||||
|
self.assertLess(max_diff, 1e-4)
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
|
|||||||
Reference in New Issue
Block a user