mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-16 09:24:54 +08:00
[train_text_to_image_lora] Better image interpolation in training scripts follow up (#11427)
* Update train_text_to_image_lora.py * update_train_text_to_image_lora
This commit is contained in:
@@ -418,6 +418,15 @@ def parse_args():
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_interpolation_mode",
|
||||
type=str,
|
||||
default="lanczos",
|
||||
choices=[
|
||||
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
|
||||
],
|
||||
help="The image interpolation method to use for resizing images.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
@@ -649,10 +658,17 @@ def main():
|
||||
)
|
||||
return inputs.input_ids
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# Get the specified interpolation method from the args
|
||||
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
|
||||
|
||||
# Raise an error if the interpolation method is invalid
|
||||
if interpolation is None:
|
||||
raise ValueError(f"Unsupported interpolation mode {args.image_interpolation_mode}.")
|
||||
|
||||
# Data preprocessing transformations
|
||||
train_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.Resize(args.resolution, interpolation=interpolation), # Use dynamic interpolation method
|
||||
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
|
||||
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
|
||||
transforms.ToTensor(),
|
||||
|
||||
Reference in New Issue
Block a user