mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 05:24:20 +08:00
Compare commits
5 Commits
chroma-fin
...
sd3-t5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
54234652cb | ||
|
|
96880d498d | ||
|
|
991174e704 | ||
|
|
8502ea8ae1 | ||
|
|
12d0258c96 |
@@ -555,7 +555,4 @@ class FromSingleFileMixin:
|
||||
|
||||
pipe = pipeline_class(**init_kwargs)
|
||||
|
||||
if torch_dtype is not None:
|
||||
pipe.to(dtype=torch_dtype)
|
||||
|
||||
return pipe
|
||||
|
||||
@@ -1808,4 +1808,17 @@ def create_diffusers_t5_model_from_checkpoint(
|
||||
|
||||
else:
|
||||
model.load_state_dict(diffusers_format_checkpoint)
|
||||
|
||||
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (torch_dtype == torch.float16)
|
||||
if use_keep_in_fp32_modules:
|
||||
keep_in_fp32_modules = model._keep_in_fp32_modules
|
||||
else:
|
||||
keep_in_fp32_modules = []
|
||||
|
||||
if keep_in_fp32_modules is not None:
|
||||
for name, param in model.named_parameters():
|
||||
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules):
|
||||
# param = param.to(torch.float32) does not work here as only in the local scope.
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
return model
|
||||
|
||||
@@ -201,6 +201,20 @@ class SDSingleFileTesterMixin:
|
||||
|
||||
self._compare_component_configs(pipe, single_file_pipe)
|
||||
|
||||
def test_single_file_setting_pipeline_dtype_to_fp16(
|
||||
self,
|
||||
single_file_pipe=None,
|
||||
):
|
||||
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
|
||||
self.ckpt_path, torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
for component_name, component in single_file_pipe.components.items():
|
||||
if not isinstance(component, torch.nn.Module):
|
||||
continue
|
||||
|
||||
assert component.dtype == torch.float16
|
||||
|
||||
|
||||
class SDXLSingleFileTesterMixin:
|
||||
def _compare_component_configs(self, pipe, single_file_pipe):
|
||||
@@ -378,3 +392,17 @@ class SDXLSingleFileTesterMixin:
|
||||
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten())
|
||||
|
||||
assert max_diff < expected_max_diff
|
||||
|
||||
def test_single_file_setting_pipeline_dtype_to_fp16(
|
||||
self,
|
||||
single_file_pipe=None,
|
||||
):
|
||||
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
|
||||
self.ckpt_path, torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
for component_name, component in single_file_pipe.components.items():
|
||||
if not isinstance(component, torch.nn.Module):
|
||||
continue
|
||||
|
||||
assert component.dtype == torch.float16
|
||||
|
||||
@@ -180,3 +180,12 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD
|
||||
local_files_only=True,
|
||||
)
|
||||
super()._compare_component_configs(pipe, pipe_single_file)
|
||||
|
||||
def test_single_file_setting_pipeline_dtype_to_fp16(self):
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
single_file_pipe = self.pipeline_class.from_single_file(
|
||||
self.ckpt_path, controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe)
|
||||
|
||||
@@ -181,3 +181,12 @@ class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestC
|
||||
local_files_only=True,
|
||||
)
|
||||
super()._compare_component_configs(pipe, pipe_single_file)
|
||||
|
||||
def test_single_file_setting_pipeline_dtype_to_fp16(self):
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
single_file_pipe = self.pipeline_class.from_single_file(
|
||||
self.ckpt_path, controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe)
|
||||
|
||||
@@ -169,3 +169,12 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD
|
||||
local_files_only=True,
|
||||
)
|
||||
super()._compare_component_configs(pipe, pipe_single_file)
|
||||
|
||||
def test_single_file_setting_pipeline_dtype_to_fp16(self):
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
single_file_pipe = self.pipeline_class.from_single_file(
|
||||
self.ckpt_path, controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe)
|
||||
|
||||
@@ -200,3 +200,11 @@ class StableDiffusionXLAdapterPipelineSingleFileSlowTests(unittest.TestCase, SDX
|
||||
local_files_only=True,
|
||||
)
|
||||
self._compare_component_configs(pipe, pipe_single_file)
|
||||
|
||||
def test_single_file_setting_pipeline_dtype_to_fp16(self):
|
||||
adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16)
|
||||
|
||||
single_file_pipe = self.pipeline_class.from_single_file(
|
||||
self.ckpt_path, adapter=adapter, torch_dtype=torch.float16
|
||||
)
|
||||
super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe)
|
||||
|
||||
@@ -195,3 +195,12 @@ class StableDiffusionXLControlNetPipelineSingleFileSlowTests(unittest.TestCase,
|
||||
local_files_only=True,
|
||||
)
|
||||
super()._compare_component_configs(pipe, pipe_single_file)
|
||||
|
||||
def test_single_file_setting_pipeline_dtype_to_fp16(self):
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
"diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
single_file_pipe = self.pipeline_class.from_single_file(
|
||||
self.ckpt_path, controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe)
|
||||
|
||||
Reference in New Issue
Block a user