Compare commits

...

3 Commits

Author SHA1 Message Date
Marc Sun
93c38d2094 style 2025-05-20 18:43:52 +02:00
Marc Sun
cad495446d quick fix 2025-05-20 18:39:46 +02:00
Marc Sun
a99663d37a load tensors on cuda 2025-05-20 18:24:38 +02:00

View File

@@ -1185,8 +1185,18 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
state_dict = None
if not is_sharded:
map_location = "cpu"
if (
device_map is not None
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
# Time to load the checkpoint
state_dict = load_state_dict(resolved_model_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries)
state_dict = load_state_dict(
resolved_model_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries, map_location=map_location
)
# We only fix it for non sharded checkpoints as we don't need it yet for sharded one.
model._fix_state_dict_keys_on_load(state_dict)
@@ -1438,8 +1448,16 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
if len(resolved_model_file) > 1:
resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
map_location = "cpu"
if (
device_map is not None
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
for shard_file in resolved_model_file:
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries, map_location=map_location)
def _find_mismatched_keys(
state_dict,