This commit is contained in:
sayakpaul
2026-02-14 15:55:16 +05:30
parent 6c2e10adf6
commit afafb247cc

View File

@@ -112,10 +112,14 @@ def _load_transformers_model_from_dduf(
tensors = safetensors.torch.load(mmap)
# Update the state dictionary with tensors
state_dict.update(tensors)
return cls.from_pretrained(
model = cls.from_pretrained(
pretrained_model_name_or_path=None,
config=config,
generation_config=generation_config,
state_dict=state_dict,
**kwargs,
)
# Models loaded via from_pretrained are in eval mode by default,
# but we need to preserve training mode for consistency with non-DDUF loading
model.train()
return model