mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
multiple enhancements to model manager REACT API
1. add a /sync route for synchronizing the in-memory model lists to models.yaml, the models directory, and the autoimport directories. 2. add optional destination_directories to convert_model and merge_model operations. 3. add /ckpt_confs route for retrieving known legacy checkpoint configuration files. 4. add /search route for finding all models in a directory located in the server filesystem
This commit is contained in:
parent
ad076b1174
commit
8600aad12b
@ -132,13 +132,11 @@ async def import_model(
|
|||||||
"/{base_model}/{model_type}/{model_name}",
|
"/{base_model}/{model_type}/{model_name}",
|
||||||
operation_id="del_model",
|
operation_id="del_model",
|
||||||
responses={
|
responses={
|
||||||
204: {
|
204: { "description": "Model deleted successfully" },
|
||||||
"description": "Model deleted successfully"
|
404: { "description": "Model not found" }
|
||||||
},
|
|
||||||
404: {
|
|
||||||
"description": "Model not found"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
|
status_code = 204,
|
||||||
|
response_model = None,
|
||||||
)
|
)
|
||||||
async def delete_model(
|
async def delete_model(
|
||||||
base_model: BaseModelType = Path(description="Base model"),
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
@ -174,14 +172,17 @@ async def convert_model(
|
|||||||
base_model: BaseModelType = Path(description="Base model"),
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
model_type: ModelType = Path(description="The type of model"),
|
model_type: ModelType = Path(description="The type of model"),
|
||||||
model_name: str = Path(description="model name"),
|
model_name: str = Path(description="model name"),
|
||||||
|
convert_dest_directory: Optional[str] = Query(default=None, description="Save the converted model to the designated directory"),
|
||||||
) -> ConvertModelResponse:
|
) -> ConvertModelResponse:
|
||||||
"""Convert a checkpoint model into a diffusers model"""
|
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
try:
|
try:
|
||||||
logger.info(f"Converting model: {model_name}")
|
logger.info(f"Converting model: {model_name}")
|
||||||
|
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
|
||||||
ApiDependencies.invoker.services.model_manager.convert_model(model_name,
|
ApiDependencies.invoker.services.model_manager.convert_model(model_name,
|
||||||
base_model = base_model,
|
base_model = base_model,
|
||||||
model_type = model_type
|
model_type = model_type,
|
||||||
|
convert_dest_directory = dest,
|
||||||
)
|
)
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name,
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name,
|
||||||
base_model = base_model,
|
base_model = base_model,
|
||||||
@ -210,6 +211,36 @@ async def search_for_models(
|
|||||||
raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory")
|
raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory")
|
||||||
return ApiDependencies.invoker.services.model_manager.search_for_models([search_path])
|
return ApiDependencies.invoker.services.model_manager.search_for_models([search_path])
|
||||||
|
|
||||||
|
@models_router.get(
|
||||||
|
"/ckpt_confs",
|
||||||
|
operation_id="list_ckpt_configs",
|
||||||
|
responses={
|
||||||
|
200: { "description" : "paths retrieved successfully" },
|
||||||
|
},
|
||||||
|
status_code = 200,
|
||||||
|
response_model = List[pathlib.Path]
|
||||||
|
)
|
||||||
|
async def list_ckpt_configs(
|
||||||
|
)->List[pathlib.Path]:
|
||||||
|
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
|
||||||
|
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
|
||||||
|
|
||||||
|
|
||||||
|
@models_router.get(
|
||||||
|
"/sync",
|
||||||
|
operation_id="sync_to_config",
|
||||||
|
responses={
|
||||||
|
201: { "description": "synchronization successful" },
|
||||||
|
},
|
||||||
|
status_code = 201,
|
||||||
|
response_model = None
|
||||||
|
)
|
||||||
|
async def sync_to_config(
|
||||||
|
)->None:
|
||||||
|
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
|
||||||
|
in-memory data structures with disk data structures."""
|
||||||
|
return ApiDependencies.invoker.services.model_manager.sync_to_config()
|
||||||
|
|
||||||
@models_router.put(
|
@models_router.put(
|
||||||
"/merge/{base_model}",
|
"/merge/{base_model}",
|
||||||
operation_id="merge_models",
|
operation_id="merge_models",
|
||||||
@ -228,17 +259,21 @@ async def merge_models(
|
|||||||
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
||||||
force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False),
|
force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False),
|
||||||
|
merge_dest_directory: Optional[str] = Body(description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None)
|
||||||
) -> MergeModelResponse:
|
) -> MergeModelResponse:
|
||||||
"""Convert a checkpoint model into a diffusers model"""
|
"""Convert a checkpoint model into a diffusers model"""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
try:
|
try:
|
||||||
logger.info(f"Merging models: {model_names}")
|
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||||
|
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||||
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
|
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
|
||||||
base_model,
|
base_model,
|
||||||
merged_model_name or "+".join(model_names),
|
merged_model_name=merged_model_name or "+".join(model_names),
|
||||||
alpha,
|
alpha=alpha,
|
||||||
interp,
|
interp=interp,
|
||||||
force)
|
force=force,
|
||||||
|
merge_dest_directory = dest
|
||||||
|
)
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
|
||||||
base_model = base_model,
|
base_model = base_model,
|
||||||
model_type = ModelType.Main,
|
model_type = ModelType.Main,
|
||||||
|
@ -167,6 +167,15 @@ class ModelManagerServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_checkpoint_configs(
|
||||||
|
self
|
||||||
|
)->List[Path]:
|
||||||
|
"""
|
||||||
|
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def convert_model(
|
def convert_model(
|
||||||
self,
|
self,
|
||||||
@ -220,6 +229,7 @@ class ModelManagerServiceBase(ABC):
|
|||||||
alpha: Optional[float] = 0.5,
|
alpha: Optional[float] = 0.5,
|
||||||
interp: Optional[MergeInterpolationMethod] = None,
|
interp: Optional[MergeInterpolationMethod] = None,
|
||||||
force: Optional[bool] = False,
|
force: Optional[bool] = False,
|
||||||
|
merge_dest_directory: Optional[Path] = None
|
||||||
) -> AddModelResult:
|
) -> AddModelResult:
|
||||||
"""
|
"""
|
||||||
Merge two to three diffusrs pipeline models and save as a new model.
|
Merge two to three diffusrs pipeline models and save as a new model.
|
||||||
@ -228,6 +238,7 @@ class ModelManagerServiceBase(ABC):
|
|||||||
:param merged_model_name: Name of destination merged model
|
:param merged_model_name: Name of destination merged model
|
||||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||||
:param interp: Interpolation method. None (default)
|
:param interp: Interpolation method. None (default)
|
||||||
|
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -238,6 +249,15 @@ class ModelManagerServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def sync_to_config(self):
|
||||||
|
"""
|
||||||
|
Re-read models.yaml, rescan the models directory, and reimport models
|
||||||
|
in the autoimport directories. Call after making changes outside the
|
||||||
|
model manager API.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def commit(self, conf_file: Optional[Path] = None) -> None:
|
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||||
"""
|
"""
|
||||||
@ -438,16 +458,18 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
"""
|
"""
|
||||||
Delete the named model from configuration. If delete_files is true,
|
Delete the named model from configuration. If delete_files is true,
|
||||||
then the underlying weight file or diffusers directory will be deleted
|
then the underlying weight file or diffusers directory will be deleted
|
||||||
as well. Call commit() to write to disk.
|
as well.
|
||||||
"""
|
"""
|
||||||
self.logger.debug(f'delete model {model_name}')
|
self.logger.debug(f'delete model {model_name}')
|
||||||
self.mgr.del_model(model_name, base_model, model_type)
|
self.mgr.del_model(model_name, base_model, model_type)
|
||||||
|
self.mgr.commit()
|
||||||
|
|
||||||
def convert_model(
|
def convert_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: Union[ModelType.Main,ModelType.Vae],
|
model_type: Union[ModelType.Main,ModelType.Vae],
|
||||||
|
convert_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"),
|
||||||
) -> AddModelResult:
|
) -> AddModelResult:
|
||||||
"""
|
"""
|
||||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||||
@ -456,13 +478,14 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
:param model_name: Name of the model to convert
|
:param model_name: Name of the model to convert
|
||||||
:param base_model: Base model type
|
:param base_model: Base model type
|
||||||
:param model_type: Type of model ['vae' or 'main']
|
:param model_type: Type of model ['vae' or 'main']
|
||||||
|
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
|
||||||
|
|
||||||
This will raise a ValueError unless the model is not a checkpoint. It will
|
This will raise a ValueError unless the model is not a checkpoint. It will
|
||||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||||
directory already in place.
|
directory already in place.
|
||||||
"""
|
"""
|
||||||
self.logger.debug(f'convert model {model_name}')
|
self.logger.debug(f'convert model {model_name}')
|
||||||
return self.mgr.convert_model(model_name, base_model, model_type)
|
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
|
||||||
|
|
||||||
def commit(self, conf_file: Optional[Path]=None):
|
def commit(self, conf_file: Optional[Path]=None):
|
||||||
"""
|
"""
|
||||||
@ -543,6 +566,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
alpha: Optional[float] = 0.5,
|
alpha: Optional[float] = 0.5,
|
||||||
interp: Optional[MergeInterpolationMethod] = None,
|
interp: Optional[MergeInterpolationMethod] = None,
|
||||||
force: Optional[bool] = False,
|
force: Optional[bool] = False,
|
||||||
|
merge_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"),
|
||||||
) -> AddModelResult:
|
) -> AddModelResult:
|
||||||
"""
|
"""
|
||||||
Merge two to three diffusrs pipeline models and save as a new model.
|
Merge two to three diffusrs pipeline models and save as a new model.
|
||||||
@ -551,6 +575,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
:param merged_model_name: Name of destination merged model
|
:param merged_model_name: Name of destination merged model
|
||||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||||
:param interp: Interpolation method. None (default)
|
:param interp: Interpolation method. None (default)
|
||||||
|
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||||
"""
|
"""
|
||||||
merger = ModelMerger(self.mgr)
|
merger = ModelMerger(self.mgr)
|
||||||
try:
|
try:
|
||||||
@ -561,6 +586,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
alpha = alpha,
|
alpha = alpha,
|
||||||
interp = interp,
|
interp = interp,
|
||||||
force = force,
|
force = force,
|
||||||
|
merge_dest_directory=merge_dest_directory,
|
||||||
)
|
)
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
raise ValueError(e)
|
raise ValueError(e)
|
||||||
@ -572,3 +598,20 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
"""
|
"""
|
||||||
search = FindModels(directory,self.logger)
|
search = FindModels(directory,self.logger)
|
||||||
return search.list_models()
|
return search.list_models()
|
||||||
|
|
||||||
|
def sync_to_config(self):
|
||||||
|
"""
|
||||||
|
Re-read models.yaml, rescan the models directory, and reimport models
|
||||||
|
in the autoimport directories. Call after making changes outside the
|
||||||
|
model manager API.
|
||||||
|
"""
|
||||||
|
return self.mgr.sync_to_config()
|
||||||
|
|
||||||
|
def list_checkpoint_configs(self)->List[Path]:
|
||||||
|
"""
|
||||||
|
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||||
|
"""
|
||||||
|
config = self.mgr.app_config
|
||||||
|
conf_path = config.legacy_conf_path
|
||||||
|
root_path = config.root_path
|
||||||
|
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob('**/*.yaml')]
|
||||||
|
@ -324,15 +324,6 @@ class ModelManager(object):
|
|||||||
# TODO: metadata not found
|
# TODO: metadata not found
|
||||||
# TODO: version check
|
# TODO: version check
|
||||||
|
|
||||||
self.models = dict()
|
|
||||||
for model_key, model_config in config.items():
|
|
||||||
model_name, base_model, model_type = self.parse_key(model_key)
|
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
|
||||||
# alias for config file
|
|
||||||
model_config["model_format"] = model_config.pop("format")
|
|
||||||
self.models[model_key] = model_class.create_config(**model_config)
|
|
||||||
|
|
||||||
# check config version number and update on disk/RAM if necessary
|
|
||||||
self.app_config = InvokeAIAppConfig.get_config()
|
self.app_config = InvokeAIAppConfig.get_config()
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.cache = ModelCache(
|
self.cache = ModelCache(
|
||||||
@ -343,11 +334,41 @@ class ModelManager(object):
|
|||||||
sequential_offload = sequential_offload,
|
sequential_offload = sequential_offload,
|
||||||
logger = logger,
|
logger = logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._read_models(config)
|
||||||
|
|
||||||
|
def _read_models(self, config: Optional[DictConfig] = None):
|
||||||
|
if not config:
|
||||||
|
if self.config_path:
|
||||||
|
config = OmegaConf.load(self.config_path)
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.models = dict()
|
||||||
|
for model_key, model_config in config.items():
|
||||||
|
if model_key.startswith('_'):
|
||||||
|
continue
|
||||||
|
model_name, base_model, model_type = self.parse_key(model_key)
|
||||||
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
|
# alias for config file
|
||||||
|
model_config["model_format"] = model_config.pop("format")
|
||||||
|
self.models[model_key] = model_class.create_config(**model_config)
|
||||||
|
|
||||||
|
# check config version number and update on disk/RAM if necessary
|
||||||
self.cache_keys = dict()
|
self.cache_keys = dict()
|
||||||
|
|
||||||
# add controlnet, lora and textual_inversion models from disk
|
# add controlnet, lora and textual_inversion models from disk
|
||||||
self.scan_models_directory()
|
self.scan_models_directory()
|
||||||
|
|
||||||
|
def sync_to_config(self):
|
||||||
|
"""
|
||||||
|
Call this when `models.yaml` has been changed externally.
|
||||||
|
This will reinitialize internal data structures
|
||||||
|
"""
|
||||||
|
# Reread models directory; note that this will reinitialize the cache,
|
||||||
|
# causing otherwise unreferenced models to be removed from memory
|
||||||
|
self._read_models()
|
||||||
|
|
||||||
def model_exists(
|
def model_exists(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@ -528,7 +549,10 @@ class ModelManager(object):
|
|||||||
model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold)
|
model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold)
|
||||||
models = []
|
models = []
|
||||||
for model_key in model_keys:
|
for model_key in model_keys:
|
||||||
model_config = self.models[model_key]
|
model_config = self.models.get(model_key)
|
||||||
|
if not model_config:
|
||||||
|
self.logger.error(f'Unknown model {model_name}')
|
||||||
|
raise KeyError(f'Unknown model {model_name}')
|
||||||
|
|
||||||
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
||||||
if base_model is not None and cur_base_model != base_model:
|
if base_model is not None and cur_base_model != base_model:
|
||||||
@ -651,6 +675,7 @@ class ModelManager(object):
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: Union[ModelType.Main,ModelType.Vae],
|
model_type: Union[ModelType.Main,ModelType.Vae],
|
||||||
|
dest_directory: Optional[Path]=None,
|
||||||
) -> AddModelResult:
|
) -> AddModelResult:
|
||||||
'''
|
'''
|
||||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||||
@ -677,14 +702,14 @@ class ModelManager(object):
|
|||||||
)
|
)
|
||||||
checkpoint_path = self.app_config.root_path / info["path"]
|
checkpoint_path = self.app_config.root_path / info["path"]
|
||||||
old_diffusers_path = self.app_config.models_path / model.location
|
old_diffusers_path = self.app_config.models_path / model.location
|
||||||
new_diffusers_path = self.app_config.models_path / base_model.value / model_type.value / model_name
|
new_diffusers_path = (dest_directory or self.app_config.models_path / base_model.value / model_type.value) / model_name
|
||||||
if new_diffusers_path.exists():
|
if new_diffusers_path.exists():
|
||||||
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
|
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
move(old_diffusers_path,new_diffusers_path)
|
move(old_diffusers_path,new_diffusers_path)
|
||||||
info["model_format"] = "diffusers"
|
info["model_format"] = "diffusers"
|
||||||
info["path"] = str(new_diffusers_path.relative_to(self.app_config.root_path))
|
info["path"] = str(new_diffusers_path) if dest_directory else str(new_diffusers_path.relative_to(self.app_config.root_path))
|
||||||
info.pop('config')
|
info.pop('config')
|
||||||
|
|
||||||
result = self.add_model(model_name, base_model, model_type,
|
result = self.add_model(model_name, base_model, model_type,
|
||||||
|
@ -11,7 +11,7 @@ from enum import Enum
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from diffusers import logging as dlogging
|
from diffusers import logging as dlogging
|
||||||
from typing import List, Union
|
from typing import List, Union, Optional
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
@ -74,6 +74,7 @@ class ModelMerger(object):
|
|||||||
alpha: float = 0.5,
|
alpha: float = 0.5,
|
||||||
interp: MergeInterpolationMethod = None,
|
interp: MergeInterpolationMethod = None,
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
|
merge_dest_directory: Optional[Path] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> AddModelResult:
|
) -> AddModelResult:
|
||||||
"""
|
"""
|
||||||
@ -85,7 +86,7 @@ class ModelMerger(object):
|
|||||||
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
|
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
|
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
|
||||||
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||||
|
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||||
"""
|
"""
|
||||||
@ -111,7 +112,7 @@ class ModelMerger(object):
|
|||||||
merged_pipe = self.merge_diffusion_models(
|
merged_pipe = self.merge_diffusion_models(
|
||||||
model_paths, alpha, merge_method, force, **kwargs
|
model_paths, alpha, merge_method, force, **kwargs
|
||||||
)
|
)
|
||||||
dump_path = config.models_path / base_model.value / ModelType.Main.value
|
dump_path = Path(merge_dest_directory) if merge_dest_directory else config.models_path / base_model.value / ModelType.Main.value
|
||||||
dump_path.mkdir(parents=True, exist_ok=True)
|
dump_path.mkdir(parents=True, exist_ok=True)
|
||||||
dump_path = dump_path / merged_model_name
|
dump_path = dump_path / merged_model_name
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user