Compare commits

...

4 Commits

Author SHA1 Message Date
sayakpaul
6eca44e1a7 up 2025-10-10 13:49:23 +05:30
sayakpaul
53476bfca9 up 2025-10-10 11:29:19 +05:30
sayakpaul
44126bd77e up 2025-10-09 10:23:28 +05:30
sayakpaul
8997e88d85 adv flux. 2025-10-09 08:42:34 +05:30
2 changed files with 440 additions and 0 deletions

View File

@@ -0,0 +1,244 @@
import argparse
import json
from pathlib import Path
from typing import Optional, Union
import torch
from PIL import Image
from reward_scorers import BaseRewardScorer, available_scorers, build_scorer
from diffusers import FluxPipeline
from diffusers.utils import make_image_grid
class AdversarialFluxPipeline(FluxPipeline):
def adversarial_refinement(
self,
prompt: Union[str, list[str]],
reward_model: BaseRewardScorer,
reward_prompt: Optional[Union[str, list[str]]] = None,
num_rounds: int = 1,
step_size: float = 0.1,
epsilon: Optional[float] = None,
attack_type: str = "pgd",
record_intermediate: bool = False,
**generate_kwargs,
):
if num_rounds < 0:
raise ValueError("`num_rounds` must be non-negative.")
if attack_type not in {"pgd", "fgsm"}:
raise ValueError("`attack_type` must be either 'pgd' or 'fgsm'.")
generate_kwargs = dict(generate_kwargs)
height, width = self._resolve_height_width(generate_kwargs.get("height"), generate_kwargs.get("width"))
generate_kwargs["height"] = height
generate_kwargs["width"] = width
generate_kwargs["output_type"] = "latent"
generate_kwargs.setdefault("return_dict", True)
flux_output = super().__call__(prompt=prompt, **generate_kwargs)
latents = flux_output.images
device = latents.device
reward_model = reward_model.to(device)
reward_model.eval()
if getattr(reward_model, "supports_gradients", True) is False and (num_rounds != 0 or attack_type == "fgsm"):
raise ValueError(
f"Scorer `{reward_model.__class__.__name__}` does not support gradients required for adversarial refinement."
)
reward_prompts = self._expand_prompts(reward_prompt if reward_prompt is not None else prompt, latents.shape[0])
with torch.no_grad():
current_images = self._decode_packed_latents(latents, height, width).to(dtype=torch.float32)
current_scores = reward_model(current_images, reward_prompts)
intermediate_images = []
if record_intermediate:
intermediate_images.append(self.image_processor.postprocess(current_images, output_type="pil"))
score_trace = [current_scores.mean().item()]
per_sample_scores = [current_scores.detach().cpu().tolist()]
if num_rounds == 0:
max_rounds = 0
else:
max_rounds = 1 if attack_type == "fgsm" else num_rounds
for round_index in range(max_rounds):
current_images.requires_grad_(True)
scores = reward_model(current_images, reward_prompts)
total_score = scores.mean()
grad = torch.autograd.grad(total_score, current_images, retain_graph=False, create_graph=False)[0]
if attack_type == "fgsm":
step = epsilon if epsilon is not None else step_size
update = step * grad.sign()
else:
update = step_size * grad
with torch.no_grad():
current_images = current_images + update
current_images = current_images.clamp_(-1.0, 1.0)
current_images = current_images.detach()
with torch.no_grad():
current_scores = reward_model(current_images, reward_prompts)
score_trace.append(current_scores.mean().item())
per_sample_scores.append(current_scores.detach().cpu().tolist())
if record_intermediate:
intermediate_images.append(self.image_processor.postprocess(current_images, output_type="pil"))
final_images = self.image_processor.postprocess(current_images, output_type="pil")
return {
"images": final_images,
"latents": latents.detach(),
"score_trace": score_trace,
"score_trace_per_sample": per_sample_scores,
"final_scores": current_scores.detach().cpu().tolist(),
"intermediate_images": intermediate_images,
}
def _decode_packed_latents(self, latents: torch.Tensor, height: int, width: int) -> torch.Tensor:
unpacked = self._unpack_latents(latents, height, width, self.vae_scale_factor)
unpacked = (unpacked / self.vae.config.scaling_factor) + self.vae.config.shift_factor
decoded = self.vae.decode(unpacked, return_dict=False)[0]
return decoded
def _resolve_height_width(self, height: Optional[int], width: Optional[int]) -> tuple[int, int]:
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
return height, width
@staticmethod
def _expand_prompts(prompts: Union[str, list[str]], batch_size: int) -> list[str]:
if isinstance(prompts, str):
return [prompts] * batch_size
if len(prompts) != batch_size:
raise ValueError(f"Expected {batch_size} reward prompts, got {len(prompts)}.")
return prompts
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--model-id", type=str, default="black-forest-labs/FLUX.1-dev")
parser.add_argument(
"--prompt", type=str, default="Photo of a dog sitting near a sea waiting for its companion to come."
)
parser.add_argument("--reward-prompt", type=str, default=None)
parser.add_argument("--output", default="flux_adversarial.png")
parser.add_argument("--num-rounds", type=int, default=3)
parser.add_argument("--step-size", type=float, default=0.1)
parser.add_argument("--epsilon", type=float, help="FGSM epsilon. Falls back to step size when omitted.")
parser.add_argument("--attack-type", choices=["pgd", "fgsm"], default="pgd")
parser.add_argument("--num-inference-steps", type=int, default=30)
parser.add_argument("--guidance-scale", type=float, default=3.5)
parser.add_argument("--height", type=int, default=1024)
parser.add_argument("--width", type=int, default=1024)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--scorer", choices=available_scorers(), default="clip")
parser.add_argument("--scorer-model-id", type=str, default=None)
parser.add_argument("--record-intermediates", action="store_true")
parser.add_argument("--intermediate-dir", type=str, default=None)
parser.add_argument("--metadata-output", type=str, default=None)
parser.add_argument("--output-root", type=str, required=True)
return parser.parse_args()
def save_intermediates(intermediate_dir: Path, rounds: list[list[Image.Image]]) -> None:
intermediate_dir.mkdir(parents=True, exist_ok=True)
for round_index, images in enumerate(rounds):
for sample_index, image in enumerate(images):
filename = intermediate_dir / f"round_{round_index:02d}_sample_{sample_index:02d}.png"
image.save(filename)
if rounds and len(rounds[0]) == 1:
grid = make_image_grid([imgs[0] for imgs in rounds], cols=len(rounds), rows=1)
grid.save(intermediate_dir / "grid.png")
def dump_metadata(metadata_path: Path, payload: dict[str, object]) -> None:
metadata_path.parent.mkdir(parents=True, exist_ok=True)
with metadata_path.open("w", encoding="utf-8") as file:
json.dump(payload, file)
def main() -> None:
args = parse_args()
generator = None
if args.seed is not None:
generator = torch.Generator(device=args.device)
generator.manual_seed(args.seed)
dtype = torch.bfloat16 if args.device.startswith("cuda") else torch.float32
pipe = AdversarialFluxPipeline.from_pretrained(args.model_id, torch_dtype=dtype)
pipe.to(args.device)
reward_model = build_scorer(name=args.scorer, model_id=args.scorer_model_id, device=args.device)
record_intermediate = args.record_intermediates or args.intermediate_dir is not None
result = pipe.adversarial_refinement(
prompt=args.prompt,
reward_prompt=args.reward_prompt,
reward_model=reward_model,
num_rounds=args.num_rounds,
step_size=args.step_size,
epsilon=args.epsilon,
attack_type=args.attack_type,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
height=args.height,
width=args.width,
generator=generator,
record_intermediate=record_intermediate,
)
images = result["images"]
output_path = Path(args.output_root)
output_path.parent.mkdir(parents=True, exist_ok=True)
images[0].save(output_path)
if args.intermediate_dir:
intermediate_dir = output_path / args.intermediate_dir
if intermediate_dir and result["intermediate_images"]:
save_intermediates(intermediate_dir, result["intermediate_images"])
if args.metadata_output:
metadata_payload = {
"prompt": args.prompt,
"reward_prompt": args.reward_prompt or "",
"scorer": {
"name": args.scorer,
"model_id": getattr(reward_model, "model_id", args.scorer_model_id or ""),
},
"attack": {
"type": args.attack_type,
"num_rounds": args.num_rounds,
"step_size": args.step_size,
"epsilon": args.epsilon
if args.epsilon is not None
else args.step_size
if args.attack_type == "fgsm"
else None,
"rounds_executed": len(result["score_trace"]) - 1,
},
"score_trace": result["score_trace"],
"score_trace_per_sample": result["score_trace_per_sample"],
"final_scores": result["final_scores"],
}
metadata_path = output_path / args.metadata_output
if metadata_path:
dump_metadata(metadata_path, metadata_payload)
print("Mean score trace:", result["score_trace"])
print("Final per-sample scores:", result["final_scores"])
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,196 @@
import warnings
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPModel, CLIPTokenizer, SiglipModel, SiglipProcessor
class BaseRewardScorer(nn.Module):
"""
Base interface for reward scorers.
Subclasses are expected to implement a differentiable `forward` method that
accepts a batch of images in the `[-1, 1]` range and a batch of prompt strings
with the same batch dimension.
"""
name: str = "base"
default_model_id: Optional[str] = None
supports_gradients: bool = True
def __init__(self, model_id: Optional[str] = None, device: Optional[torch.device] = None):
super().__init__()
self.model_id = model_id or self.default_model_id
if self.model_id is None:
raise ValueError(f"{self.__class__.__name__} requires `model_id` to be specified.")
self._requested_device = torch.device(device) if device is not None else None
@property
def device(self) -> torch.device:
parameters = list(self.parameters())
if parameters:
return parameters[0].device
return self._requested_device or torch.device("cpu")
def ensure_device(self) -> None:
if self._requested_device is not None:
self.to(self._requested_device)
def forward(self, images: torch.Tensor, prompts: Sequence[str]) -> torch.Tensor:
raise NotImplementedError
class ClipScorer(BaseRewardScorer):
name = "clip"
default_model_id = "openai/clip-vit-large-patch14"
def __init__(self, model_id: Optional[str] = None, device: Optional[torch.device] = None):
super().__init__(model_id=model_id, device=device)
self.model = CLIPModel.from_pretrained(self.model_id)
self.tokenizer = CLIPTokenizer.from_pretrained(self.model_id)
self.image_processor = CLIPImageProcessor.from_pretrained(self.model_id)
if self._requested_device is not None:
self.model = self.model.to(self._requested_device)
self.model = self.model.to(dtype=torch.float32)
self.model.eval()
for parameter in self.model.parameters():
parameter.requires_grad_(False)
self.eval()
def forward(self, images: torch.Tensor, prompts: Sequence[str]) -> torch.Tensor:
device = self.model.device
pixel_values = self._preprocess_images(images).to(device=device, dtype=torch.float32)
text_inputs = self.tokenizer(list(prompts), padding=True, truncation=True, return_tensors="pt").to(device)
image_embeds = self.model.get_image_features(pixel_values=pixel_values)
text_embeds = self.model.get_text_features(**text_inputs)
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
return (image_embeds * text_embeds).sum(dim=-1)
def _preprocess_images(self, images: torch.Tensor) -> torch.Tensor:
pixel_values = (images + 1) / 2
pixel_values = torch.clamp(pixel_values, 0, 1)
crop_size = self.image_processor.crop_size
if isinstance(crop_size, dict):
target_height = crop_size["height"]
target_width = crop_size["width"]
else:
target_height = target_width = crop_size
pixel_values = F.interpolate(
pixel_values, size=(target_height, target_width), mode="bilinear", align_corners=False
)
mean = torch.tensor(
self.image_processor.image_mean, device=pixel_values.device, dtype=pixel_values.dtype
).view(1, -1, 1, 1)
std = torch.tensor(self.image_processor.image_std, device=pixel_values.device, dtype=pixel_values.dtype).view(
1, -1, 1, 1
)
return (pixel_values - mean) / std
class SiglipScorer(BaseRewardScorer):
name = "siglip"
default_model_id = "google/siglip-so400m-patch14-384"
def __init__(self, model_id: Optional[str] = None, device: Optional[torch.device] = None):
super().__init__(model_id=model_id, device=device)
self.processor = SiglipProcessor.from_pretrained(self.model_id)
self.image_processor = self.processor.image_processor
self.text_tokenizer = self.processor.tokenizer
self.model = SiglipModel.from_pretrained(self.model_id)
if self._requested_device is not None:
self.model = self.model.to(self._requested_device)
self.model = self.model.to(dtype=torch.float32)
self.model.eval()
for parameter in self.model.parameters():
parameter.requires_grad_(False)
self.eval()
def forward(self, images: torch.Tensor, prompts: Sequence[str]) -> torch.Tensor: # type: ignore[override]
device = self.model.device
pixel_values = self._preprocess_images(images).to(device=device, dtype=torch.float32)
text_inputs = self.text_tokenizer(list(prompts), padding=True, truncation=True, return_tensors="pt").to(device)
image_embeds = self.model.get_image_features(pixel_values=pixel_values)
text_embeds = self.model.get_text_features(**text_inputs)
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
return (image_embeds * text_embeds).sum(dim=-1)
def _preprocess_images(self, images: torch.Tensor) -> torch.Tensor:
pixel_values = (images + 1) / 2
pixel_values = torch.clamp(pixel_values, 0, 1)
size = self.image_processor.size
if isinstance(size, dict):
target_height = size.get("shortest_edge") or size.get("height") or size.get("width")
target_width = size.get("width") or target_height
target_height = target_height or target_width
else:
target_height = target_width = size
pixel_values = F.interpolate(
pixel_values, size=(target_height, target_width), mode="bilinear", align_corners=False
)
mean = torch.tensor(
self.image_processor.image_mean, device=pixel_values.device, dtype=pixel_values.dtype
).view(1, -1, 1, 1)
std = torch.tensor(self.image_processor.image_std, device=pixel_values.device, dtype=pixel_values.dtype).view(
1, -1, 1, 1
)
return (pixel_values - mean) / std
class PlaceholderScorer(BaseRewardScorer):
"""
Helper scorer that surfaces a friendly error for scorers that require external dependencies.
"""
name = "placeholder"
supports_gradients = False
def __init__(self, *args: Any, required_package: str, scorer_name: str, **kwargs: Any):
self.required_package = required_package
self.scorer_name = scorer_name
raise ImportError(f"{scorer_name} requires the external package `{required_package}` which is not installed.")
SCORER_REGISTRY: Dict[str, Type[BaseRewardScorer]] = {ClipScorer.name: ClipScorer, SiglipScorer.name: SiglipScorer}
def available_scorers() -> Tuple[str, ...]:
return tuple(sorted(SCORER_REGISTRY.keys()))
def build_scorer(
name: str,
model_id: Optional[str] = None,
device: Optional[Union[str, torch.device]] = None,
**kwargs: Any,
) -> BaseRewardScorer:
if name not in SCORER_REGISTRY:
raise ValueError(f"Unknown scorer `{name}`. Available scorers: {', '.join(available_scorers())}.")
scorer_cls = SCORER_REGISTRY[name]
device_obj = torch.device(device) if device is not None else None
scorer = scorer_cls(model_id=model_id, device=device_obj, **kwargs)
if not scorer.supports_gradients:
warnings.warn(
f"Scorer `{name}` does not declare gradient support. Adversarial refinement may not work as expected.",
UserWarning,
)
return scorer