Compare commits

...

7 Commits

Author SHA1 Message Date
Sayak Paul
ea90a74ed4 Merge branch 'main' into transformers-v5-pr 2026-01-15 20:01:03 +05:30
sayakpaul
96f08043a3 fix group offloading. 2026-01-15 20:00:45 +05:30
sayakpaul
d0f279ce76 up 2026-01-15 16:59:41 +05:30
sayakpaul
c5e023fbe6 up 2026-01-15 13:03:11 +05:30
Sayak Paul
f8e50fab75 Merge branch 'main' into transformers-v5-pr 2026-01-15 12:47:21 +05:30
sayakpaul
c152b1831c more 2026-01-14 14:54:39 +05:30
sayakpaul
039324ae16 switch to transformers main again./ 2026-01-14 14:52:15 +05:30
7 changed files with 28 additions and 13 deletions

View File

@@ -115,8 +115,8 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality]"
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
# uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
- name: Environment
@@ -247,8 +247,8 @@ jobs:
uv pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
uv pip install -U tokenizers
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
# uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |

View File

@@ -14,6 +14,7 @@ on:
- "tests/pipelines/test_pipelines_common.py"
- "tests/models/test_modeling_common.py"
- "examples/**/*.py"
- ".github/**.yml"
workflow_dispatch:
concurrency:
@@ -131,8 +132,8 @@ jobs:
run: |
uv pip install -e ".[quality]"
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
# uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -202,8 +203,8 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
# uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -264,8 +265,8 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
# uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip install -e ".[quality,training]"
- name: Environment

View File

@@ -17,6 +17,9 @@ import logging
import os
import sys
import tempfile
import unittest
from diffusers.utils import is_transformers_version
sys.path.append("..")
@@ -30,6 +33,7 @@ stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
@unittest.skipIf(is_transformers_version(">=", "4.57.5"), "Size mismatch")
class CustomDiffusion(ExamplesTestsAccelerate):
def test_custom_diffusion(self):
with tempfile.TemporaryDirectory() as tmpdir:

View File

@@ -44,6 +44,7 @@ _GO_LC_SUPPORTED_PYTORCH_LAYERS = (
torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d,
torch.nn.Linear,
torch.nn.Embedding,
# TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
# because of double invocation of the same norm layer in CogVideoXLayerNorm
)

View File

@@ -20,6 +20,8 @@ class MultilingualCLIP(PreTrainedModel):
self.LinearTransformation = torch.nn.Linear(
in_features=config.transformerDimensions, out_features=config.numDims
)
if hasattr(self, "post_init"):
self.post_init()
def forward(self, input_ids, attention_mask):
embs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)[0]

View File

@@ -782,6 +782,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
self.prefix_encoder = PrefixEncoder(config)
self.dropout = torch.nn.Dropout(0.1)
if hasattr(self, "post_init"):
self.post_init()
def get_input_embeddings(self):
return self.embedding.word_embeddings
@@ -811,7 +814,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
use_cache = use_cache if use_cache is not None else getattr(self.config, "use_cache", None)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
batch_size, seq_length = input_ids.shape

View File

@@ -20,7 +20,9 @@ class TestAutoModel(unittest.TestCase):
side_effect=[EnvironmentError("File not found"), {"model_type": "clip_text_model"}],
)
def test_load_from_config_transformers_with_subfolder(self, mock_load_config):
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder")
model = AutoModel.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder", use_safetensors=False
)
assert isinstance(model, CLIPTextModel)
def test_load_from_config_without_subfolder(self):
@@ -28,5 +30,7 @@ class TestAutoModel(unittest.TestCase):
assert isinstance(model, LongformerModel)
def test_load_from_model_index(self):
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder")
model = AutoModel.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder", use_safetensors=False
)
assert isinstance(model, CLIPTextModel)