Compare commits

...

18 Commits

Author SHA1 Message Date
DN6
7fd1a8205b update 2025-08-14 14:03:01 +05:30
Sayak Paul
09e063c145 Merge branch 'main' into local-model-info 2025-08-13 21:19:54 +05:30
sayakpaul
2a9734f014 empty 2025-08-13 20:46:04 +05:30
sayakpaul
1b939e570c up 2025-08-13 14:56:52 +05:30
sayakpaul
1c528a4166 up 2025-08-13 14:55:18 +05:30
sayakpaul
04cd2dc451 reviewer feedback. 2025-08-13 14:50:50 +05:30
sayakpaul
b7af5111c4 reviewer feedback. 2025-08-13 14:31:05 +05:30
Sayak Paul
01784c39cb Merge branch 'main' into local-model-info 2025-08-13 14:16:43 +05:30
Sayak Paul
832de66a8d Merge branch 'main' into local-model-info 2025-08-13 08:02:21 +05:30
sayakpaul
fb2397f3fe up 2025-08-12 20:26:54 +05:30
Sayak Paul
71843a0c8b Merge branch 'main' into local-model-info 2025-08-12 20:20:33 +05:30
Sayak Paul
d1174740bb Merge branch 'main' into local-model-info 2025-08-07 10:08:33 +05:30
Sayak Paul
85279dfeee Merge branch 'main' into local-model-info 2025-08-01 08:13:57 +05:30
Sayak Paul
2d993b71d5 Merge branch 'main' into local-model-info 2025-07-29 13:58:33 +05:30
sayakpaul
f38a64443f Revert "tighten compilation tests for quantization"
This reverts commit 8d431dc967.
2025-07-28 20:19:38 +05:30
sayakpaul
d5c1772dc3 up 2025-07-28 20:17:24 +05:30
sayakpaul
69920eff3e feat: model_info but local. 2025-07-28 15:16:53 +05:30
sayakpaul
8d431dc967 tighten compilation tests for quantization 2025-07-28 13:27:20 +05:30
2 changed files with 66 additions and 11 deletions

View File

@@ -402,15 +402,17 @@ def _get_checkpoint_shard_files(
allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns]
ignore_patterns = ["*.json", "*.md"]
# `model_info` call must guarded with the above condition.
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
for shard_file in original_shard_filenames:
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
if not shard_file_present:
raise EnvironmentError(
f"{shards_path} does not appear to have a file named {shard_file} which is "
"required according to the checkpoint index."
)
# If the repo doesn't have the required shards, error out early even before downloading anything.
if not local_files_only:
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
for shard_file in original_shard_filenames:
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
if not shard_file_present:
raise EnvironmentError(
f"{shards_path} does not appear to have a file named {shard_file} which is "
"required according to the checkpoint index."
)
try:
# Load from URL
@@ -437,6 +439,11 @@ def _get_checkpoint_shard_files(
) from e
cached_filenames = [os.path.join(cached_folder, f) for f in original_shard_filenames]
for cached_file in cached_filenames:
if not os.path.isfile(cached_file):
raise EnvironmentError(
f"{cached_folder} does not have a file named {cached_file} which is required according to the checkpoint index."
)
return cached_filenames, sharded_metadata

View File

@@ -36,12 +36,12 @@ import safetensors.torch
import torch
import torch.nn as nn
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
from huggingface_hub import ModelCard, delete_repo, snapshot_download
from huggingface_hub import ModelCard, delete_repo, snapshot_download, try_to_load_from_cache
from huggingface_hub.utils import is_jinja_available
from parameterized import parameterized
from requests.exceptions import HTTPError
from diffusers.models import SD3Transformer2DModel, UNet2DConditionModel
from diffusers.models import FluxTransformer2DModel, SD3Transformer2DModel, UNet2DConditionModel
from diffusers.models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
@@ -291,6 +291,54 @@ class ModelUtilsTest(unittest.TestCase):
if p1.data.ne(p2.data).sum() > 0:
assert False, "Parameters not the same!"
def test_local_files_only_with_sharded_checkpoint(self):
repo_id = "hf-internal-testing/tiny-flux-sharded"
error_response = mock.Mock(
status_code=500,
headers={},
raise_for_status=mock.Mock(side_effect=HTTPError),
json=mock.Mock(return_value={}),
)
with tempfile.TemporaryDirectory() as tmpdir:
model = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", cache_dir=tmpdir)
with mock.patch("requests.Session.get", return_value=error_response):
# Should fail with local_files_only=False (network required)
# We would make a network call with model_info
with self.assertRaises(OSError):
FluxTransformer2DModel.from_pretrained(
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=False
)
# Should succeed with local_files_only=True (uses cache)
# model_info call skipped
local_model = FluxTransformer2DModel.from_pretrained(
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True
)
assert all(torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())), (
"Model parameters don't match!"
)
# Remove a shard file
cached_shard_file = try_to_load_from_cache(
repo_id, filename="transformer/diffusion_pytorch_model-00001-of-00002.safetensors", cache_dir=tmpdir
)
os.remove(cached_shard_file)
# Attempting to load from cache should raise an error
with self.assertRaises(OSError) as context:
FluxTransformer2DModel.from_pretrained(
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True
)
# Verify error mentions the missing shard
error_msg = str(context.exception)
assert cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg, (
f"Expected error about missing shard, got: {error_msg}"
)
@unittest.skip("Flaky behaviour on CI. Re-enable after migrating to new runners")
@unittest.skipIf(torch_device == "mps", reason="Test not supported for MPS.")
def test_one_request_upon_cached(self):