Compare commits

...

5 Commits

Author SHA1 Message Date
Dhruv Nair
54234652cb Merge branch 'main' into sd3-t5 2024-07-04 15:43:29 +05:30
Dhruv Nair
96880d498d update 2024-07-03 05:45:46 +00:00
Dhruv Nair
991174e704 update 2024-07-03 04:46:06 +00:00
Dhruv Nair
8502ea8ae1 update 2024-07-03 04:16:45 +00:00
Dhruv Nair
12d0258c96 update 2024-07-03 03:49:18 +00:00
8 changed files with 85 additions and 3 deletions

View File

@@ -555,7 +555,4 @@ class FromSingleFileMixin:
pipe = pipeline_class(**init_kwargs)
if torch_dtype is not None:
pipe.to(dtype=torch_dtype)
return pipe

View File

@@ -1808,4 +1808,17 @@ def create_diffusers_t5_model_from_checkpoint(
else:
model.load_state_dict(diffusers_format_checkpoint)
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (torch_dtype == torch.float16)
if use_keep_in_fp32_modules:
keep_in_fp32_modules = model._keep_in_fp32_modules
else:
keep_in_fp32_modules = []
if keep_in_fp32_modules is not None:
for name, param in model.named_parameters():
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules):
# param = param.to(torch.float32) does not work here as only in the local scope.
param.data = param.data.to(torch.float32)
return model

View File

@@ -201,6 +201,20 @@ class SDSingleFileTesterMixin:
self._compare_component_configs(pipe, single_file_pipe)
def test_single_file_setting_pipeline_dtype_to_fp16(
self,
single_file_pipe=None,
):
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
self.ckpt_path, torch_dtype=torch.float16
)
for component_name, component in single_file_pipe.components.items():
if not isinstance(component, torch.nn.Module):
continue
assert component.dtype == torch.float16
class SDXLSingleFileTesterMixin:
def _compare_component_configs(self, pipe, single_file_pipe):
@@ -378,3 +392,17 @@ class SDXLSingleFileTesterMixin:
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten())
assert max_diff < expected_max_diff
def test_single_file_setting_pipeline_dtype_to_fp16(
self,
single_file_pipe=None,
):
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
self.ckpt_path, torch_dtype=torch.float16
)
for component_name, component in single_file_pipe.components.items():
if not isinstance(component, torch.nn.Module):
continue
assert component.dtype == torch.float16

View File

@@ -180,3 +180,12 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD
local_files_only=True,
)
super()._compare_component_configs(pipe, pipe_single_file)
def test_single_file_setting_pipeline_dtype_to_fp16(self):
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
)
single_file_pipe = self.pipeline_class.from_single_file(
self.ckpt_path, controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
)
super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe)

View File

@@ -181,3 +181,12 @@ class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestC
local_files_only=True,
)
super()._compare_component_configs(pipe, pipe_single_file)
def test_single_file_setting_pipeline_dtype_to_fp16(self):
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
)
single_file_pipe = self.pipeline_class.from_single_file(
self.ckpt_path, controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
)
super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe)

View File

@@ -169,3 +169,12 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD
local_files_only=True,
)
super()._compare_component_configs(pipe, pipe_single_file)
def test_single_file_setting_pipeline_dtype_to_fp16(self):
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
)
single_file_pipe = self.pipeline_class.from_single_file(
self.ckpt_path, controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
)
super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe)

View File

@@ -200,3 +200,11 @@ class StableDiffusionXLAdapterPipelineSingleFileSlowTests(unittest.TestCase, SDX
local_files_only=True,
)
self._compare_component_configs(pipe, pipe_single_file)
def test_single_file_setting_pipeline_dtype_to_fp16(self):
adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16)
single_file_pipe = self.pipeline_class.from_single_file(
self.ckpt_path, adapter=adapter, torch_dtype=torch.float16
)
super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe)

View File

@@ -195,3 +195,12 @@ class StableDiffusionXLControlNetPipelineSingleFileSlowTests(unittest.TestCase,
local_files_only=True,
)
super()._compare_component_configs(pipe, pipe_single_file)
def test_single_file_setting_pipeline_dtype_to_fp16(self):
controlnet = ControlNetModel.from_pretrained(
"diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16, variant="fp16"
)
single_file_pipe = self.pipeline_class.from_single_file(
self.ckpt_path, controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
)
super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe)