Compare commits

...

3 Commits

Author SHA1 Message Date
Sayak Paul
687982e607 Merge branch 'main' into chroma-docs 2025-06-19 20:19:14 +05:30
DN6
802651e205 update 2025-06-19 19:41:32 +05:30
DN6
907ecf72b1 update 2025-06-19 14:20:40 +05:30
3 changed files with 62 additions and 34 deletions

View File

@@ -27,9 +27,36 @@ Chroma can use all the same optimizations as Flux.
</Tip>
## Inference (Single File)
## Inference
The `ChromaTransformer2DModel` supports loading checkpoints in the original format. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.
The Diffusers version of Chroma is based on the [`unlocked-v37`](https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors) version of the original model, which is available in the [Chroma repository](https://huggingface.co/lodestones/Chroma).
```python
import torch
from diffusers import ChromaPipeline
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma", torch_dtype=torch.bfloat16)
pipe.enabe_model_cpu_offload()
prompt = [
"A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
]
negative_prompt = ["low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"]
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
generator=torch.Generator("cpu").manual_seed(433),
num_inference_steps=40,
guidance_scale=3.0,
num_images_per_prompt=1,
).images[0]
image.save("chroma.png")
```
## Loading from a single file
To use updated model checkpoints that are not in the Diffusers format, you can use the `ChromaTransformer2DModel` class to load the model from a single file in the original format. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.
The following example demonstrates how to run Chroma from a single file.
@@ -38,30 +65,29 @@ Then run the following example
```python
import torch
from diffusers import ChromaTransformer2DModel, ChromaPipeline
from transformers import T5EncoderModel
bfl_repo = "black-forest-labs/FLUX.1-dev"
model_id = "lodestones/Chroma"
dtype = torch.bfloat16
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v35.safetensors", torch_dtype=dtype)
text_encoder = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
tokenizer = T5Tokenizer.from_pretrained(bfl_repo, subfolder="tokenizer_2", torch_dtype=dtype)
pipe = ChromaPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=dtype)
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors", torch_dtype=dtype)
pipe = ChromaPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=dtype)
pipe.enable_model_cpu_offload()
prompt = "A cat holding a sign that says hello world"
prompt = [
"A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
]
negative_prompt = ["low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"]
image = pipe(
prompt,
guidance_scale=4.0,
output_type="pil",
num_inference_steps=26,
generator=torch.Generator("cpu").manual_seed(0)
prompt=prompt,
negative_prompt=negative_prompt,
generator=torch.Generator("cpu").manual_seed(433),
num_inference_steps=40,
guidance_scale=3.0,
).images[0]
image.save("image.png")
image.save("chroma-single-file.png")
```
## ChromaPipeline
@@ -69,3 +95,9 @@ image.save("image.png")
[[autodoc]] ChromaPipeline
- all
- __call__
## ChromaImg2ImgPipeline
[[autodoc]] ChromaImg2ImgPipeline
- all
- __call__

View File

@@ -52,20 +52,21 @@ EXAMPLE_DOC_STRING = """
>>> import torch
>>> from diffusers import ChromaPipeline
>>> model_id = "lodestones/Chroma"
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
>>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
>>> text_encoder = AutoModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2")
>>> tokenizer = AutoTokenizer.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="tokenizer_2")
>>> pipe = ChromaImg2ImgPipeline.from_pretrained(
... "black-forest-labs/FLUX.1-schnell",
>>> pipe = ChromaPipeline.from_pretrained(
... model_id,
... transformer=transformer,
... text_encoder=text_encoder,
... tokenizer=tokenizer,
... torch_dtype=torch.bfloat16,
... )
>>> pipe.enable_model_cpu_offload()
>>> prompt = "A cat holding a sign that says hello world"
>>> negative_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
>>> prompt = [
... "A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
... ]
>>> negative_prompt = [
... "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
... ]
>>> image = pipe(prompt, negative_prompt=negative_prompt).images[0]
>>> image.save("chroma.png")
```

View File

@@ -51,26 +51,21 @@ EXAMPLE_DOC_STRING = """
```py
>>> import torch
>>> from diffusers import ChromaTransformer2DModel, ChromaImg2ImgPipeline
>>> from transformers import AutoModel, Autotokenizer
>>> model_id = "lodestones/Chroma"
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
>>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
>>> text_encoder = AutoModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2")
>>> tokenizer = AutoTokenizer.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="tokenizer_2")
>>> pipe = ChromaImg2ImgPipeline.from_pretrained(
... "black-forest-labs/FLUX.1-schnell",
... model_id,
... transformer=transformer,
... text_encoder=text_encoder,
... tokenizer=tokenizer,
... torch_dtype=torch.bfloat16,
... )
>>> pipe.enable_model_cpu_offload()
>>> image = load_image(
>>> init_image = load_image(
... "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
... )
>>> prompt = "a scenic fastasy landscape with a river and mountains in the background, vibrant colors, detailed, high resolution"
>>> negative_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
>>> image = pipe(prompt, image=image, negative_prompt=negative_prompt).images[0]
>>> image = pipe(prompt, image=init_image, negative_prompt=negative_prompt).images[0]
>>> image.save("chroma-img2img.png")
```
"""