Compare commits

...

3 Commits

Author SHA1 Message Date
Sayak Paul
e35f377a3a Merge branch 'main' into latte/chunked-ffn 2024-07-12 09:53:32 +02:00
Aryan
239e5107b3 Merge branch 'main' into latte/chunked-ffn 2024-07-11 21:52:16 +05:30
Aryan
953f4e6667 add feed forward chunking to latte 2024-07-11 18:18:48 +02:00
2 changed files with 62 additions and 0 deletions

View File

@@ -165,6 +165,48 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
Sets the attention processor to use [feed forward
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
Parameters:
chunk_size (`int`, *optional*):
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
over each tensor of dim=`dim`.
dim (`int`, *optional*, defaults to `0`):
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
or dim=1 (sequence length).
"""
if dim not in [0, 1]:
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
# By default chunk size is 1
chunk_size = chunk_size or 1
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)
for module in self.children():
fn_recursive_feed_forward(module, chunk_size, dim)
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
def disable_forward_chunking(self):
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)
for module in self.children():
fn_recursive_feed_forward(module, None, 0)
def forward(
self,
hidden_states: torch.Tensor,

View File

@@ -256,6 +256,26 @@ class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(max_diff, 1.0)
def test_feed_forward_chunking(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs)[0]
image_slice_no_chunking = image[0, -3:, -3:, -1]
pipe.transformer.enable_forward_chunking(chunk_size=1, dim=0)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs)[0]
image_slice_chunking = image[0, -3:, -3:, -1]
max_diff = np.abs(to_np(image_slice_no_chunking) - to_np(image_slice_chunking)).max()
self.assertLess(max_diff, 1e-4)
@slow
@require_torch_gpu