mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-21 19:15:50 +08:00
* Basic implementation of request scheduling * Basic editing in SD and Flux Pipelines * Small Fix * Fix * Update for more pipelines * Add examples/server-async * Add examples/server-async * Updated RequestScopedPipeline to handle a single tokenizer lock to avoid race conditions * Fix * Fix _TokenizerLockWrapper * Fix _TokenizerLockWrapper * Delete _TokenizerLockWrapper * Fix tokenizer * Update examples/server-async * Fix server-async * Optimizations in examples/server-async * We keep the implementation simple in examples/server-async * Update examples/server-async/README.md * Update examples/server-async/README.md for changes to tokenizer locks and backward-compatible retrieve_timesteps * The changes to the diffusers core have been undone and all logic is being moved to exmaples/server-async * Update examples/server-async/utils/* * Fix BaseAsyncScheduler * Rollback in the core of the diffusers * Update examples/server-async/README.md * Complete rollback of diffusers core files * Simple implementation of an asynchronous server compatible with SD3-3.5 and Flux Pipelines * Update examples/server-async/README.md * Fixed import errors in 'examples/server-async/serverasync.py' * Flux Pipeline Discard * Update examples/server-async/README.md * Apply style fixes --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
231 lines
6.8 KiB
Python
231 lines
6.8 KiB
Python
import asyncio
|
|
import gc
|
|
import logging
|
|
import os
|
|
import random
|
|
import threading
|
|
from contextlib import asynccontextmanager
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, Optional, Type
|
|
|
|
import torch
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from fastapi.concurrency import run_in_threadpool
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import FileResponse
|
|
from Pipelines import ModelPipelineInitializer
|
|
from pydantic import BaseModel
|
|
|
|
from utils import RequestScopedPipeline, Utils
|
|
|
|
|
|
@dataclass
|
|
class ServerConfigModels:
|
|
model: str = "stabilityai/stable-diffusion-3.5-medium"
|
|
type_models: str = "t2im"
|
|
constructor_pipeline: Optional[Type] = None
|
|
custom_pipeline: Optional[Type] = None
|
|
components: Optional[Dict[str, Any]] = None
|
|
torch_dtype: Optional[torch.dtype] = None
|
|
host: str = "0.0.0.0"
|
|
port: int = 8500
|
|
|
|
|
|
server_config = ServerConfigModels()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
logging.basicConfig(level=logging.INFO)
|
|
app.state.logger = logging.getLogger("diffusers-server")
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
|
|
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
|
|
|
|
app.state.total_requests = 0
|
|
app.state.active_inferences = 0
|
|
app.state.metrics_lock = asyncio.Lock()
|
|
app.state.metrics_task = None
|
|
|
|
app.state.utils_app = Utils(
|
|
host=server_config.host,
|
|
port=server_config.port,
|
|
)
|
|
|
|
async def metrics_loop():
|
|
try:
|
|
while True:
|
|
async with app.state.metrics_lock:
|
|
total = app.state.total_requests
|
|
active = app.state.active_inferences
|
|
app.state.logger.info(f"[METRICS] total_requests={total} active_inferences={active}")
|
|
await asyncio.sleep(5)
|
|
except asyncio.CancelledError:
|
|
app.state.logger.info("Metrics loop cancelled")
|
|
raise
|
|
|
|
app.state.metrics_task = asyncio.create_task(metrics_loop())
|
|
|
|
try:
|
|
yield
|
|
finally:
|
|
task = app.state.metrics_task
|
|
if task:
|
|
task.cancel()
|
|
try:
|
|
await task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
try:
|
|
stop_fn = getattr(model_pipeline, "stop", None) or getattr(model_pipeline, "close", None)
|
|
if callable(stop_fn):
|
|
await run_in_threadpool(stop_fn)
|
|
except Exception as e:
|
|
app.state.logger.warning(f"Error during pipeline shutdown: {e}")
|
|
|
|
app.state.logger.info("Lifespan shutdown complete")
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
logger = logging.getLogger("DiffusersServer.Pipelines")
|
|
|
|
|
|
initializer = ModelPipelineInitializer(
|
|
model=server_config.model,
|
|
type_models=server_config.type_models,
|
|
)
|
|
model_pipeline = initializer.initialize_pipeline()
|
|
model_pipeline.start()
|
|
|
|
request_pipe = RequestScopedPipeline(model_pipeline.pipeline)
|
|
pipeline_lock = threading.Lock()
|
|
|
|
logger.info(f"Pipeline initialized and ready to receive requests (model ={server_config.model})")
|
|
|
|
app.state.MODEL_INITIALIZER = initializer
|
|
app.state.MODEL_PIPELINE = model_pipeline
|
|
app.state.REQUEST_PIPE = request_pipe
|
|
app.state.PIPELINE_LOCK = pipeline_lock
|
|
|
|
|
|
class JSONBodyQueryAPI(BaseModel):
|
|
model: str | None = None
|
|
prompt: str
|
|
negative_prompt: str | None = None
|
|
num_inference_steps: int = 28
|
|
num_images_per_prompt: int = 1
|
|
|
|
|
|
@app.middleware("http")
|
|
async def count_requests_middleware(request: Request, call_next):
|
|
async with app.state.metrics_lock:
|
|
app.state.total_requests += 1
|
|
response = await call_next(request)
|
|
return response
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
return {"message": "Welcome to the Diffusers Server"}
|
|
|
|
|
|
@app.post("/api/diffusers/inference")
|
|
async def api(json: JSONBodyQueryAPI):
|
|
prompt = json.prompt
|
|
negative_prompt = json.negative_prompt or ""
|
|
num_steps = json.num_inference_steps
|
|
num_images_per_prompt = json.num_images_per_prompt
|
|
|
|
wrapper = app.state.MODEL_PIPELINE
|
|
initializer = app.state.MODEL_INITIALIZER
|
|
|
|
utils_app = app.state.utils_app
|
|
|
|
if not wrapper or not wrapper.pipeline:
|
|
raise HTTPException(500, "Model not initialized correctly")
|
|
if not prompt.strip():
|
|
raise HTTPException(400, "No prompt provided")
|
|
|
|
def make_generator():
|
|
g = torch.Generator(device=initializer.device)
|
|
return g.manual_seed(random.randint(0, 10_000_000))
|
|
|
|
req_pipe = app.state.REQUEST_PIPE
|
|
|
|
def infer():
|
|
gen = make_generator()
|
|
return req_pipe.generate(
|
|
prompt=prompt,
|
|
negative_prompt=negative_prompt,
|
|
generator=gen,
|
|
num_inference_steps=num_steps,
|
|
num_images_per_prompt=num_images_per_prompt,
|
|
device=initializer.device,
|
|
output_type="pil",
|
|
)
|
|
|
|
try:
|
|
async with app.state.metrics_lock:
|
|
app.state.active_inferences += 1
|
|
|
|
output = await run_in_threadpool(infer)
|
|
|
|
async with app.state.metrics_lock:
|
|
app.state.active_inferences = max(0, app.state.active_inferences - 1)
|
|
|
|
urls = [utils_app.save_image(img) for img in output.images]
|
|
return {"response": urls}
|
|
|
|
except Exception as e:
|
|
async with app.state.metrics_lock:
|
|
app.state.active_inferences = max(0, app.state.active_inferences - 1)
|
|
logger.error(f"Error during inference: {e}")
|
|
raise HTTPException(500, f"Error in processing: {e}")
|
|
|
|
finally:
|
|
if torch.cuda.is_available():
|
|
torch.cuda.synchronize()
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.reset_peak_memory_stats()
|
|
torch.cuda.ipc_collect()
|
|
gc.collect()
|
|
|
|
|
|
@app.get("/images/{filename}")
|
|
async def serve_image(filename: str):
|
|
utils_app = app.state.utils_app
|
|
file_path = os.path.join(utils_app.image_dir, filename)
|
|
if not os.path.isfile(file_path):
|
|
raise HTTPException(status_code=404, detail="Image not found")
|
|
return FileResponse(file_path, media_type="image/png")
|
|
|
|
|
|
@app.get("/api/status")
|
|
async def get_status():
|
|
memory_info = {}
|
|
if torch.cuda.is_available():
|
|
memory_allocated = torch.cuda.memory_allocated() / 1024**3 # GB
|
|
memory_reserved = torch.cuda.memory_reserved() / 1024**3 # GB
|
|
memory_info = {
|
|
"memory_allocated_gb": round(memory_allocated, 2),
|
|
"memory_reserved_gb": round(memory_reserved, 2),
|
|
"device": torch.cuda.get_device_name(0),
|
|
}
|
|
|
|
return {"current_model": server_config.model, "type_models": server_config.type_models, "memory": memory_info}
|
|
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run(app, host=server_config.host, port=server_config.port)
|