mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-15 08:54:20 +08:00
Compare commits
1 Commits
support-gr
...
dynamic-te
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7e13c18b29 |
@@ -1377,3 +1377,88 @@ class Expectations(DevicePropertiesUserDict):
|
|||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"{self.data}"
|
return f"{self.data}"
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_slice_test(func):
|
||||||
|
"""
|
||||||
|
Decorator that injects an expected_slice parameter into a test function.
|
||||||
|
|
||||||
|
On the first run, it will capture the actual slice output and cache it.
|
||||||
|
On subsequent runs, it provides the cached slice as the expected slice.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
@dynamic_slice_test
|
||||||
|
def test_stable_diffusion_ddim(self, expected_slice=None):
|
||||||
|
# Run the pipeline
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
sd_pipe = StableDiffusionPipeline(**components)
|
||||||
|
inputs = self.get_dummy_inputs("cpu")
|
||||||
|
image = sd_pipe(**inputs).images
|
||||||
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
# If expected_slice is provided (from cache), assert against it
|
||||||
|
if expected_slice is not None:
|
||||||
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
|
# Always return the current slice for caching
|
||||||
|
return image_slice
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
# Check if the function has the expected_slice parameter
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
if "expected_slice" not in sig.parameters:
|
||||||
|
raise ValueError("The decorated function must have an 'expected_slice' parameter")
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
# Get the test name from pytest
|
||||||
|
# pytest sets this environment variable to the current test
|
||||||
|
test_name = os.environ.get("PYTEST_CURRENT_TEST", "")
|
||||||
|
if test_name:
|
||||||
|
# Format is: test_file.py::TestClass::test_method (call)
|
||||||
|
test_name = test_name.split(" ")[0]
|
||||||
|
else:
|
||||||
|
# Fallback if not running in pytest
|
||||||
|
test_name = f"{func.__module__}.{func.__qualname__}"
|
||||||
|
|
||||||
|
# Create a unique filename based on hardware details
|
||||||
|
device_props = get_device_properties()
|
||||||
|
device_str = f"{device_props[0]}{device_props[1] if device_props[1] is not None else ''}"
|
||||||
|
|
||||||
|
# Setup cache directory
|
||||||
|
cache_dir = os.environ.get("DIFFUSERS_TEST_CACHE_DIR", ".test_cache")
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
cache_path = os.path.join(cache_dir, f"{test_name}_{device_str}.npy")
|
||||||
|
|
||||||
|
# Check for cached expected slice
|
||||||
|
cached_slice = None
|
||||||
|
if os.path.exists(cache_path):
|
||||||
|
try:
|
||||||
|
cached_slice = np.load(cache_path)
|
||||||
|
print(f"Using cached slice from {cache_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading cached slice: {e}")
|
||||||
|
|
||||||
|
# Run the test function with the expected slice injected
|
||||||
|
kwargs["expected_slice"] = cached_slice
|
||||||
|
actual_slice = func(*args, **kwargs)
|
||||||
|
|
||||||
|
# If the function returned a slice and there's no cached slice yet, cache it
|
||||||
|
if actual_slice is not None and cached_slice is None:
|
||||||
|
# Convert torch tensor to numpy if needed
|
||||||
|
if hasattr(actual_slice, "detach") and hasattr(actual_slice, "cpu") and hasattr(actual_slice, "numpy"):
|
||||||
|
actual_slice_np = actual_slice.detach().cpu().numpy()
|
||||||
|
else:
|
||||||
|
actual_slice_np = actual_slice
|
||||||
|
|
||||||
|
# Save the slice
|
||||||
|
try:
|
||||||
|
np.save(cache_path, actual_slice_np)
|
||||||
|
print(f"Saved slice to cache: {cache_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error saving slice to cache: {e}")
|
||||||
|
|
||||||
|
return actual_slice
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|||||||
Reference in New Issue
Block a user