mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-15 17:04:52 +08:00
Compare commits
4 Commits
modular-do
...
adv-flux
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6eca44e1a7 | ||
|
|
53476bfca9 | ||
|
|
44126bd77e | ||
|
|
8997e88d85 |
244
examples/community/adversarial-tts/flux_adversarial_latents.py
Normal file
244
examples/community/adversarial-tts/flux_adversarial_latents.py
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from reward_scorers import BaseRewardScorer, available_scorers, build_scorer
|
||||||
|
from diffusers import FluxPipeline
|
||||||
|
from diffusers.utils import make_image_grid
|
||||||
|
|
||||||
|
|
||||||
|
class AdversarialFluxPipeline(FluxPipeline):
|
||||||
|
def adversarial_refinement(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, list[str]],
|
||||||
|
reward_model: BaseRewardScorer,
|
||||||
|
reward_prompt: Optional[Union[str, list[str]]] = None,
|
||||||
|
num_rounds: int = 1,
|
||||||
|
step_size: float = 0.1,
|
||||||
|
epsilon: Optional[float] = None,
|
||||||
|
attack_type: str = "pgd",
|
||||||
|
record_intermediate: bool = False,
|
||||||
|
**generate_kwargs,
|
||||||
|
):
|
||||||
|
if num_rounds < 0:
|
||||||
|
raise ValueError("`num_rounds` must be non-negative.")
|
||||||
|
if attack_type not in {"pgd", "fgsm"}:
|
||||||
|
raise ValueError("`attack_type` must be either 'pgd' or 'fgsm'.")
|
||||||
|
|
||||||
|
generate_kwargs = dict(generate_kwargs)
|
||||||
|
height, width = self._resolve_height_width(generate_kwargs.get("height"), generate_kwargs.get("width"))
|
||||||
|
generate_kwargs["height"] = height
|
||||||
|
generate_kwargs["width"] = width
|
||||||
|
generate_kwargs["output_type"] = "latent"
|
||||||
|
generate_kwargs.setdefault("return_dict", True)
|
||||||
|
|
||||||
|
flux_output = super().__call__(prompt=prompt, **generate_kwargs)
|
||||||
|
latents = flux_output.images
|
||||||
|
device = latents.device
|
||||||
|
|
||||||
|
reward_model = reward_model.to(device)
|
||||||
|
reward_model.eval()
|
||||||
|
if getattr(reward_model, "supports_gradients", True) is False and (num_rounds != 0 or attack_type == "fgsm"):
|
||||||
|
raise ValueError(
|
||||||
|
f"Scorer `{reward_model.__class__.__name__}` does not support gradients required for adversarial refinement."
|
||||||
|
)
|
||||||
|
|
||||||
|
reward_prompts = self._expand_prompts(reward_prompt if reward_prompt is not None else prompt, latents.shape[0])
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
current_images = self._decode_packed_latents(latents, height, width).to(dtype=torch.float32)
|
||||||
|
current_scores = reward_model(current_images, reward_prompts)
|
||||||
|
|
||||||
|
intermediate_images = []
|
||||||
|
if record_intermediate:
|
||||||
|
intermediate_images.append(self.image_processor.postprocess(current_images, output_type="pil"))
|
||||||
|
|
||||||
|
score_trace = [current_scores.mean().item()]
|
||||||
|
per_sample_scores = [current_scores.detach().cpu().tolist()]
|
||||||
|
|
||||||
|
if num_rounds == 0:
|
||||||
|
max_rounds = 0
|
||||||
|
else:
|
||||||
|
max_rounds = 1 if attack_type == "fgsm" else num_rounds
|
||||||
|
|
||||||
|
for round_index in range(max_rounds):
|
||||||
|
current_images.requires_grad_(True)
|
||||||
|
scores = reward_model(current_images, reward_prompts)
|
||||||
|
total_score = scores.mean()
|
||||||
|
|
||||||
|
grad = torch.autograd.grad(total_score, current_images, retain_graph=False, create_graph=False)[0]
|
||||||
|
|
||||||
|
if attack_type == "fgsm":
|
||||||
|
step = epsilon if epsilon is not None else step_size
|
||||||
|
update = step * grad.sign()
|
||||||
|
else:
|
||||||
|
update = step_size * grad
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
current_images = current_images + update
|
||||||
|
current_images = current_images.clamp_(-1.0, 1.0)
|
||||||
|
|
||||||
|
current_images = current_images.detach()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
current_scores = reward_model(current_images, reward_prompts)
|
||||||
|
|
||||||
|
score_trace.append(current_scores.mean().item())
|
||||||
|
per_sample_scores.append(current_scores.detach().cpu().tolist())
|
||||||
|
|
||||||
|
if record_intermediate:
|
||||||
|
intermediate_images.append(self.image_processor.postprocess(current_images, output_type="pil"))
|
||||||
|
|
||||||
|
final_images = self.image_processor.postprocess(current_images, output_type="pil")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"images": final_images,
|
||||||
|
"latents": latents.detach(),
|
||||||
|
"score_trace": score_trace,
|
||||||
|
"score_trace_per_sample": per_sample_scores,
|
||||||
|
"final_scores": current_scores.detach().cpu().tolist(),
|
||||||
|
"intermediate_images": intermediate_images,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _decode_packed_latents(self, latents: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||||
|
unpacked = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
||||||
|
unpacked = (unpacked / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||||
|
decoded = self.vae.decode(unpacked, return_dict=False)[0]
|
||||||
|
return decoded
|
||||||
|
|
||||||
|
def _resolve_height_width(self, height: Optional[int], width: Optional[int]) -> tuple[int, int]:
|
||||||
|
height = height or self.default_sample_size * self.vae_scale_factor
|
||||||
|
width = width or self.default_sample_size * self.vae_scale_factor
|
||||||
|
return height, width
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _expand_prompts(prompts: Union[str, list[str]], batch_size: int) -> list[str]:
|
||||||
|
if isinstance(prompts, str):
|
||||||
|
return [prompts] * batch_size
|
||||||
|
if len(prompts) != batch_size:
|
||||||
|
raise ValueError(f"Expected {batch_size} reward prompts, got {len(prompts)}.")
|
||||||
|
return prompts
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--model-id", type=str, default="black-forest-labs/FLUX.1-dev")
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt", type=str, default="Photo of a dog sitting near a sea waiting for its companion to come."
|
||||||
|
)
|
||||||
|
parser.add_argument("--reward-prompt", type=str, default=None)
|
||||||
|
parser.add_argument("--output", default="flux_adversarial.png")
|
||||||
|
parser.add_argument("--num-rounds", type=int, default=3)
|
||||||
|
parser.add_argument("--step-size", type=float, default=0.1)
|
||||||
|
parser.add_argument("--epsilon", type=float, help="FGSM epsilon. Falls back to step size when omitted.")
|
||||||
|
parser.add_argument("--attack-type", choices=["pgd", "fgsm"], default="pgd")
|
||||||
|
parser.add_argument("--num-inference-steps", type=int, default=30)
|
||||||
|
parser.add_argument("--guidance-scale", type=float, default=3.5)
|
||||||
|
parser.add_argument("--height", type=int, default=1024)
|
||||||
|
parser.add_argument("--width", type=int, default=1024)
|
||||||
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
|
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
parser.add_argument("--scorer", choices=available_scorers(), default="clip")
|
||||||
|
parser.add_argument("--scorer-model-id", type=str, default=None)
|
||||||
|
parser.add_argument("--record-intermediates", action="store_true")
|
||||||
|
parser.add_argument("--intermediate-dir", type=str, default=None)
|
||||||
|
parser.add_argument("--metadata-output", type=str, default=None)
|
||||||
|
parser.add_argument("--output-root", type=str, required=True)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def save_intermediates(intermediate_dir: Path, rounds: list[list[Image.Image]]) -> None:
|
||||||
|
intermediate_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
for round_index, images in enumerate(rounds):
|
||||||
|
for sample_index, image in enumerate(images):
|
||||||
|
filename = intermediate_dir / f"round_{round_index:02d}_sample_{sample_index:02d}.png"
|
||||||
|
image.save(filename)
|
||||||
|
|
||||||
|
if rounds and len(rounds[0]) == 1:
|
||||||
|
grid = make_image_grid([imgs[0] for imgs in rounds], cols=len(rounds), rows=1)
|
||||||
|
grid.save(intermediate_dir / "grid.png")
|
||||||
|
|
||||||
|
|
||||||
|
def dump_metadata(metadata_path: Path, payload: dict[str, object]) -> None:
|
||||||
|
metadata_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with metadata_path.open("w", encoding="utf-8") as file:
|
||||||
|
json.dump(payload, file)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
generator = None
|
||||||
|
if args.seed is not None:
|
||||||
|
generator = torch.Generator(device=args.device)
|
||||||
|
generator.manual_seed(args.seed)
|
||||||
|
|
||||||
|
dtype = torch.bfloat16 if args.device.startswith("cuda") else torch.float32
|
||||||
|
pipe = AdversarialFluxPipeline.from_pretrained(args.model_id, torch_dtype=dtype)
|
||||||
|
pipe.to(args.device)
|
||||||
|
|
||||||
|
reward_model = build_scorer(name=args.scorer, model_id=args.scorer_model_id, device=args.device)
|
||||||
|
|
||||||
|
record_intermediate = args.record_intermediates or args.intermediate_dir is not None
|
||||||
|
|
||||||
|
result = pipe.adversarial_refinement(
|
||||||
|
prompt=args.prompt,
|
||||||
|
reward_prompt=args.reward_prompt,
|
||||||
|
reward_model=reward_model,
|
||||||
|
num_rounds=args.num_rounds,
|
||||||
|
step_size=args.step_size,
|
||||||
|
epsilon=args.epsilon,
|
||||||
|
attack_type=args.attack_type,
|
||||||
|
num_inference_steps=args.num_inference_steps,
|
||||||
|
guidance_scale=args.guidance_scale,
|
||||||
|
height=args.height,
|
||||||
|
width=args.width,
|
||||||
|
generator=generator,
|
||||||
|
record_intermediate=record_intermediate,
|
||||||
|
)
|
||||||
|
|
||||||
|
images = result["images"]
|
||||||
|
output_path = Path(args.output_root)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
images[0].save(output_path)
|
||||||
|
|
||||||
|
if args.intermediate_dir:
|
||||||
|
intermediate_dir = output_path / args.intermediate_dir
|
||||||
|
if intermediate_dir and result["intermediate_images"]:
|
||||||
|
save_intermediates(intermediate_dir, result["intermediate_images"])
|
||||||
|
|
||||||
|
if args.metadata_output:
|
||||||
|
metadata_payload = {
|
||||||
|
"prompt": args.prompt,
|
||||||
|
"reward_prompt": args.reward_prompt or "",
|
||||||
|
"scorer": {
|
||||||
|
"name": args.scorer,
|
||||||
|
"model_id": getattr(reward_model, "model_id", args.scorer_model_id or ""),
|
||||||
|
},
|
||||||
|
"attack": {
|
||||||
|
"type": args.attack_type,
|
||||||
|
"num_rounds": args.num_rounds,
|
||||||
|
"step_size": args.step_size,
|
||||||
|
"epsilon": args.epsilon
|
||||||
|
if args.epsilon is not None
|
||||||
|
else args.step_size
|
||||||
|
if args.attack_type == "fgsm"
|
||||||
|
else None,
|
||||||
|
"rounds_executed": len(result["score_trace"]) - 1,
|
||||||
|
},
|
||||||
|
"score_trace": result["score_trace"],
|
||||||
|
"score_trace_per_sample": result["score_trace_per_sample"],
|
||||||
|
"final_scores": result["final_scores"],
|
||||||
|
}
|
||||||
|
metadata_path = output_path / args.metadata_output
|
||||||
|
if metadata_path:
|
||||||
|
dump_metadata(metadata_path, metadata_payload)
|
||||||
|
|
||||||
|
print("Mean score trace:", result["score_trace"])
|
||||||
|
print("Final per-sample scores:", result["final_scores"])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
196
examples/community/adversarial-tts/reward_scorers.py
Normal file
196
examples/community/adversarial-tts/reward_scorers.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
import warnings
|
||||||
|
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers import CLIPImageProcessor, CLIPModel, CLIPTokenizer, SiglipModel, SiglipProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRewardScorer(nn.Module):
|
||||||
|
"""
|
||||||
|
Base interface for reward scorers.
|
||||||
|
|
||||||
|
Subclasses are expected to implement a differentiable `forward` method that
|
||||||
|
accepts a batch of images in the `[-1, 1]` range and a batch of prompt strings
|
||||||
|
with the same batch dimension.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "base"
|
||||||
|
default_model_id: Optional[str] = None
|
||||||
|
supports_gradients: bool = True
|
||||||
|
|
||||||
|
def __init__(self, model_id: Optional[str] = None, device: Optional[torch.device] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.model_id = model_id or self.default_model_id
|
||||||
|
if self.model_id is None:
|
||||||
|
raise ValueError(f"{self.__class__.__name__} requires `model_id` to be specified.")
|
||||||
|
self._requested_device = torch.device(device) if device is not None else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
parameters = list(self.parameters())
|
||||||
|
if parameters:
|
||||||
|
return parameters[0].device
|
||||||
|
return self._requested_device or torch.device("cpu")
|
||||||
|
|
||||||
|
def ensure_device(self) -> None:
|
||||||
|
if self._requested_device is not None:
|
||||||
|
self.to(self._requested_device)
|
||||||
|
|
||||||
|
def forward(self, images: torch.Tensor, prompts: Sequence[str]) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class ClipScorer(BaseRewardScorer):
|
||||||
|
name = "clip"
|
||||||
|
default_model_id = "openai/clip-vit-large-patch14"
|
||||||
|
|
||||||
|
def __init__(self, model_id: Optional[str] = None, device: Optional[torch.device] = None):
|
||||||
|
super().__init__(model_id=model_id, device=device)
|
||||||
|
|
||||||
|
self.model = CLIPModel.from_pretrained(self.model_id)
|
||||||
|
self.tokenizer = CLIPTokenizer.from_pretrained(self.model_id)
|
||||||
|
self.image_processor = CLIPImageProcessor.from_pretrained(self.model_id)
|
||||||
|
if self._requested_device is not None:
|
||||||
|
self.model = self.model.to(self._requested_device)
|
||||||
|
self.model = self.model.to(dtype=torch.float32)
|
||||||
|
self.model.eval()
|
||||||
|
for parameter in self.model.parameters():
|
||||||
|
parameter.requires_grad_(False)
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
def forward(self, images: torch.Tensor, prompts: Sequence[str]) -> torch.Tensor:
|
||||||
|
device = self.model.device
|
||||||
|
pixel_values = self._preprocess_images(images).to(device=device, dtype=torch.float32)
|
||||||
|
text_inputs = self.tokenizer(list(prompts), padding=True, truncation=True, return_tensors="pt").to(device)
|
||||||
|
|
||||||
|
image_embeds = self.model.get_image_features(pixel_values=pixel_values)
|
||||||
|
text_embeds = self.model.get_text_features(**text_inputs)
|
||||||
|
|
||||||
|
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
|
||||||
|
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
|
||||||
|
return (image_embeds * text_embeds).sum(dim=-1)
|
||||||
|
|
||||||
|
def _preprocess_images(self, images: torch.Tensor) -> torch.Tensor:
|
||||||
|
pixel_values = (images + 1) / 2
|
||||||
|
pixel_values = torch.clamp(pixel_values, 0, 1)
|
||||||
|
|
||||||
|
crop_size = self.image_processor.crop_size
|
||||||
|
if isinstance(crop_size, dict):
|
||||||
|
target_height = crop_size["height"]
|
||||||
|
target_width = crop_size["width"]
|
||||||
|
else:
|
||||||
|
target_height = target_width = crop_size
|
||||||
|
|
||||||
|
pixel_values = F.interpolate(
|
||||||
|
pixel_values, size=(target_height, target_width), mode="bilinear", align_corners=False
|
||||||
|
)
|
||||||
|
|
||||||
|
mean = torch.tensor(
|
||||||
|
self.image_processor.image_mean, device=pixel_values.device, dtype=pixel_values.dtype
|
||||||
|
).view(1, -1, 1, 1)
|
||||||
|
std = torch.tensor(self.image_processor.image_std, device=pixel_values.device, dtype=pixel_values.dtype).view(
|
||||||
|
1, -1, 1, 1
|
||||||
|
)
|
||||||
|
return (pixel_values - mean) / std
|
||||||
|
|
||||||
|
|
||||||
|
class SiglipScorer(BaseRewardScorer):
|
||||||
|
name = "siglip"
|
||||||
|
default_model_id = "google/siglip-so400m-patch14-384"
|
||||||
|
|
||||||
|
def __init__(self, model_id: Optional[str] = None, device: Optional[torch.device] = None):
|
||||||
|
super().__init__(model_id=model_id, device=device)
|
||||||
|
|
||||||
|
self.processor = SiglipProcessor.from_pretrained(self.model_id)
|
||||||
|
self.image_processor = self.processor.image_processor
|
||||||
|
self.text_tokenizer = self.processor.tokenizer
|
||||||
|
self.model = SiglipModel.from_pretrained(self.model_id)
|
||||||
|
if self._requested_device is not None:
|
||||||
|
self.model = self.model.to(self._requested_device)
|
||||||
|
self.model = self.model.to(dtype=torch.float32)
|
||||||
|
self.model.eval()
|
||||||
|
for parameter in self.model.parameters():
|
||||||
|
parameter.requires_grad_(False)
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
def forward(self, images: torch.Tensor, prompts: Sequence[str]) -> torch.Tensor: # type: ignore[override]
|
||||||
|
device = self.model.device
|
||||||
|
pixel_values = self._preprocess_images(images).to(device=device, dtype=torch.float32)
|
||||||
|
text_inputs = self.text_tokenizer(list(prompts), padding=True, truncation=True, return_tensors="pt").to(device)
|
||||||
|
|
||||||
|
image_embeds = self.model.get_image_features(pixel_values=pixel_values)
|
||||||
|
text_embeds = self.model.get_text_features(**text_inputs)
|
||||||
|
|
||||||
|
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
|
||||||
|
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
return (image_embeds * text_embeds).sum(dim=-1)
|
||||||
|
|
||||||
|
def _preprocess_images(self, images: torch.Tensor) -> torch.Tensor:
|
||||||
|
pixel_values = (images + 1) / 2
|
||||||
|
pixel_values = torch.clamp(pixel_values, 0, 1)
|
||||||
|
|
||||||
|
size = self.image_processor.size
|
||||||
|
if isinstance(size, dict):
|
||||||
|
target_height = size.get("shortest_edge") or size.get("height") or size.get("width")
|
||||||
|
target_width = size.get("width") or target_height
|
||||||
|
target_height = target_height or target_width
|
||||||
|
else:
|
||||||
|
target_height = target_width = size
|
||||||
|
|
||||||
|
pixel_values = F.interpolate(
|
||||||
|
pixel_values, size=(target_height, target_width), mode="bilinear", align_corners=False
|
||||||
|
)
|
||||||
|
|
||||||
|
mean = torch.tensor(
|
||||||
|
self.image_processor.image_mean, device=pixel_values.device, dtype=pixel_values.dtype
|
||||||
|
).view(1, -1, 1, 1)
|
||||||
|
std = torch.tensor(self.image_processor.image_std, device=pixel_values.device, dtype=pixel_values.dtype).view(
|
||||||
|
1, -1, 1, 1
|
||||||
|
)
|
||||||
|
return (pixel_values - mean) / std
|
||||||
|
|
||||||
|
|
||||||
|
class PlaceholderScorer(BaseRewardScorer):
|
||||||
|
"""
|
||||||
|
Helper scorer that surfaces a friendly error for scorers that require external dependencies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "placeholder"
|
||||||
|
supports_gradients = False
|
||||||
|
|
||||||
|
def __init__(self, *args: Any, required_package: str, scorer_name: str, **kwargs: Any):
|
||||||
|
self.required_package = required_package
|
||||||
|
self.scorer_name = scorer_name
|
||||||
|
raise ImportError(f"{scorer_name} requires the external package `{required_package}` which is not installed.")
|
||||||
|
|
||||||
|
|
||||||
|
SCORER_REGISTRY: Dict[str, Type[BaseRewardScorer]] = {ClipScorer.name: ClipScorer, SiglipScorer.name: SiglipScorer}
|
||||||
|
|
||||||
|
|
||||||
|
def available_scorers() -> Tuple[str, ...]:
|
||||||
|
return tuple(sorted(SCORER_REGISTRY.keys()))
|
||||||
|
|
||||||
|
|
||||||
|
def build_scorer(
|
||||||
|
name: str,
|
||||||
|
model_id: Optional[str] = None,
|
||||||
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> BaseRewardScorer:
|
||||||
|
if name not in SCORER_REGISTRY:
|
||||||
|
raise ValueError(f"Unknown scorer `{name}`. Available scorers: {', '.join(available_scorers())}.")
|
||||||
|
scorer_cls = SCORER_REGISTRY[name]
|
||||||
|
device_obj = torch.device(device) if device is not None else None
|
||||||
|
|
||||||
|
scorer = scorer_cls(model_id=model_id, device=device_obj, **kwargs)
|
||||||
|
|
||||||
|
if not scorer.supports_gradients:
|
||||||
|
warnings.warn(
|
||||||
|
f"Scorer `{name}` does not declare gradient support. Adversarial refinement may not work as expected.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
return scorer
|
||||||
Reference in New Issue
Block a user