[examples] add controlnet sd3 example (#9249)

* add controlnet sd3 example

* add controlnet sd3 example

* update controlnet sd3 example

* add controlnet sd3 example test

* fix quality and style

* update test

* update test

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Yu Zheng
2024-09-11 09:34:37 +08:00
committed by GitHub
parent adf1f911f0
commit c002731d93
4 changed files with 1596 additions and 0 deletions

View File

@@ -115,3 +115,24 @@ class ControlNetSDXL(ExamplesTestsAccelerate):
run_command(self._launch_args + test_args)
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
class ControlNetSD3(ExamplesTestsAccelerate):
def test_controlnet_sd3(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/controlnet/train_controlnet_sd3.py
--pretrained_model_name_or_path=DavyMorgan/tiny-sd3-pipe
--dataset_name=hf-internal-testing/fill10
--output_dir={tmpdir}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--controlnet_model_name_or_path=DavyMorgan/tiny-controlnet-sd3
--max_train_steps=4
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))