mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 21:44:27 +08:00
Compare commits
18 Commits
pr-tests-f
...
local-mode
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7fd1a8205b | ||
|
|
09e063c145 | ||
|
|
2a9734f014 | ||
|
|
1b939e570c | ||
|
|
1c528a4166 | ||
|
|
04cd2dc451 | ||
|
|
b7af5111c4 | ||
|
|
01784c39cb | ||
|
|
832de66a8d | ||
|
|
fb2397f3fe | ||
|
|
71843a0c8b | ||
|
|
d1174740bb | ||
|
|
85279dfeee | ||
|
|
2d993b71d5 | ||
|
|
f38a64443f | ||
|
|
d5c1772dc3 | ||
|
|
69920eff3e | ||
|
|
8d431dc967 |
@@ -402,15 +402,17 @@ def _get_checkpoint_shard_files(
|
||||
allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns]
|
||||
|
||||
ignore_patterns = ["*.json", "*.md"]
|
||||
# `model_info` call must guarded with the above condition.
|
||||
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
|
||||
for shard_file in original_shard_filenames:
|
||||
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
|
||||
if not shard_file_present:
|
||||
raise EnvironmentError(
|
||||
f"{shards_path} does not appear to have a file named {shard_file} which is "
|
||||
"required according to the checkpoint index."
|
||||
)
|
||||
|
||||
# If the repo doesn't have the required shards, error out early even before downloading anything.
|
||||
if not local_files_only:
|
||||
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
|
||||
for shard_file in original_shard_filenames:
|
||||
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
|
||||
if not shard_file_present:
|
||||
raise EnvironmentError(
|
||||
f"{shards_path} does not appear to have a file named {shard_file} which is "
|
||||
"required according to the checkpoint index."
|
||||
)
|
||||
|
||||
try:
|
||||
# Load from URL
|
||||
@@ -437,6 +439,11 @@ def _get_checkpoint_shard_files(
|
||||
) from e
|
||||
|
||||
cached_filenames = [os.path.join(cached_folder, f) for f in original_shard_filenames]
|
||||
for cached_file in cached_filenames:
|
||||
if not os.path.isfile(cached_file):
|
||||
raise EnvironmentError(
|
||||
f"{cached_folder} does not have a file named {cached_file} which is required according to the checkpoint index."
|
||||
)
|
||||
|
||||
return cached_filenames, sharded_metadata
|
||||
|
||||
|
||||
@@ -36,12 +36,12 @@ import safetensors.torch
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
|
||||
from huggingface_hub import ModelCard, delete_repo, snapshot_download
|
||||
from huggingface_hub import ModelCard, delete_repo, snapshot_download, try_to_load_from_cache
|
||||
from huggingface_hub.utils import is_jinja_available
|
||||
from parameterized import parameterized
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from diffusers.models import SD3Transformer2DModel, UNet2DConditionModel
|
||||
from diffusers.models import FluxTransformer2DModel, SD3Transformer2DModel, UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
@@ -291,6 +291,54 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
if p1.data.ne(p2.data).sum() > 0:
|
||||
assert False, "Parameters not the same!"
|
||||
|
||||
def test_local_files_only_with_sharded_checkpoint(self):
|
||||
repo_id = "hf-internal-testing/tiny-flux-sharded"
|
||||
error_response = mock.Mock(
|
||||
status_code=500,
|
||||
headers={},
|
||||
raise_for_status=mock.Mock(side_effect=HTTPError),
|
||||
json=mock.Mock(return_value={}),
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", cache_dir=tmpdir)
|
||||
|
||||
with mock.patch("requests.Session.get", return_value=error_response):
|
||||
# Should fail with local_files_only=False (network required)
|
||||
# We would make a network call with model_info
|
||||
with self.assertRaises(OSError):
|
||||
FluxTransformer2DModel.from_pretrained(
|
||||
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=False
|
||||
)
|
||||
|
||||
# Should succeed with local_files_only=True (uses cache)
|
||||
# model_info call skipped
|
||||
local_model = FluxTransformer2DModel.from_pretrained(
|
||||
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True
|
||||
)
|
||||
|
||||
assert all(torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())), (
|
||||
"Model parameters don't match!"
|
||||
)
|
||||
|
||||
# Remove a shard file
|
||||
cached_shard_file = try_to_load_from_cache(
|
||||
repo_id, filename="transformer/diffusion_pytorch_model-00001-of-00002.safetensors", cache_dir=tmpdir
|
||||
)
|
||||
os.remove(cached_shard_file)
|
||||
|
||||
# Attempting to load from cache should raise an error
|
||||
with self.assertRaises(OSError) as context:
|
||||
FluxTransformer2DModel.from_pretrained(
|
||||
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True
|
||||
)
|
||||
|
||||
# Verify error mentions the missing shard
|
||||
error_msg = str(context.exception)
|
||||
assert cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg, (
|
||||
f"Expected error about missing shard, got: {error_msg}"
|
||||
)
|
||||
|
||||
@unittest.skip("Flaky behaviour on CI. Re-enable after migrating to new runners")
|
||||
@unittest.skipIf(torch_device == "mps", reason="Test not supported for MPS.")
|
||||
def test_one_request_upon_cached(self):
|
||||
|
||||
Reference in New Issue
Block a user