Compare commits

...

3 Commits

Author SHA1 Message Date
Sayak Paul
1ec9ab9559 Merge branch 'main' into fix/lora-loading-offline 2023-12-16 08:18:23 +05:30
sayakpaul
d4772beea2 simplify condition 2023-12-15 15:58:47 +05:30
sayakpaul
88b93c92f5 add an error message when dealing with _best_guess_weight_name ofline 2023-12-15 13:39:54 +05:30

View File

@@ -18,6 +18,7 @@ from typing import Callable, Dict, List, Optional, Union
import safetensors
import torch
from huggingface_hub import model_info
from huggingface_hub.constants import HF_HUB_OFFLINE
from huggingface_hub.utils import validate_hf_hub_args
from packaging import version
from torch import nn
@@ -229,7 +230,9 @@ class LoraLoaderMixin:
# determine `weight_name`.
if weight_name is None:
weight_name = cls._best_guess_weight_name(
pretrained_model_name_or_path_or_dict, file_extension=".safetensors"
pretrained_model_name_or_path_or_dict,
file_extension=".safetensors",
local_files_only=local_files_only,
)
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
@@ -255,7 +258,7 @@ class LoraLoaderMixin:
if model_file is None:
if weight_name is None:
weight_name = cls._best_guess_weight_name(
pretrained_model_name_or_path_or_dict, file_extension=".bin"
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
)
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
@@ -294,7 +297,12 @@ class LoraLoaderMixin:
return state_dict, network_alphas
@classmethod
def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors"):
def _best_guess_weight_name(
cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
):
if local_files_only or HF_HUB_OFFLINE:
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
targeted_files = []
if os.path.isfile(pretrained_model_name_or_path_or_dict):