Compare commits

...

2 Commits

Author SHA1 Message Date
yiyixuxu
8d42a97a40 style 2026-01-22 03:27:46 +01:00
yiyixuxu
39a6a0c171 up 2026-01-22 03:27:24 +01:00

View File

@@ -1552,11 +1552,11 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
else:
logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")
self.blocks = blocks
self._blocks = blocks
self._components_manager = components_manager
self._collection = collection
self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components}
self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs}
self._component_specs = {spec.name: deepcopy(spec) for spec in self._blocks.expected_components}
self._config_specs = {spec.name: deepcopy(spec) for spec in self._blocks.expected_configs}
# update component_specs and config_specs based on modular_model_index.json
if modular_config_dict is not None:
@@ -1603,7 +1603,9 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
for name, config_spec in self._config_specs.items():
default_configs[name] = config_spec.default
self.register_to_config(**default_configs)
self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None)
self.register_to_config(
_blocks_class_name=self._blocks.__class__.__name__ if self._blocks is not None else None
)
@property
def default_call_parameters(self) -> Dict[str, Any]:
@@ -1612,7 +1614,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
- Dictionary mapping input names to their default values
"""
params = {}
for input_param in self.blocks.inputs:
for input_param in self._blocks.inputs:
params[input_param.name] = input_param.default
return params
@@ -1775,7 +1777,15 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
Returns:
- The docstring of the pipeline blocks
"""
return self.blocks.doc
return self._blocks.doc
@property
def blocks(self) -> ModularPipelineBlocks:
"""
Returns:
- A copy of the pipeline blocks
"""
return deepcopy(self._blocks)
def register_components(self, **kwargs):
"""
@@ -2509,7 +2519,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
)
def set_progress_bar_config(self, **kwargs):
for sub_block_name, sub_block in self.blocks.sub_blocks.items():
for sub_block_name, sub_block in self._blocks.sub_blocks.items():
if hasattr(sub_block, "set_progress_bar_config"):
sub_block.set_progress_bar_config(**kwargs)
@@ -2563,7 +2573,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
# Add inputs to state, using defaults if not provided in the kwargs or the state
# if same input already in the state, will override it if provided in the kwargs
for expected_input_param in self.blocks.inputs:
for expected_input_param in self._blocks.inputs:
name = expected_input_param.name
default = expected_input_param.default
kwargs_type = expected_input_param.kwargs_type
@@ -2582,9 +2592,9 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
# Run the pipeline
with torch.no_grad():
try:
_, state = self.blocks(self, state)
_, state = self._blocks(self, state)
except Exception:
error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n"
error_msg = f"Error in block: ({self._blocks.__class__.__name__}):\n"
logger.error(error_msg)
raise