This commit is contained in:
sayakpaul
2025-10-29 17:10:10 +05:30
parent 316b71ff2b
commit ecdd843044

View File

@@ -1,7 +1,18 @@
from typing import List
import torch
from diffusers import FluxTransformer2DModel
from diffusers.modular_pipelines import ComponentSpec, InputParam, ModularPipelineBlocks, OutputParam, PipelineState
from diffusers.modular_pipelines import (
ComponentSpec,
InputParam,
ModularPipelineBlocks,
OutputParam,
PipelineState,
WanModularPipeline,
)
from ..testing_utils import nightly, require_torch, slow
class DummyCustomBlockSimple(ModularPipelineBlocks):
@@ -81,10 +92,7 @@ class TestModularCustomBlocks:
def test_custom_block_loads_from_hub(self):
repo_id = "hf-internal-testing/tiny-modular-diffusers-block"
block = ModularPipelineBlocks.from_pretrained(
repo_id,
trust_remote_code=True,
)
block = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True)
self._test_block_properties(block)
pipe = block.init_pipeline()
@@ -93,3 +101,19 @@ class TestModularCustomBlocks:
output = pipe(prompt=prompt)
output_prompt = output.values["output_prompt"]
assert output_prompt.startswith("Modular diffusers + ")
@slow
@nightly
@require_torch
class TestModularCustomBlocksIntegration:
def test_krea_realtime_video_loading(self):
repo_id = "krea/krea-realtime-video"
blocks = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True)
pipe = WanModularPipeline(blocks, repo_id)
pipe.load_components(
trust_remote_code=True,
device_map="cuda",
torch_dtype={"default": torch.bfloat16, "vae": torch.float16},
)