Compare commits

..

2 Commits

Author SHA1 Message Date
sayakpaul
4e01e02395 add mslk for additional dependencies. 2026-03-25 09:41:04 +05:30
sayakpaul
5e5b575fb3 fix torchao tests 2026-03-25 09:38:49 +05:30
3 changed files with 51 additions and 35 deletions

View File

@@ -341,7 +341,7 @@ jobs:
additional_deps: ["peft", "kernels"]
- backend: "torchao"
test_location: "torchao"
additional_deps: []
additional_deps: [mslk-cuda]
- backend: "optimum_quanto"
test_location: "quanto"
additional_deps: []

View File

@@ -1,45 +1,73 @@
# Adapted from https://blog.deepjyoti30.dev/pypi-release-github-action
name: PyPI release
on:
workflow_dispatch:
push:
tags:
- "v*"
- "*"
jobs:
build-and-test:
find-and-checkout-latest-branch:
runs-on: ubuntu-22.04
outputs:
latest_branch: ${{ steps.set_latest_branch.outputs.latest_branch }}
steps:
- name: Checkout repo
- name: Checkout Repo
uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.10"
python-version: '3.10'
- name: Fetch and checkout latest release branch
- name: Fetch latest branch
id: fetch_latest_branch
run: |
pip install -U requests packaging
LATEST_BRANCH=$(python utils/fetch_latest_release_branch.py)
echo "Latest branch: $LATEST_BRANCH"
git fetch origin "$LATEST_BRANCH"
git checkout "$LATEST_BRANCH"
echo "latest_branch=$LATEST_BRANCH" >> $GITHUB_ENV
- name: Install build dependencies
- name: Set latest branch output
id: set_latest_branch
run: echo "::set-output name=latest_branch::${{ env.latest_branch }}"
release:
needs: find-and-checkout-latest-branch
runs-on: ubuntu-22.04
steps:
- name: Checkout Repo
uses: actions/checkout@v6
with:
ref: ${{ needs.find-and-checkout-latest-branch.outputs.latest_branch }}
- name: Setup Python
uses: actions/setup-python@v6
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -U build
pip install -U setuptools wheel twine
pip install -U torch --index-url https://download.pytorch.org/whl/cpu
- name: Build the dist files
run: python -m build
run: python setup.py bdist_wheel && python setup.py sdist
- name: Install from built wheel
run: pip install dist/*.whl
- name: Publish to the test PyPI
env:
TWINE_USERNAME: ${{ secrets.TEST_PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.TEST_PYPI_PASSWORD }}
run: twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
- name: Test installing diffusers and importing
run: |
pip install diffusers && pip uninstall diffusers -y
pip install -i https://test.pypi.org/simple/ diffusers
pip install -U transformers
python utils/print_env.py
python -c "from diffusers import __version__; print(__version__)"
@@ -47,26 +75,8 @@ jobs:
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('hf-internal-testing/tiny-stable-diffusion-pipe', safety_checker=None); pipe('ah suh du')"
python -c "from diffusers import *"
- name: Upload build artifacts
uses: actions/upload-artifact@v4
with:
name: python-dist
path: dist/
publish-to-pypi:
needs: build-and-test
if: startsWith(github.ref, 'refs/tags/')
runs-on: ubuntu-22.04
environment: pypi-release
permissions:
id-token: write
steps:
- name: Download build artifacts
uses: actions/download-artifact@v4
with:
name: python-dist
path: dist/
- name: Publish to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: twine upload dist/* -r pypi

View File

@@ -177,6 +177,11 @@ class QuantizationTesterMixin:
model_quantized.to(torch_device)
inputs = self.get_dummy_inputs()
model_dtype = next(model_quantized.parameters()).dtype
inputs = {
k: v.to(dtype=model_dtype) if torch.is_tensor(v) and torch.is_floating_point(v) else v
for k, v in inputs.items()
}
output = model_quantized(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None"
@@ -930,6 +935,7 @@ class TorchAoTesterMixin(TorchAoConfigMixin, QuantizationTesterMixin):
"""Test that device_map='auto' works correctly with quantization."""
self._test_quantization_device_map(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"])
@pytest.mark.xfail(reason="dequantize is not implemented in torchao")
def test_torchao_dequantize(self):
"""Test that dequantize() works correctly."""
self._test_dequantize(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"])