Compare commits

...

1 Commits

Author SHA1 Message Date
DN6
7e13c18b29 update 2025-04-08 20:38:58 +05:30

View File

@@ -1377,3 +1377,88 @@ class Expectations(DevicePropertiesUserDict):
def __repr__(self):
return f"{self.data}"
def dynamic_slice_test(func):
"""
Decorator that injects an expected_slice parameter into a test function.
On the first run, it will capture the actual slice output and cache it.
On subsequent runs, it provides the cached slice as the expected slice.
Example:
```python
@dynamic_slice_test
def test_stable_diffusion_ddim(self, expected_slice=None):
# Run the pipeline
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
inputs = self.get_dummy_inputs("cpu")
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
# If expected_slice is provided (from cache), assert against it
if expected_slice is not None:
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
# Always return the current slice for caching
return image_slice
```
"""
# Check if the function has the expected_slice parameter
sig = inspect.signature(func)
if "expected_slice" not in sig.parameters:
raise ValueError("The decorated function must have an 'expected_slice' parameter")
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Get the test name from pytest
# pytest sets this environment variable to the current test
test_name = os.environ.get("PYTEST_CURRENT_TEST", "")
if test_name:
# Format is: test_file.py::TestClass::test_method (call)
test_name = test_name.split(" ")[0]
else:
# Fallback if not running in pytest
test_name = f"{func.__module__}.{func.__qualname__}"
# Create a unique filename based on hardware details
device_props = get_device_properties()
device_str = f"{device_props[0]}{device_props[1] if device_props[1] is not None else ''}"
# Setup cache directory
cache_dir = os.environ.get("DIFFUSERS_TEST_CACHE_DIR", ".test_cache")
os.makedirs(cache_dir, exist_ok=True)
cache_path = os.path.join(cache_dir, f"{test_name}_{device_str}.npy")
# Check for cached expected slice
cached_slice = None
if os.path.exists(cache_path):
try:
cached_slice = np.load(cache_path)
print(f"Using cached slice from {cache_path}")
except Exception as e:
print(f"Error loading cached slice: {e}")
# Run the test function with the expected slice injected
kwargs["expected_slice"] = cached_slice
actual_slice = func(*args, **kwargs)
# If the function returned a slice and there's no cached slice yet, cache it
if actual_slice is not None and cached_slice is None:
# Convert torch tensor to numpy if needed
if hasattr(actual_slice, "detach") and hasattr(actual_slice, "cpu") and hasattr(actual_slice, "numpy"):
actual_slice_np = actual_slice.detach().cpu().numpy()
else:
actual_slice_np = actual_slice
# Save the slice
try:
np.save(cache_path, actual_slice_np)
print(f"Saved slice to cache: {cache_path}")
except Exception as e:
print(f"Error saving slice to cache: {e}")
return actual_slice
return wrapper