Compare commits

...

8 Commits

Author SHA1 Message Date
Sayak Paul
dacae33c1b Merge branch 'main' into hidream-torch-compile 2025-06-06 10:36:09 +05:30
Sayak Paul
4bac6acce6 Merge branch 'main' into hidream-torch-compile 2025-05-14 20:26:44 +05:30
sayakpaul
f9662ed8ce resolce conflicts., 2025-05-13 21:03:59 +05:30
Sayak Paul
612770b04b Merge branch 'main' into hidream-torch-compile 2025-05-03 08:30:07 +05:30
Sayak Paul
6e6ccf7d7c Merge branch 'main' into hidream-torch-compile 2025-05-01 19:32:08 +05:30
sayakpaul
6e3d988279 get hidream transformer fully torch.compile compatible. 2025-05-01 19:08:08 +05:30
sayakpaul
1d1e7157ca fix 2025-05-01 18:22:28 +05:30
sayakpaul
4caa6e819b add tests for hidream transformer model. 2025-05-01 18:04:04 +05:30
2 changed files with 24 additions and 1 deletions

View File

@@ -389,7 +389,9 @@ class MOEFeedForwardSwiGLU(nn.Module):
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
count_freq = torch.bincount(flat_expert_indices, minlength=self.num_activated_experts)
tokens_per_expert = count_freq.cumsum(dim=0)
token_idxs = idxs // self.num_activated_experts
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]

View File

@@ -20,6 +20,10 @@ import torch
from diffusers import HiDreamImageTransformer2DModel
from diffusers.utils.testing_utils import (
enable_full_determinism,
is_torch_compile,
require_torch_2,
require_torch_gpu,
slow,
torch_device,
)
@@ -94,3 +98,20 @@ class HiDreamTransformerTests(ModelTesterMixin, unittest.TestCase):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HiDreamImageTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@require_torch_gpu
@require_torch_2
@is_torch_compile
@slow
def test_torch_compile_recompilation_and_graph_break(self):
torch._dynamo.reset()
torch._dynamo.config.capture_dynamic_output_shape_ops = True
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model = torch.compile(model, fullgraph=True)
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)