mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-11 15:04:45 +08:00
Compare commits
8 Commits
custom-cod
...
hidream-to
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dacae33c1b | ||
|
|
4bac6acce6 | ||
|
|
f9662ed8ce | ||
|
|
612770b04b | ||
|
|
6e6ccf7d7c | ||
|
|
6e3d988279 | ||
|
|
1d1e7157ca | ||
|
|
4caa6e819b |
@@ -389,7 +389,9 @@ class MOEFeedForwardSwiGLU(nn.Module):
|
|||||||
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||||
expert_cache = torch.zeros_like(x)
|
expert_cache = torch.zeros_like(x)
|
||||||
idxs = flat_expert_indices.argsort()
|
idxs = flat_expert_indices.argsort()
|
||||||
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
|
count_freq = torch.bincount(flat_expert_indices, minlength=self.num_activated_experts)
|
||||||
|
tokens_per_expert = count_freq.cumsum(dim=0)
|
||||||
|
|
||||||
token_idxs = idxs // self.num_activated_experts
|
token_idxs = idxs // self.num_activated_experts
|
||||||
for i, end_idx in enumerate(tokens_per_expert):
|
for i, end_idx in enumerate(tokens_per_expert):
|
||||||
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
|
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
|
||||||
|
|||||||
@@ -20,6 +20,10 @@ import torch
|
|||||||
from diffusers import HiDreamImageTransformer2DModel
|
from diffusers import HiDreamImageTransformer2DModel
|
||||||
from diffusers.utils.testing_utils import (
|
from diffusers.utils.testing_utils import (
|
||||||
enable_full_determinism,
|
enable_full_determinism,
|
||||||
|
is_torch_compile,
|
||||||
|
require_torch_2,
|
||||||
|
require_torch_gpu,
|
||||||
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -94,3 +98,20 @@ class HiDreamTransformerTests(ModelTesterMixin, unittest.TestCase):
|
|||||||
def test_gradient_checkpointing_is_applied(self):
|
def test_gradient_checkpointing_is_applied(self):
|
||||||
expected_set = {"HiDreamImageTransformer2DModel"}
|
expected_set = {"HiDreamImageTransformer2DModel"}
|
||||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
@require_torch_2
|
||||||
|
@is_torch_compile
|
||||||
|
@slow
|
||||||
|
def test_torch_compile_recompilation_and_graph_break(self):
|
||||||
|
torch._dynamo.reset()
|
||||||
|
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||||
|
|
||||||
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||||
|
|
||||||
|
model = self.model_class(**init_dict).to(torch_device)
|
||||||
|
model = torch.compile(model, fullgraph=True)
|
||||||
|
|
||||||
|
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
|
||||||
|
_ = model(**inputs_dict)
|
||||||
|
_ = model(**inputs_dict)
|
||||||
|
|||||||
Reference in New Issue
Block a user