mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-14 08:24:32 +08:00
Compare commits
4 Commits
modular-do
...
adv-flux
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6eca44e1a7 | ||
|
|
53476bfca9 | ||
|
|
44126bd77e | ||
|
|
8997e88d85 |
244
examples/community/adversarial-tts/flux_adversarial_latents.py
Normal file
244
examples/community/adversarial-tts/flux_adversarial_latents.py
Normal file
@@ -0,0 +1,244 @@
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
import torch
|
||||
from PIL import Image
|
||||
from reward_scorers import BaseRewardScorer, available_scorers, build_scorer
|
||||
from diffusers import FluxPipeline
|
||||
from diffusers.utils import make_image_grid
|
||||
|
||||
|
||||
class AdversarialFluxPipeline(FluxPipeline):
|
||||
def adversarial_refinement(
|
||||
self,
|
||||
prompt: Union[str, list[str]],
|
||||
reward_model: BaseRewardScorer,
|
||||
reward_prompt: Optional[Union[str, list[str]]] = None,
|
||||
num_rounds: int = 1,
|
||||
step_size: float = 0.1,
|
||||
epsilon: Optional[float] = None,
|
||||
attack_type: str = "pgd",
|
||||
record_intermediate: bool = False,
|
||||
**generate_kwargs,
|
||||
):
|
||||
if num_rounds < 0:
|
||||
raise ValueError("`num_rounds` must be non-negative.")
|
||||
if attack_type not in {"pgd", "fgsm"}:
|
||||
raise ValueError("`attack_type` must be either 'pgd' or 'fgsm'.")
|
||||
|
||||
generate_kwargs = dict(generate_kwargs)
|
||||
height, width = self._resolve_height_width(generate_kwargs.get("height"), generate_kwargs.get("width"))
|
||||
generate_kwargs["height"] = height
|
||||
generate_kwargs["width"] = width
|
||||
generate_kwargs["output_type"] = "latent"
|
||||
generate_kwargs.setdefault("return_dict", True)
|
||||
|
||||
flux_output = super().__call__(prompt=prompt, **generate_kwargs)
|
||||
latents = flux_output.images
|
||||
device = latents.device
|
||||
|
||||
reward_model = reward_model.to(device)
|
||||
reward_model.eval()
|
||||
if getattr(reward_model, "supports_gradients", True) is False and (num_rounds != 0 or attack_type == "fgsm"):
|
||||
raise ValueError(
|
||||
f"Scorer `{reward_model.__class__.__name__}` does not support gradients required for adversarial refinement."
|
||||
)
|
||||
|
||||
reward_prompts = self._expand_prompts(reward_prompt if reward_prompt is not None else prompt, latents.shape[0])
|
||||
|
||||
with torch.no_grad():
|
||||
current_images = self._decode_packed_latents(latents, height, width).to(dtype=torch.float32)
|
||||
current_scores = reward_model(current_images, reward_prompts)
|
||||
|
||||
intermediate_images = []
|
||||
if record_intermediate:
|
||||
intermediate_images.append(self.image_processor.postprocess(current_images, output_type="pil"))
|
||||
|
||||
score_trace = [current_scores.mean().item()]
|
||||
per_sample_scores = [current_scores.detach().cpu().tolist()]
|
||||
|
||||
if num_rounds == 0:
|
||||
max_rounds = 0
|
||||
else:
|
||||
max_rounds = 1 if attack_type == "fgsm" else num_rounds
|
||||
|
||||
for round_index in range(max_rounds):
|
||||
current_images.requires_grad_(True)
|
||||
scores = reward_model(current_images, reward_prompts)
|
||||
total_score = scores.mean()
|
||||
|
||||
grad = torch.autograd.grad(total_score, current_images, retain_graph=False, create_graph=False)[0]
|
||||
|
||||
if attack_type == "fgsm":
|
||||
step = epsilon if epsilon is not None else step_size
|
||||
update = step * grad.sign()
|
||||
else:
|
||||
update = step_size * grad
|
||||
|
||||
with torch.no_grad():
|
||||
current_images = current_images + update
|
||||
current_images = current_images.clamp_(-1.0, 1.0)
|
||||
|
||||
current_images = current_images.detach()
|
||||
|
||||
with torch.no_grad():
|
||||
current_scores = reward_model(current_images, reward_prompts)
|
||||
|
||||
score_trace.append(current_scores.mean().item())
|
||||
per_sample_scores.append(current_scores.detach().cpu().tolist())
|
||||
|
||||
if record_intermediate:
|
||||
intermediate_images.append(self.image_processor.postprocess(current_images, output_type="pil"))
|
||||
|
||||
final_images = self.image_processor.postprocess(current_images, output_type="pil")
|
||||
|
||||
return {
|
||||
"images": final_images,
|
||||
"latents": latents.detach(),
|
||||
"score_trace": score_trace,
|
||||
"score_trace_per_sample": per_sample_scores,
|
||||
"final_scores": current_scores.detach().cpu().tolist(),
|
||||
"intermediate_images": intermediate_images,
|
||||
}
|
||||
|
||||
def _decode_packed_latents(self, latents: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
unpacked = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
||||
unpacked = (unpacked / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
decoded = self.vae.decode(unpacked, return_dict=False)[0]
|
||||
return decoded
|
||||
|
||||
def _resolve_height_width(self, height: Optional[int], width: Optional[int]) -> tuple[int, int]:
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
return height, width
|
||||
|
||||
@staticmethod
|
||||
def _expand_prompts(prompts: Union[str, list[str]], batch_size: int) -> list[str]:
|
||||
if isinstance(prompts, str):
|
||||
return [prompts] * batch_size
|
||||
if len(prompts) != batch_size:
|
||||
raise ValueError(f"Expected {batch_size} reward prompts, got {len(prompts)}.")
|
||||
return prompts
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model-id", type=str, default="black-forest-labs/FLUX.1-dev")
|
||||
parser.add_argument(
|
||||
"--prompt", type=str, default="Photo of a dog sitting near a sea waiting for its companion to come."
|
||||
)
|
||||
parser.add_argument("--reward-prompt", type=str, default=None)
|
||||
parser.add_argument("--output", default="flux_adversarial.png")
|
||||
parser.add_argument("--num-rounds", type=int, default=3)
|
||||
parser.add_argument("--step-size", type=float, default=0.1)
|
||||
parser.add_argument("--epsilon", type=float, help="FGSM epsilon. Falls back to step size when omitted.")
|
||||
parser.add_argument("--attack-type", choices=["pgd", "fgsm"], default="pgd")
|
||||
parser.add_argument("--num-inference-steps", type=int, default=30)
|
||||
parser.add_argument("--guidance-scale", type=float, default=3.5)
|
||||
parser.add_argument("--height", type=int, default=1024)
|
||||
parser.add_argument("--width", type=int, default=1024)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--scorer", choices=available_scorers(), default="clip")
|
||||
parser.add_argument("--scorer-model-id", type=str, default=None)
|
||||
parser.add_argument("--record-intermediates", action="store_true")
|
||||
parser.add_argument("--intermediate-dir", type=str, default=None)
|
||||
parser.add_argument("--metadata-output", type=str, default=None)
|
||||
parser.add_argument("--output-root", type=str, required=True)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def save_intermediates(intermediate_dir: Path, rounds: list[list[Image.Image]]) -> None:
|
||||
intermediate_dir.mkdir(parents=True, exist_ok=True)
|
||||
for round_index, images in enumerate(rounds):
|
||||
for sample_index, image in enumerate(images):
|
||||
filename = intermediate_dir / f"round_{round_index:02d}_sample_{sample_index:02d}.png"
|
||||
image.save(filename)
|
||||
|
||||
if rounds and len(rounds[0]) == 1:
|
||||
grid = make_image_grid([imgs[0] for imgs in rounds], cols=len(rounds), rows=1)
|
||||
grid.save(intermediate_dir / "grid.png")
|
||||
|
||||
|
||||
def dump_metadata(metadata_path: Path, payload: dict[str, object]) -> None:
|
||||
metadata_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with metadata_path.open("w", encoding="utf-8") as file:
|
||||
json.dump(payload, file)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
generator = None
|
||||
if args.seed is not None:
|
||||
generator = torch.Generator(device=args.device)
|
||||
generator.manual_seed(args.seed)
|
||||
|
||||
dtype = torch.bfloat16 if args.device.startswith("cuda") else torch.float32
|
||||
pipe = AdversarialFluxPipeline.from_pretrained(args.model_id, torch_dtype=dtype)
|
||||
pipe.to(args.device)
|
||||
|
||||
reward_model = build_scorer(name=args.scorer, model_id=args.scorer_model_id, device=args.device)
|
||||
|
||||
record_intermediate = args.record_intermediates or args.intermediate_dir is not None
|
||||
|
||||
result = pipe.adversarial_refinement(
|
||||
prompt=args.prompt,
|
||||
reward_prompt=args.reward_prompt,
|
||||
reward_model=reward_model,
|
||||
num_rounds=args.num_rounds,
|
||||
step_size=args.step_size,
|
||||
epsilon=args.epsilon,
|
||||
attack_type=args.attack_type,
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
guidance_scale=args.guidance_scale,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
generator=generator,
|
||||
record_intermediate=record_intermediate,
|
||||
)
|
||||
|
||||
images = result["images"]
|
||||
output_path = Path(args.output_root)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
images[0].save(output_path)
|
||||
|
||||
if args.intermediate_dir:
|
||||
intermediate_dir = output_path / args.intermediate_dir
|
||||
if intermediate_dir and result["intermediate_images"]:
|
||||
save_intermediates(intermediate_dir, result["intermediate_images"])
|
||||
|
||||
if args.metadata_output:
|
||||
metadata_payload = {
|
||||
"prompt": args.prompt,
|
||||
"reward_prompt": args.reward_prompt or "",
|
||||
"scorer": {
|
||||
"name": args.scorer,
|
||||
"model_id": getattr(reward_model, "model_id", args.scorer_model_id or ""),
|
||||
},
|
||||
"attack": {
|
||||
"type": args.attack_type,
|
||||
"num_rounds": args.num_rounds,
|
||||
"step_size": args.step_size,
|
||||
"epsilon": args.epsilon
|
||||
if args.epsilon is not None
|
||||
else args.step_size
|
||||
if args.attack_type == "fgsm"
|
||||
else None,
|
||||
"rounds_executed": len(result["score_trace"]) - 1,
|
||||
},
|
||||
"score_trace": result["score_trace"],
|
||||
"score_trace_per_sample": result["score_trace_per_sample"],
|
||||
"final_scores": result["final_scores"],
|
||||
}
|
||||
metadata_path = output_path / args.metadata_output
|
||||
if metadata_path:
|
||||
dump_metadata(metadata_path, metadata_payload)
|
||||
|
||||
print("Mean score trace:", result["score_trace"])
|
||||
print("Final per-sample scores:", result["final_scores"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
196
examples/community/adversarial-tts/reward_scorers.py
Normal file
196
examples/community/adversarial-tts/reward_scorers.py
Normal file
@@ -0,0 +1,196 @@
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPImageProcessor, CLIPModel, CLIPTokenizer, SiglipModel, SiglipProcessor
|
||||
|
||||
|
||||
class BaseRewardScorer(nn.Module):
|
||||
"""
|
||||
Base interface for reward scorers.
|
||||
|
||||
Subclasses are expected to implement a differentiable `forward` method that
|
||||
accepts a batch of images in the `[-1, 1]` range and a batch of prompt strings
|
||||
with the same batch dimension.
|
||||
"""
|
||||
|
||||
name: str = "base"
|
||||
default_model_id: Optional[str] = None
|
||||
supports_gradients: bool = True
|
||||
|
||||
def __init__(self, model_id: Optional[str] = None, device: Optional[torch.device] = None):
|
||||
super().__init__()
|
||||
self.model_id = model_id or self.default_model_id
|
||||
if self.model_id is None:
|
||||
raise ValueError(f"{self.__class__.__name__} requires `model_id` to be specified.")
|
||||
self._requested_device = torch.device(device) if device is not None else None
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
parameters = list(self.parameters())
|
||||
if parameters:
|
||||
return parameters[0].device
|
||||
return self._requested_device or torch.device("cpu")
|
||||
|
||||
def ensure_device(self) -> None:
|
||||
if self._requested_device is not None:
|
||||
self.to(self._requested_device)
|
||||
|
||||
def forward(self, images: torch.Tensor, prompts: Sequence[str]) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ClipScorer(BaseRewardScorer):
|
||||
name = "clip"
|
||||
default_model_id = "openai/clip-vit-large-patch14"
|
||||
|
||||
def __init__(self, model_id: Optional[str] = None, device: Optional[torch.device] = None):
|
||||
super().__init__(model_id=model_id, device=device)
|
||||
|
||||
self.model = CLIPModel.from_pretrained(self.model_id)
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(self.model_id)
|
||||
self.image_processor = CLIPImageProcessor.from_pretrained(self.model_id)
|
||||
if self._requested_device is not None:
|
||||
self.model = self.model.to(self._requested_device)
|
||||
self.model = self.model.to(dtype=torch.float32)
|
||||
self.model.eval()
|
||||
for parameter in self.model.parameters():
|
||||
parameter.requires_grad_(False)
|
||||
self.eval()
|
||||
|
||||
def forward(self, images: torch.Tensor, prompts: Sequence[str]) -> torch.Tensor:
|
||||
device = self.model.device
|
||||
pixel_values = self._preprocess_images(images).to(device=device, dtype=torch.float32)
|
||||
text_inputs = self.tokenizer(list(prompts), padding=True, truncation=True, return_tensors="pt").to(device)
|
||||
|
||||
image_embeds = self.model.get_image_features(pixel_values=pixel_values)
|
||||
text_embeds = self.model.get_text_features(**text_inputs)
|
||||
|
||||
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
|
||||
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
|
||||
return (image_embeds * text_embeds).sum(dim=-1)
|
||||
|
||||
def _preprocess_images(self, images: torch.Tensor) -> torch.Tensor:
|
||||
pixel_values = (images + 1) / 2
|
||||
pixel_values = torch.clamp(pixel_values, 0, 1)
|
||||
|
||||
crop_size = self.image_processor.crop_size
|
||||
if isinstance(crop_size, dict):
|
||||
target_height = crop_size["height"]
|
||||
target_width = crop_size["width"]
|
||||
else:
|
||||
target_height = target_width = crop_size
|
||||
|
||||
pixel_values = F.interpolate(
|
||||
pixel_values, size=(target_height, target_width), mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
mean = torch.tensor(
|
||||
self.image_processor.image_mean, device=pixel_values.device, dtype=pixel_values.dtype
|
||||
).view(1, -1, 1, 1)
|
||||
std = torch.tensor(self.image_processor.image_std, device=pixel_values.device, dtype=pixel_values.dtype).view(
|
||||
1, -1, 1, 1
|
||||
)
|
||||
return (pixel_values - mean) / std
|
||||
|
||||
|
||||
class SiglipScorer(BaseRewardScorer):
|
||||
name = "siglip"
|
||||
default_model_id = "google/siglip-so400m-patch14-384"
|
||||
|
||||
def __init__(self, model_id: Optional[str] = None, device: Optional[torch.device] = None):
|
||||
super().__init__(model_id=model_id, device=device)
|
||||
|
||||
self.processor = SiglipProcessor.from_pretrained(self.model_id)
|
||||
self.image_processor = self.processor.image_processor
|
||||
self.text_tokenizer = self.processor.tokenizer
|
||||
self.model = SiglipModel.from_pretrained(self.model_id)
|
||||
if self._requested_device is not None:
|
||||
self.model = self.model.to(self._requested_device)
|
||||
self.model = self.model.to(dtype=torch.float32)
|
||||
self.model.eval()
|
||||
for parameter in self.model.parameters():
|
||||
parameter.requires_grad_(False)
|
||||
self.eval()
|
||||
|
||||
def forward(self, images: torch.Tensor, prompts: Sequence[str]) -> torch.Tensor: # type: ignore[override]
|
||||
device = self.model.device
|
||||
pixel_values = self._preprocess_images(images).to(device=device, dtype=torch.float32)
|
||||
text_inputs = self.text_tokenizer(list(prompts), padding=True, truncation=True, return_tensors="pt").to(device)
|
||||
|
||||
image_embeds = self.model.get_image_features(pixel_values=pixel_values)
|
||||
text_embeds = self.model.get_text_features(**text_inputs)
|
||||
|
||||
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
|
||||
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
|
||||
|
||||
return (image_embeds * text_embeds).sum(dim=-1)
|
||||
|
||||
def _preprocess_images(self, images: torch.Tensor) -> torch.Tensor:
|
||||
pixel_values = (images + 1) / 2
|
||||
pixel_values = torch.clamp(pixel_values, 0, 1)
|
||||
|
||||
size = self.image_processor.size
|
||||
if isinstance(size, dict):
|
||||
target_height = size.get("shortest_edge") or size.get("height") or size.get("width")
|
||||
target_width = size.get("width") or target_height
|
||||
target_height = target_height or target_width
|
||||
else:
|
||||
target_height = target_width = size
|
||||
|
||||
pixel_values = F.interpolate(
|
||||
pixel_values, size=(target_height, target_width), mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
mean = torch.tensor(
|
||||
self.image_processor.image_mean, device=pixel_values.device, dtype=pixel_values.dtype
|
||||
).view(1, -1, 1, 1)
|
||||
std = torch.tensor(self.image_processor.image_std, device=pixel_values.device, dtype=pixel_values.dtype).view(
|
||||
1, -1, 1, 1
|
||||
)
|
||||
return (pixel_values - mean) / std
|
||||
|
||||
|
||||
class PlaceholderScorer(BaseRewardScorer):
|
||||
"""
|
||||
Helper scorer that surfaces a friendly error for scorers that require external dependencies.
|
||||
"""
|
||||
|
||||
name = "placeholder"
|
||||
supports_gradients = False
|
||||
|
||||
def __init__(self, *args: Any, required_package: str, scorer_name: str, **kwargs: Any):
|
||||
self.required_package = required_package
|
||||
self.scorer_name = scorer_name
|
||||
raise ImportError(f"{scorer_name} requires the external package `{required_package}` which is not installed.")
|
||||
|
||||
|
||||
SCORER_REGISTRY: Dict[str, Type[BaseRewardScorer]] = {ClipScorer.name: ClipScorer, SiglipScorer.name: SiglipScorer}
|
||||
|
||||
|
||||
def available_scorers() -> Tuple[str, ...]:
|
||||
return tuple(sorted(SCORER_REGISTRY.keys()))
|
||||
|
||||
|
||||
def build_scorer(
|
||||
name: str,
|
||||
model_id: Optional[str] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseRewardScorer:
|
||||
if name not in SCORER_REGISTRY:
|
||||
raise ValueError(f"Unknown scorer `{name}`. Available scorers: {', '.join(available_scorers())}.")
|
||||
scorer_cls = SCORER_REGISTRY[name]
|
||||
device_obj = torch.device(device) if device is not None else None
|
||||
|
||||
scorer = scorer_cls(model_id=model_id, device=device_obj, **kwargs)
|
||||
|
||||
if not scorer.supports_gradients:
|
||||
warnings.warn(
|
||||
f"Scorer `{name}` does not declare gradient support. Adversarial refinement may not work as expected.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
return scorer
|
||||
Reference in New Issue
Block a user