mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'lstein/model-manager-router-api'
This commit is contained in:
commit
f78f10bef6
@ -206,10 +206,10 @@ async def convert_model(
|
|||||||
async def merge_models(
|
async def merge_models(
|
||||||
base_model: BaseModelType = Path(description="Base model"),
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
|
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
|
||||||
merged_model_name: Optional[str] = Body(description = "Name of destination model"),
|
merged_model_name: Optional[str] = Body(description="Name of destination model"),
|
||||||
alpha: Optional[float] = Body(description = "Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||||
interp: Union[MergeInterpolationMethod, None] = Body(description = "Interpolation method"),
|
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
||||||
force: Optional[bool] = Body(description = "Force merging of models created with different versions of diffusers", default=False),
|
force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False),
|
||||||
) -> MergeModelResponse:
|
) -> MergeModelResponse:
|
||||||
"""Convert a checkpoint model into a diffusers model"""
|
"""Convert a checkpoint model into a diffusers model"""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
@ -546,7 +546,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
:param interp: Interpolation method. None (default)
|
:param interp: Interpolation method. None (default)
|
||||||
"""
|
"""
|
||||||
merger = ModelMerger(self.mgr)
|
merger = ModelMerger(self.mgr)
|
||||||
return merger.merge_diffusion_models_and_save(
|
try:
|
||||||
|
result = merger.merge_diffusion_models_and_save(
|
||||||
model_names = model_names,
|
model_names = model_names,
|
||||||
base_model = base_model,
|
base_model = base_model,
|
||||||
merged_model_name = merged_model_name,
|
merged_model_name = merged_model_name,
|
||||||
@ -554,3 +555,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
interp = interp,
|
interp = interp,
|
||||||
force = force,
|
force = force,
|
||||||
)
|
)
|
||||||
|
except AssertionError as e:
|
||||||
|
raise ValueError(e)
|
||||||
|
return result
|
||||||
|
@ -15,14 +15,13 @@ from typing import List, Union
|
|||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
|
from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
|
||||||
|
|
||||||
class MergeInterpolationMethod(str, Enum):
|
class MergeInterpolationMethod(str, Enum):
|
||||||
|
WeightedSum = "weighted_sum"
|
||||||
Sigmoid = "sigmoid"
|
Sigmoid = "sigmoid"
|
||||||
InvSigmoid = "inv_sigmoid"
|
InvSigmoid = "inv_sigmoid"
|
||||||
AddDifference = "add_difference"
|
AddDifference = "add_difference"
|
||||||
WeightedSum = "weighted_sum"
|
|
||||||
|
|
||||||
class ModelMerger(object):
|
class ModelMerger(object):
|
||||||
def __init__(self, manager: ModelManager):
|
def __init__(self, manager: ModelManager):
|
||||||
@ -99,13 +98,16 @@ class ModelMerger(object):
|
|||||||
info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main)
|
info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main)
|
||||||
assert info, f"model {mod}, base_model {base_model}, is unknown"
|
assert info, f"model {mod}, base_model {base_model}, is unknown"
|
||||||
assert info["model_format"] == "diffusers", f"{mod} is not a diffusers model. It must be optimized before merging"
|
assert info["model_format"] == "diffusers", f"{mod} is not a diffusers model. It must be optimized before merging"
|
||||||
assert info["variant"] == "normal", (f"{mod} is a {info['variant']} model, which cannot currently be merged")
|
assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged"
|
||||||
|
assert len(model_names) <= 2 or \
|
||||||
|
interp==MergeInterpolationMethod.AddDifference, "When merging three models, only the 'add_difference' merge method is supported"
|
||||||
# pick up the first model's vae
|
# pick up the first model's vae
|
||||||
if mod == model_names[0]:
|
if mod == model_names[0]:
|
||||||
vae = info.get("vae")
|
vae = info.get("vae")
|
||||||
model_paths.extend([config.root_path / info["path"]])
|
model_paths.extend([config.root_path / info["path"]])
|
||||||
|
|
||||||
merge_method = None if interp == 'weighted_sum' else MergeInterpolationMethod(interp)
|
merge_method = None if interp == 'weighted_sum' else MergeInterpolationMethod(interp)
|
||||||
|
logger.debug(f'interp = {interp}, merge_method={merge_method}')
|
||||||
merged_pipe = self.merge_diffusion_models(
|
merged_pipe = self.merge_diffusion_models(
|
||||||
model_paths, alpha, merge_method, force, **kwargs
|
model_paths, alpha, merge_method, force, **kwargs
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user