mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-02 09:51:06 +08:00
* Add server example. * Minor updates to README. * Add fixes after local testing. * Apply suggestions from code review Updates to README from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * More doc updates. * Maybe this will work to build the docs correctly? * Fix style issues. * Fix toc. * Minor reformatting. * Move docs to proper loc. * Fix missing tick. * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Sync docs changes back to README. * Very minor update to docs to add space. --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
134 lines
4.1 KiB
Python
134 lines
4.1 KiB
Python
import asyncio
|
|
import logging
|
|
import os
|
|
import random
|
|
import tempfile
|
|
import traceback
|
|
import uuid
|
|
|
|
import aiohttp
|
|
import torch
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.staticfiles import StaticFiles
|
|
from pydantic import BaseModel
|
|
|
|
from diffusers.pipelines.stable_diffusion_3 import StableDiffusion3Pipeline
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TextToImageInput(BaseModel):
|
|
model: str
|
|
prompt: str
|
|
size: str | None = None
|
|
n: int | None = None
|
|
|
|
|
|
class HttpClient:
|
|
session: aiohttp.ClientSession = None
|
|
|
|
def start(self):
|
|
self.session = aiohttp.ClientSession()
|
|
|
|
async def stop(self):
|
|
await self.session.close()
|
|
self.session = None
|
|
|
|
def __call__(self) -> aiohttp.ClientSession:
|
|
assert self.session is not None
|
|
return self.session
|
|
|
|
|
|
class TextToImagePipeline:
|
|
pipeline: StableDiffusion3Pipeline = None
|
|
device: str = None
|
|
|
|
def start(self):
|
|
if torch.cuda.is_available():
|
|
model_path = os.getenv("MODEL_PATH", "stabilityai/stable-diffusion-3.5-large")
|
|
logger.info("Loading CUDA")
|
|
self.device = "cuda"
|
|
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
|
|
model_path,
|
|
torch_dtype=torch.bfloat16,
|
|
).to(device=self.device)
|
|
elif torch.backends.mps.is_available():
|
|
model_path = os.getenv("MODEL_PATH", "stabilityai/stable-diffusion-3.5-medium")
|
|
logger.info("Loading MPS for Mac M Series")
|
|
self.device = "mps"
|
|
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
|
|
model_path,
|
|
torch_dtype=torch.bfloat16,
|
|
).to(device=self.device)
|
|
else:
|
|
raise Exception("No CUDA or MPS device available")
|
|
|
|
|
|
app = FastAPI()
|
|
service_url = os.getenv("SERVICE_URL", "http://localhost:8000")
|
|
image_dir = os.path.join(tempfile.gettempdir(), "images")
|
|
if not os.path.exists(image_dir):
|
|
os.makedirs(image_dir)
|
|
app.mount("/images", StaticFiles(directory=image_dir), name="images")
|
|
http_client = HttpClient()
|
|
shared_pipeline = TextToImagePipeline()
|
|
|
|
# Configure CORS settings
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"], # Allows all origins
|
|
allow_credentials=True,
|
|
allow_methods=["*"], # Allows all methods, e.g., GET, POST, OPTIONS, etc.
|
|
allow_headers=["*"], # Allows all headers
|
|
)
|
|
|
|
|
|
@app.on_event("startup")
|
|
def startup():
|
|
http_client.start()
|
|
shared_pipeline.start()
|
|
|
|
|
|
def save_image(image):
|
|
filename = "draw" + str(uuid.uuid4()).split("-")[0] + ".png"
|
|
image_path = os.path.join(image_dir, filename)
|
|
# write image to disk at image_path
|
|
logger.info(f"Saving image to {image_path}")
|
|
image.save(image_path)
|
|
return os.path.join(service_url, "images", filename)
|
|
|
|
|
|
@app.get("/")
|
|
@app.post("/")
|
|
@app.options("/")
|
|
async def base():
|
|
return "Welcome to Diffusers! Where you can use diffusion models to generate images"
|
|
|
|
|
|
@app.post("/v1/images/generations")
|
|
async def generate_image(image_input: TextToImageInput):
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
scheduler = shared_pipeline.pipeline.scheduler.from_config(shared_pipeline.pipeline.scheduler.config)
|
|
pipeline = StableDiffusion3Pipeline.from_pipe(shared_pipeline.pipeline, scheduler=scheduler)
|
|
generator = torch.Generator(device=shared_pipeline.device)
|
|
generator.manual_seed(random.randint(0, 10000000))
|
|
output = await loop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator=generator))
|
|
logger.info(f"output: {output}")
|
|
image_url = save_image(output.images[0])
|
|
return {"data": [{"url": image_url}]}
|
|
except Exception as e:
|
|
if isinstance(e, HTTPException):
|
|
raise e
|
|
elif hasattr(e, "message"):
|
|
raise HTTPException(status_code=500, detail=e.message + traceback.format_exc())
|
|
raise HTTPException(status_code=500, detail=str(e) + traceback.format_exc())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|