[docs] add float64 + runtime weight-dtype gotchas to models.md

Document two dtype pitfalls surfaced by Ernie-Image follow-up #13464:
unconditional torch.float64 in RoPE/precompute (breaks MPS/NPU) and
reading a child module's weight dtype at runtime (breaks gguf/quant).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
yiyi@huggingface.co
2026-04-14 20:26:26 +00:00
parent e9c092d886
commit c421712df1

View File

@@ -74,3 +74,15 @@ Consult the implementations in `src/diffusers/models/transformers/` if you need
7. **Forgetting to update `_import_structure` and `_lazy_modules`.** The top-level `src/diffusers/__init__.py` has both -- missing either one causes partial import failures.
8. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16` in the model's forward pass. Use the dtype of the input tensors or `self.dtype` so the model works with any precision.
9. **`torch.float64` anywhere in the model.** MPS and several NPU backends don't support float64 -- ops will either error out or silently fall back. Reference repos commonly reach for float64 in RoPE frequency bases, timestep embeddings, sinusoidal position encodings, and similar "precision-sensitive" precompute code (`torch.arange(..., dtype=torch.float64)`, `.double()`, `torch.float64` literals). When porting a model, grep for `float64` / `double()` up front and resolve as follows:
- **Default: just use `torch.float32`.** For inference it is almost always sufficient -- the precision difference in RoPE angles, timestep embeddings, etc. is immaterial to image/video quality. Flip it and move on.
- **Only if float32 visibly degrades output, fall back to the device-gated pattern** we use in the repo:
```python
is_mps = hidden_states.device.type == "mps"
is_npu = hidden_states.device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
```
See `transformer_flux.py`, `transformer_flux2.py`, `transformer_wan.py`, `unet_2d_condition.py` for reference usages. Never leave an unconditional `torch.float64` in the model.
10. **Reading a weight's dtype at runtime to cast activations.** Patterns like `x = x.to(self.linear.weight.dtype)` break under gguf / quantized loading, where the stored weight dtype isn't the compute dtype. Cast activations using the input tensor's dtype or `self.dtype`, not by peeking at a child module's parameter.