mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into feat/clip_skip
This commit is contained in:
commit
7aa918677e
@ -1,75 +1,30 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein
|
||||||
|
|
||||||
from typing import Literal, Optional, Union
|
|
||||||
|
|
||||||
from fastapi import Query, Body
|
from typing import Literal, List, Optional, Union
|
||||||
from fastapi.routing import APIRouter, HTTPException
|
|
||||||
from pydantic import BaseModel, Field, parse_obj_as
|
from fastapi import Body, Path, Query, Response
|
||||||
from ..dependencies import ApiDependencies
|
from fastapi.routing import APIRouter
|
||||||
|
from pydantic import BaseModel, parse_obj_as
|
||||||
|
from starlette.exceptions import HTTPException
|
||||||
|
|
||||||
from invokeai.backend import BaseModelType, ModelType
|
from invokeai.backend import BaseModelType, ModelType
|
||||||
from invokeai.backend.model_management import AddModelResult
|
from invokeai.backend.model_management.models import (
|
||||||
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
|
OPENAPI_MODEL_CONFIGS,
|
||||||
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
SchedulerPredictionType,
|
||||||
|
)
|
||||||
|
from invokeai.backend.model_management import MergeInterpolationMethod
|
||||||
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||||
|
|
||||||
class VaeRepo(BaseModel):
|
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
repo_id: str = Field(description="The repo ID to use for this VAE")
|
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
path: Optional[str] = Field(description="The path to the VAE")
|
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
subfolder: Optional[str] = Field(description="The subfolder to use for this VAE")
|
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
|
|
||||||
class ModelInfo(BaseModel):
|
|
||||||
description: Optional[str] = Field(description="A description of the model")
|
|
||||||
model_name: str = Field(description="The name of the model")
|
|
||||||
model_type: str = Field(description="The type of the model")
|
|
||||||
|
|
||||||
class DiffusersModelInfo(ModelInfo):
|
|
||||||
format: Literal['folder'] = 'folder'
|
|
||||||
|
|
||||||
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
|
|
||||||
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
|
|
||||||
path: Optional[str] = Field(description="The path to the model")
|
|
||||||
|
|
||||||
class CkptModelInfo(ModelInfo):
|
|
||||||
format: Literal['ckpt'] = 'ckpt'
|
|
||||||
|
|
||||||
config: str = Field(description="The path to the model config")
|
|
||||||
weights: str = Field(description="The path to the model weights")
|
|
||||||
vae: str = Field(description="The path to the model VAE")
|
|
||||||
width: Optional[int] = Field(description="The width of the model")
|
|
||||||
height: Optional[int] = Field(description="The height of the model")
|
|
||||||
|
|
||||||
class SafetensorsModelInfo(CkptModelInfo):
|
|
||||||
format: Literal['safetensors'] = 'safetensors'
|
|
||||||
|
|
||||||
class CreateModelRequest(BaseModel):
|
|
||||||
name: str = Field(description="The name of the model")
|
|
||||||
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
|
|
||||||
|
|
||||||
class CreateModelResponse(BaseModel):
|
|
||||||
name: str = Field(description="The name of the new model")
|
|
||||||
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
|
|
||||||
status: str = Field(description="The status of the API response")
|
|
||||||
|
|
||||||
class ImportModelResponse(BaseModel):
|
|
||||||
name: str = Field(description="The name of the imported model")
|
|
||||||
# base_model: str = Field(description="The base model")
|
|
||||||
# model_type: str = Field(description="The model type")
|
|
||||||
info: AddModelResult = Field(description="The model info")
|
|
||||||
status: str = Field(description="The status of the API response")
|
|
||||||
|
|
||||||
class ConversionRequest(BaseModel):
|
|
||||||
name: str = Field(description="The name of the new model")
|
|
||||||
info: CkptModelInfo = Field(description="The converted model info")
|
|
||||||
save_location: str = Field(description="The path to save the converted model weights")
|
|
||||||
|
|
||||||
class ConvertedModelResponse(BaseModel):
|
|
||||||
name: str = Field(description="The name of the new model")
|
|
||||||
info: DiffusersModelInfo = Field(description="The converted model info")
|
|
||||||
|
|
||||||
class ModelsList(BaseModel):
|
class ModelsList(BaseModel):
|
||||||
models: list[MODEL_CONFIGS]
|
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||||
|
|
||||||
|
|
||||||
@models_router.get(
|
@models_router.get(
|
||||||
"/",
|
"/",
|
||||||
@ -77,75 +32,103 @@ class ModelsList(BaseModel):
|
|||||||
responses={200: {"model": ModelsList }},
|
responses={200: {"model": ModelsList }},
|
||||||
)
|
)
|
||||||
async def list_models(
|
async def list_models(
|
||||||
base_model: Optional[BaseModelType] = Query(
|
base_model: Optional[BaseModelType] = Query(default=None, description="Base model"),
|
||||||
default=None, description="Base model"
|
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||||
),
|
|
||||||
model_type: Optional[ModelType] = Query(
|
|
||||||
default=None, description="The type of model to get"
|
|
||||||
),
|
|
||||||
) -> ModelsList:
|
) -> ModelsList:
|
||||||
"""Gets a list of models"""
|
"""Gets a list of models"""
|
||||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)
|
models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)
|
||||||
models = parse_obj_as(ModelsList, { "models": models_raw })
|
models = parse_obj_as(ModelsList, { "models": models_raw })
|
||||||
return models
|
return models
|
||||||
|
|
||||||
@models_router.post(
|
@models_router.patch(
|
||||||
"/",
|
"/{base_model}/{model_type}/{model_name}",
|
||||||
operation_id="update_model",
|
operation_id="update_model",
|
||||||
responses={200: {"status": "success"}},
|
responses={200: {"description" : "The model was updated successfully"},
|
||||||
|
404: {"description" : "The model could not be found"},
|
||||||
|
400: {"description" : "Bad request"}
|
||||||
|
},
|
||||||
|
status_code = 200,
|
||||||
|
response_model = UpdateModelResponse,
|
||||||
)
|
)
|
||||||
async def update_model(
|
async def update_model(
|
||||||
model_request: CreateModelRequest
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
) -> CreateModelResponse:
|
model_type: ModelType = Path(description="The type of model"),
|
||||||
|
model_name: str = Path(description="model name"),
|
||||||
|
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||||
|
) -> UpdateModelResponse:
|
||||||
""" Add Model """
|
""" Add Model """
|
||||||
model_request_info = model_request.info
|
try:
|
||||||
info_dict = model_request_info.dict()
|
ApiDependencies.invoker.services.model_manager.update_model(
|
||||||
model_response = CreateModelResponse(name=model_request.name, info=model_request.info, status="success")
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
ApiDependencies.invoker.services.model_manager.add_model(
|
model_type=model_type,
|
||||||
model_name=model_request.name,
|
model_attributes=info.dict()
|
||||||
model_attributes=info_dict,
|
)
|
||||||
clobber=True,
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
)
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
)
|
||||||
|
model_response = parse_obj_as(UpdateModelResponse, model_raw)
|
||||||
|
except KeyError as e:
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
@models_router.post(
|
@models_router.post(
|
||||||
"/import",
|
"/",
|
||||||
operation_id="import_model",
|
operation_id="import_model",
|
||||||
responses= {
|
responses= {
|
||||||
201: {"description" : "The model imported successfully"},
|
201: {"description" : "The model imported successfully"},
|
||||||
404: {"description" : "The model could not be found"},
|
404: {"description" : "The model could not be found"},
|
||||||
|
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
|
||||||
|
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
||||||
},
|
},
|
||||||
status_code=201,
|
status_code=201,
|
||||||
response_model=ImportModelResponse
|
response_model=ImportModelResponse
|
||||||
)
|
)
|
||||||
async def import_model(
|
async def import_model(
|
||||||
name: str = Query(description="A model path, repo_id or URL to import"),
|
location: str = Body(description="A model path, repo_id or URL to import"),
|
||||||
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = Query(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
|
||||||
|
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
||||||
) -> ImportModelResponse:
|
) -> ImportModelResponse:
|
||||||
""" Add a model using its local path, repo_id, or remote URL """
|
""" Add a model using its local path, repo_id, or remote URL """
|
||||||
items_to_import = {name}
|
|
||||||
|
items_to_import = {location}
|
||||||
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
try:
|
||||||
items_to_import = items_to_import,
|
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||||
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
|
items_to_import = items_to_import,
|
||||||
)
|
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
|
||||||
if info := installed_models.get(name):
|
|
||||||
logger.info(f'Successfully imported {name}, got {info}')
|
|
||||||
return ImportModelResponse(
|
|
||||||
name = name,
|
|
||||||
info = info,
|
|
||||||
status = "success",
|
|
||||||
)
|
)
|
||||||
else:
|
info = installed_models.get(location)
|
||||||
logger.error(f'Model {name} not imported')
|
|
||||||
raise HTTPException(status_code=404, detail=f'Model {name} not found')
|
if not info:
|
||||||
|
logger.error("Import failed")
|
||||||
|
raise HTTPException(status_code=424)
|
||||||
|
|
||||||
|
logger.info(f'Successfully imported {location}, got {info}')
|
||||||
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
|
model_name=info.name,
|
||||||
|
base_model=info.base_model,
|
||||||
|
model_type=info.model_type
|
||||||
|
)
|
||||||
|
return parse_obj_as(ImportModelResponse, model_raw)
|
||||||
|
|
||||||
|
except KeyError as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@models_router.delete(
|
@models_router.delete(
|
||||||
"/{model_name}",
|
"/{base_model}/{model_type}/{model_name}",
|
||||||
operation_id="del_model",
|
operation_id="del_model",
|
||||||
responses={
|
responses={
|
||||||
204: {
|
204: {
|
||||||
@ -156,144 +139,95 @@ async def import_model(
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def delete_model(model_name: str) -> None:
|
async def delete_model(
|
||||||
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
|
model_type: ModelType = Path(description="The type of model"),
|
||||||
|
model_name: str = Path(description="model name"),
|
||||||
|
) -> Response:
|
||||||
"""Delete Model"""
|
"""Delete Model"""
|
||||||
model_names = ApiDependencies.invoker.services.model_manager.model_names()
|
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
model_exists = model_name in model_names
|
|
||||||
|
|
||||||
# check if model exists
|
|
||||||
logger.info(f"Checking for model {model_name}...")
|
|
||||||
|
|
||||||
if model_exists:
|
|
||||||
logger.info(f"Deleting Model: {model_name}")
|
|
||||||
ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True)
|
|
||||||
logger.info(f"Model Deleted: {model_name}")
|
|
||||||
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
|
|
||||||
|
|
||||||
else:
|
try:
|
||||||
logger.error("Model not found")
|
ApiDependencies.invoker.services.model_manager.del_model(model_name,
|
||||||
|
base_model = base_model,
|
||||||
|
model_type = model_type
|
||||||
|
)
|
||||||
|
logger.info(f"Deleted model: {model_name}")
|
||||||
|
return Response(status_code=204)
|
||||||
|
except KeyError:
|
||||||
|
logger.error(f"Model not found: {model_name}")
|
||||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||||
|
|
||||||
|
|
||||||
# @socketio.on("convertToDiffusers")
|
@models_router.put(
|
||||||
# def convert_to_diffusers(model_to_convert: dict):
|
"/convert/{base_model}/{model_type}/{model_name}",
|
||||||
# try:
|
operation_id="convert_model",
|
||||||
# if model_info := self.generate.model_manager.model_info(
|
responses={
|
||||||
# model_name=model_to_convert["model_name"]
|
200: { "description": "Model converted successfully" },
|
||||||
# ):
|
400: {"description" : "Bad request" },
|
||||||
# if "weights" in model_info:
|
404: { "description": "Model not found" },
|
||||||
# ckpt_path = Path(model_info["weights"])
|
},
|
||||||
# original_config_file = Path(model_info["config"])
|
status_code = 200,
|
||||||
# model_name = model_to_convert["model_name"]
|
response_model = ConvertModelResponse,
|
||||||
# model_description = model_info["description"]
|
)
|
||||||
# else:
|
async def convert_model(
|
||||||
# self.socketio.emit(
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
# "error", {"message": "Model is not a valid checkpoint file"}
|
model_type: ModelType = Path(description="The type of model"),
|
||||||
# )
|
model_name: str = Path(description="model name"),
|
||||||
# else:
|
) -> ConvertModelResponse:
|
||||||
# self.socketio.emit(
|
"""Convert a checkpoint model into a diffusers model"""
|
||||||
# "error", {"message": "Could not retrieve model info."}
|
logger = ApiDependencies.invoker.services.logger
|
||||||
# )
|
try:
|
||||||
|
logger.info(f"Converting model: {model_name}")
|
||||||
# if not ckpt_path.is_absolute():
|
ApiDependencies.invoker.services.model_manager.convert_model(model_name,
|
||||||
# ckpt_path = Path(Globals.root, ckpt_path)
|
base_model = base_model,
|
||||||
|
model_type = model_type
|
||||||
# if original_config_file and not original_config_file.is_absolute():
|
)
|
||||||
# original_config_file = Path(Globals.root, original_config_file)
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name,
|
||||||
|
base_model = base_model,
|
||||||
# diffusers_path = Path(
|
model_type = model_type)
|
||||||
# ckpt_path.parent.absolute(), f"{model_name}_diffusers"
|
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||||
# )
|
except KeyError:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||||
# if model_to_convert["save_location"] == "root":
|
except ValueError as e:
|
||||||
# diffusers_path = Path(
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
# global_converted_ckpts_dir(), f"{model_name}_diffusers"
|
return response
|
||||||
# )
|
|
||||||
|
@models_router.put(
|
||||||
# if (
|
"/merge/{base_model}",
|
||||||
# model_to_convert["save_location"] == "custom"
|
operation_id="merge_models",
|
||||||
# and model_to_convert["custom_location"] is not None
|
responses={
|
||||||
# ):
|
200: { "description": "Model converted successfully" },
|
||||||
# diffusers_path = Path(
|
400: { "description": "Incompatible models" },
|
||||||
# model_to_convert["custom_location"], f"{model_name}_diffusers"
|
404: { "description": "One or more models not found" },
|
||||||
# )
|
},
|
||||||
|
status_code = 200,
|
||||||
# if diffusers_path.exists():
|
response_model = MergeModelResponse,
|
||||||
# shutil.rmtree(diffusers_path)
|
)
|
||||||
|
async def merge_models(
|
||||||
# self.generate.model_manager.convert_and_import(
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
# ckpt_path,
|
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
|
||||||
# diffusers_path,
|
merged_model_name: Optional[str] = Body(description="Name of destination model"),
|
||||||
# model_name=model_name,
|
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||||
# model_description=model_description,
|
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
||||||
# vae=None,
|
force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False),
|
||||||
# original_config_file=original_config_file,
|
) -> MergeModelResponse:
|
||||||
# commit_to_conf=opt.conf,
|
"""Convert a checkpoint model into a diffusers model"""
|
||||||
# )
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
try:
|
||||||
# new_model_list = self.generate.model_manager.list_models()
|
logger.info(f"Merging models: {model_names}")
|
||||||
# socketio.emit(
|
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
|
||||||
# "modelConverted",
|
base_model,
|
||||||
# {
|
merged_model_name or "+".join(model_names),
|
||||||
# "new_model_name": model_name,
|
alpha,
|
||||||
# "model_list": new_model_list,
|
interp,
|
||||||
# "update": True,
|
force)
|
||||||
# },
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
|
||||||
# )
|
base_model = base_model,
|
||||||
# print(f">> Model Converted: {model_name}")
|
model_type = ModelType.Main,
|
||||||
# except Exception as e:
|
)
|
||||||
# self.handle_exceptions(e)
|
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||||
|
except KeyError:
|
||||||
# @socketio.on("mergeDiffusersModels")
|
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
|
||||||
# def merge_diffusers_models(model_merge_info: dict):
|
except ValueError as e:
|
||||||
# try:
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
# models_to_merge = model_merge_info["models_to_merge"]
|
return response
|
||||||
# model_ids_or_paths = [
|
|
||||||
# self.generate.model_manager.model_name_or_path(x)
|
|
||||||
# for x in models_to_merge
|
|
||||||
# ]
|
|
||||||
# merged_pipe = merge_diffusion_models(
|
|
||||||
# model_ids_or_paths,
|
|
||||||
# model_merge_info["alpha"],
|
|
||||||
# model_merge_info["interp"],
|
|
||||||
# model_merge_info["force"],
|
|
||||||
# )
|
|
||||||
|
|
||||||
# dump_path = global_models_dir() / "merged_models"
|
|
||||||
# if model_merge_info["model_merge_save_path"] is not None:
|
|
||||||
# dump_path = Path(model_merge_info["model_merge_save_path"])
|
|
||||||
|
|
||||||
# os.makedirs(dump_path, exist_ok=True)
|
|
||||||
# dump_path = dump_path / model_merge_info["merged_model_name"]
|
|
||||||
# merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
|
||||||
|
|
||||||
# merged_model_config = dict(
|
|
||||||
# model_name=model_merge_info["merged_model_name"],
|
|
||||||
# description=f'Merge of models {", ".join(models_to_merge)}',
|
|
||||||
# commit_to_conf=opt.conf,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# if vae := self.generate.model_manager.config[models_to_merge[0]].get(
|
|
||||||
# "vae", None
|
|
||||||
# ):
|
|
||||||
# print(f">> Using configured VAE assigned to {models_to_merge[0]}")
|
|
||||||
# merged_model_config.update(vae=vae)
|
|
||||||
|
|
||||||
# self.generate.model_manager.import_diffuser_model(
|
|
||||||
# dump_path, **merged_model_config
|
|
||||||
# )
|
|
||||||
# new_model_list = self.generate.model_manager.list_models()
|
|
||||||
|
|
||||||
# socketio.emit(
|
|
||||||
# "modelsMerged",
|
|
||||||
# {
|
|
||||||
# "merged_models": models_to_merge,
|
|
||||||
# "merged_model_name": model_merge_info["merged_model_name"],
|
|
||||||
# "model_list": new_model_list,
|
|
||||||
# "update": True,
|
|
||||||
# },
|
|
||||||
# )
|
|
||||||
# print(f">> Models Merged: {models_to_merge}")
|
|
||||||
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
|
|
||||||
# except Exception as e:
|
|
||||||
|
@ -2,22 +2,29 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import torch
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING
|
from pydantic import Field
|
||||||
from dataclasses import dataclass
|
from typing import Optional, Union, Callable, List, Tuple, TYPE_CHECKING
|
||||||
|
from types import ModuleType
|
||||||
|
|
||||||
from invokeai.backend.model_management.model_manager import (
|
from invokeai.backend.model_management import (
|
||||||
ModelManager,
|
ModelManager,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
ModelType,
|
ModelType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
|
AddModelResult,
|
||||||
|
SchedulerPredictionType,
|
||||||
|
ModelMerger,
|
||||||
|
MergeInterpolationMethod,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
from invokeai.app.models.exceptions import CanceledException
|
from invokeai.app.models.exceptions import CanceledException
|
||||||
from .config import InvokeAIAppConfig
|
|
||||||
from ...backend.util import choose_precision, choose_torch_device
|
from ...backend.util import choose_precision, choose_torch_device
|
||||||
|
from .config import InvokeAIAppConfig
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
|
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
|
||||||
@ -30,7 +37,7 @@ class ModelManagerServiceBase(ABC):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: InvokeAIAppConfig,
|
config: InvokeAIAppConfig,
|
||||||
logger: types.ModuleType,
|
logger: ModuleType,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize with the path to the models.yaml config file.
|
Initialize with the path to the models.yaml config file.
|
||||||
@ -73,13 +80,7 @@ class ModelManagerServiceBase(ABC):
|
|||||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||||
"""
|
"""
|
||||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||||
"""
|
Uses the exact format as the omegaconf stanza.
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
|
||||||
"""
|
|
||||||
Returns a list of all the model names known.
|
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -101,7 +102,20 @@ class ModelManagerServiceBase(ABC):
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||||
|
"""
|
||||||
|
Return information about the model using the same format as list_models()
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||||
|
"""
|
||||||
|
Returns a list of all the model names known.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def add_model(
|
def add_model(
|
||||||
@ -111,7 +125,7 @@ class ModelManagerServiceBase(ABC):
|
|||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
model_attributes: dict,
|
model_attributes: dict,
|
||||||
clobber: bool = False
|
clobber: bool = False
|
||||||
) -> None:
|
) -> AddModelResult:
|
||||||
"""
|
"""
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
Update the named model with a dictionary of attributes. Will fail with an
|
||||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||||
@ -121,6 +135,24 @@ class ModelManagerServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
model_attributes: dict,
|
||||||
|
) -> AddModelResult:
|
||||||
|
"""
|
||||||
|
Update the named model with a dictionary of attributes. Will fail with a
|
||||||
|
KeyErrorException if the name does not already exist.
|
||||||
|
|
||||||
|
On a successful update, the config will be changed in memory. Will fail
|
||||||
|
with an assertion error if provided attributes are incorrect or
|
||||||
|
the model name is missing. Call commit() to write changes to disk.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def del_model(
|
def del_model(
|
||||||
self,
|
self,
|
||||||
@ -135,11 +167,32 @@ class ModelManagerServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def convert_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: Union[ModelType.Main,ModelType.Vae],
|
||||||
|
) -> AddModelResult:
|
||||||
|
"""
|
||||||
|
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||||
|
version and deleting the original checkpoint file if it is in the models
|
||||||
|
directory.
|
||||||
|
:param model_name: Name of the model to convert
|
||||||
|
:param base_model: Base model type
|
||||||
|
:param model_type: Type of model ['vae' or 'main']
|
||||||
|
|
||||||
|
This will raise a ValueError unless the model is not a checkpoint. It will
|
||||||
|
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||||
|
directory already in place.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def heuristic_import(self,
|
def heuristic_import(self,
|
||||||
items_to_import: Set[str],
|
items_to_import: set[str],
|
||||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
|
||||||
)->Dict[str, AddModelResult]:
|
)->dict[str, AddModelResult]:
|
||||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||||
successfully imported items.
|
successfully imported items.
|
||||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||||
@ -159,7 +212,27 @@ class ModelManagerServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def commit(self, conf_file: Path = None) -> None:
|
def merge_models(
|
||||||
|
self,
|
||||||
|
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
|
||||||
|
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
|
||||||
|
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||||
|
alpha: Optional[float] = 0.5,
|
||||||
|
interp: Optional[MergeInterpolationMethod] = None,
|
||||||
|
force: Optional[bool] = False,
|
||||||
|
) -> AddModelResult:
|
||||||
|
"""
|
||||||
|
Merge two to three diffusrs pipeline models and save as a new model.
|
||||||
|
:param model_names: List of 2-3 models to merge
|
||||||
|
:param base_model: Base model to use for all models
|
||||||
|
:param merged_model_name: Name of destination merged model
|
||||||
|
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||||
|
:param interp: Interpolation method. None (default)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Write current configuration out to the indicated file.
|
Write current configuration out to the indicated file.
|
||||||
If no conf_file is provided, then replaces the
|
If no conf_file is provided, then replaces the
|
||||||
@ -173,7 +246,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: InvokeAIAppConfig,
|
config: InvokeAIAppConfig,
|
||||||
logger: types.ModuleType,
|
logger: ModuleType,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize with the path to the models.yaml config file.
|
Initialize with the path to the models.yaml config file.
|
||||||
@ -299,12 +372,19 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
base_model: Optional[BaseModelType] = None,
|
base_model: Optional[BaseModelType] = None,
|
||||||
model_type: Optional[ModelType] = None
|
model_type: Optional[ModelType] = None
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
# ) -> dict:
|
|
||||||
"""
|
"""
|
||||||
Return a list of models.
|
Return a list of models.
|
||||||
"""
|
"""
|
||||||
return self.mgr.list_models(base_model, model_type)
|
return self.mgr.list_models(base_model, model_type)
|
||||||
|
|
||||||
|
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||||
|
"""
|
||||||
|
Return information about the model using the same format as list_models()
|
||||||
|
"""
|
||||||
|
return self.mgr.list_model(model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type)
|
||||||
|
|
||||||
def add_model(
|
def add_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@ -320,9 +400,28 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
with an assertion error if provided attributes are incorrect or
|
with an assertion error if provided attributes are incorrect or
|
||||||
the model name is missing. Call commit() to write changes to disk.
|
the model name is missing. Call commit() to write changes to disk.
|
||||||
"""
|
"""
|
||||||
|
self.logger.debug(f'add/update model {model_name}')
|
||||||
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
|
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
|
||||||
|
|
||||||
|
def update_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
model_attributes: dict,
|
||||||
|
) -> AddModelResult:
|
||||||
|
"""
|
||||||
|
Update the named model with a dictionary of attributes. Will fail with a
|
||||||
|
KeyError exception if the name does not already exist.
|
||||||
|
On a successful update, the config will be changed in memory. Will fail
|
||||||
|
with an assertion error if provided attributes are incorrect or
|
||||||
|
the model name is missing. Call commit() to write changes to disk.
|
||||||
|
"""
|
||||||
|
self.logger.debug(f'update model {model_name}')
|
||||||
|
if not self.model_exists(model_name, base_model, model_type):
|
||||||
|
raise KeyError(f"Unknown model {model_name}")
|
||||||
|
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
|
||||||
|
|
||||||
def del_model(
|
def del_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@ -334,8 +433,29 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
then the underlying weight file or diffusers directory will be deleted
|
then the underlying weight file or diffusers directory will be deleted
|
||||||
as well. Call commit() to write to disk.
|
as well. Call commit() to write to disk.
|
||||||
"""
|
"""
|
||||||
|
self.logger.debug(f'delete model {model_name}')
|
||||||
self.mgr.del_model(model_name, base_model, model_type)
|
self.mgr.del_model(model_name, base_model, model_type)
|
||||||
|
|
||||||
|
def convert_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: Union[ModelType.Main,ModelType.Vae],
|
||||||
|
) -> AddModelResult:
|
||||||
|
"""
|
||||||
|
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||||
|
version and deleting the original checkpoint file if it is in the models
|
||||||
|
directory.
|
||||||
|
:param model_name: Name of the model to convert
|
||||||
|
:param base_model: Base model type
|
||||||
|
:param model_type: Type of model ['vae' or 'main']
|
||||||
|
|
||||||
|
This will raise a ValueError unless the model is not a checkpoint. It will
|
||||||
|
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||||
|
directory already in place.
|
||||||
|
"""
|
||||||
|
self.logger.debug(f'convert model {model_name}')
|
||||||
|
return self.mgr.convert_model(model_name, base_model, model_type)
|
||||||
|
|
||||||
def commit(self, conf_file: Optional[Path]=None):
|
def commit(self, conf_file: Optional[Path]=None):
|
||||||
"""
|
"""
|
||||||
@ -387,9 +507,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
return self.mgr.logger
|
return self.mgr.logger
|
||||||
|
|
||||||
def heuristic_import(self,
|
def heuristic_import(self,
|
||||||
items_to_import: Set[str],
|
items_to_import: set[str],
|
||||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
|
||||||
)->Dict[str, AddModelResult]:
|
)->dict[str, AddModelResult]:
|
||||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||||
successfully imported items.
|
successfully imported items.
|
||||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||||
@ -406,4 +526,35 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||||
that model.
|
that model.
|
||||||
'''
|
'''
|
||||||
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
||||||
|
|
||||||
|
def merge_models(
|
||||||
|
self,
|
||||||
|
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
|
||||||
|
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
|
||||||
|
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||||
|
alpha: Optional[float] = 0.5,
|
||||||
|
interp: Optional[MergeInterpolationMethod] = None,
|
||||||
|
force: Optional[bool] = False,
|
||||||
|
) -> AddModelResult:
|
||||||
|
"""
|
||||||
|
Merge two to three diffusrs pipeline models and save as a new model.
|
||||||
|
:param model_names: List of 2-3 models to merge
|
||||||
|
:param base_model: Base model to use for all models
|
||||||
|
:param merged_model_name: Name of destination merged model
|
||||||
|
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||||
|
:param interp: Interpolation method. None (default)
|
||||||
|
"""
|
||||||
|
merger = ModelMerger(self.mgr)
|
||||||
|
try:
|
||||||
|
result = merger.merge_diffusion_models_and_save(
|
||||||
|
model_names = model_names,
|
||||||
|
base_model = base_model,
|
||||||
|
merged_model_name = merged_model_name,
|
||||||
|
alpha = alpha,
|
||||||
|
interp = interp,
|
||||||
|
force = force,
|
||||||
|
)
|
||||||
|
except AssertionError as e:
|
||||||
|
raise ValueError(e)
|
||||||
|
return result
|
||||||
|
@ -166,14 +166,18 @@ class ModelInstall(object):
|
|||||||
# add requested models
|
# add requested models
|
||||||
for path in selections.install_models:
|
for path in selections.install_models:
|
||||||
logger.info(f'Installing {path} [{job}/{jobs}]')
|
logger.info(f'Installing {path} [{job}/{jobs}]')
|
||||||
self.heuristic_import(path)
|
try:
|
||||||
|
self.heuristic_import(path)
|
||||||
|
except (ValueError, KeyError) as e:
|
||||||
|
logger.error(str(e))
|
||||||
job += 1
|
job += 1
|
||||||
|
|
||||||
self.mgr.commit()
|
self.mgr.commit()
|
||||||
|
|
||||||
def heuristic_import(self,
|
def heuristic_import(self,
|
||||||
model_path_id_or_url: Union[str,Path],
|
model_path_id_or_url: Union[str,Path],
|
||||||
models_installed: Set[Path]=None)->Dict[str, AddModelResult]:
|
models_installed: Set[Path]=None,
|
||||||
|
)->Dict[str, AddModelResult]:
|
||||||
'''
|
'''
|
||||||
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
|
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
|
||||||
:param models_installed: Set of installed models, used for recursive invocation
|
:param models_installed: Set of installed models, used for recursive invocation
|
||||||
@ -187,61 +191,53 @@ class ModelInstall(object):
|
|||||||
self.current_id = model_path_id_or_url
|
self.current_id = model_path_id_or_url
|
||||||
path = Path(model_path_id_or_url)
|
path = Path(model_path_id_or_url)
|
||||||
|
|
||||||
try:
|
# checkpoint file, or similar
|
||||||
# checkpoint file, or similar
|
if path.is_file():
|
||||||
if path.is_file():
|
models_installed.update({str(path):self._install_path(path)})
|
||||||
models_installed.update(self._install_path(path))
|
|
||||||
|
|
||||||
# folders style or similar
|
# folders style or similar
|
||||||
elif path.is_dir() and any([(path/x).exists() for x in \
|
elif path.is_dir() and any([(path/x).exists() for x in \
|
||||||
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
|
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
|
||||||
]
|
]
|
||||||
):
|
):
|
||||||
models_installed.update(self._install_path(path))
|
models_installed.update(self._install_path(path))
|
||||||
|
|
||||||
# recursive scan
|
# recursive scan
|
||||||
elif path.is_dir():
|
elif path.is_dir():
|
||||||
for child in path.iterdir():
|
for child in path.iterdir():
|
||||||
self.heuristic_import(child, models_installed=models_installed)
|
self.heuristic_import(child, models_installed=models_installed)
|
||||||
|
|
||||||
# huggingface repo
|
# huggingface repo
|
||||||
elif len(str(path).split('/')) == 2:
|
elif len(str(model_path_id_or_url).split('/')) == 2:
|
||||||
models_installed.update(self._install_repo(str(path)))
|
models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
|
||||||
|
|
||||||
# a URL
|
# a URL
|
||||||
elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")):
|
elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")):
|
||||||
models_installed.update(self._install_url(model_path_id_or_url))
|
models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
|
raise KeyError(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
|
|
||||||
return models_installed
|
return models_installed
|
||||||
|
|
||||||
# install a model from a local path. The optional info parameter is there to prevent
|
# install a model from a local path. The optional info parameter is there to prevent
|
||||||
# the model from being probed twice in the event that it has already been probed.
|
# the model from being probed twice in the event that it has already been probed.
|
||||||
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Dict[str, AddModelResult]:
|
def _install_path(self, path: Path, info: ModelProbeInfo=None)->AddModelResult:
|
||||||
try:
|
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
||||||
model_result = None
|
if not info:
|
||||||
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
logger.warning(f'Unable to parse format of {path}')
|
||||||
model_name = path.stem if path.is_file() else path.name
|
return None
|
||||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
model_name = path.stem if path.is_file() else path.name
|
||||||
raise ValueError(f'A model named "{model_name}" is already installed.')
|
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||||
attributes = self._make_attributes(path,info)
|
raise ValueError(f'A model named "{model_name}" is already installed.')
|
||||||
model_result = self.mgr.add_model(model_name = model_name,
|
attributes = self._make_attributes(path,info)
|
||||||
base_model = info.base_type,
|
return self.mgr.add_model(model_name = model_name,
|
||||||
model_type = info.model_type,
|
base_model = info.base_type,
|
||||||
model_attributes = attributes,
|
model_type = info.model_type,
|
||||||
)
|
model_attributes = attributes,
|
||||||
except Exception as e:
|
)
|
||||||
logger.warning(f'{str(e)} Skipping registration.')
|
|
||||||
return {}
|
|
||||||
return {str(path): model_result}
|
|
||||||
|
|
||||||
def _install_url(self, url: str)->dict:
|
def _install_url(self, url: str)->AddModelResult:
|
||||||
# copy to a staging area, probe, import and delete
|
|
||||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||||
location = download_with_resume(url,Path(staging))
|
location = download_with_resume(url,Path(staging))
|
||||||
if not location:
|
if not location:
|
||||||
@ -253,7 +249,7 @@ class ModelInstall(object):
|
|||||||
# staged version will be garbage-collected at this time
|
# staged version will be garbage-collected at this time
|
||||||
return self._install_path(Path(models_path), info)
|
return self._install_path(Path(models_path), info)
|
||||||
|
|
||||||
def _install_repo(self, repo_id: str)->dict:
|
def _install_repo(self, repo_id: str)->AddModelResult:
|
||||||
hinfo = HfApi().model_info(repo_id)
|
hinfo = HfApi().model_info(repo_id)
|
||||||
|
|
||||||
# we try to figure out how to download this most economically
|
# we try to figure out how to download this most economically
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
Initialization file for invokeai.backend.model_management
|
Initialization file for invokeai.backend.model_management
|
||||||
"""
|
"""
|
||||||
from .model_manager import ModelManager, ModelInfo, AddModelResult
|
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
|
||||||
from .model_cache import ModelCache
|
from .model_cache import ModelCache
|
||||||
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
|
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
|
||||||
|
from .model_merge import ModelMerger, MergeInterpolationMethod
|
||||||
|
|
||||||
|
@ -2,8 +2,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from typing import Optional, Dict, Tuple, Any, Union, List
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional, Tuple, Union, List
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from compel.embeddings_provider import BaseTextualInversionManager
|
from compel.embeddings_provider import BaseTextualInversionManager
|
||||||
|
@ -234,7 +234,7 @@ import textwrap
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types
|
from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types
|
||||||
from shutil import rmtree
|
from shutil import rmtree, move
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
@ -279,7 +279,7 @@ class InvalidModelError(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
class AddModelResult(BaseModel):
|
class AddModelResult(BaseModel):
|
||||||
name: str = Field(description="The name of the model after import")
|
name: str = Field(description="The name of the model after installation")
|
||||||
model_type: ModelType = Field(description="The type of model")
|
model_type: ModelType = Field(description="The type of model")
|
||||||
base_model: BaseModelType = Field(description="The base model")
|
base_model: BaseModelType = Field(description="The base model")
|
||||||
config: ModelConfigBase = Field(description="The configuration of the model")
|
config: ModelConfigBase = Field(description="The configuration of the model")
|
||||||
@ -491,17 +491,32 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
return [(self.parse_key(x)) for x in self.models.keys()]
|
return [(self.parse_key(x)) for x in self.models.keys()]
|
||||||
|
|
||||||
|
def list_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Returns a dict describing one installed model, using
|
||||||
|
the combined format of the list_models() method.
|
||||||
|
"""
|
||||||
|
models = self.list_models(base_model,model_type,model_name)
|
||||||
|
return models[0] if models else None
|
||||||
|
|
||||||
def list_models(
|
def list_models(
|
||||||
self,
|
self,
|
||||||
base_model: Optional[BaseModelType] = None,
|
base_model: Optional[BaseModelType] = None,
|
||||||
model_type: Optional[ModelType] = None,
|
model_type: Optional[ModelType] = None,
|
||||||
|
model_name: Optional[str] = None,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Return a list of models.
|
Return a list of models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold)
|
||||||
models = []
|
models = []
|
||||||
for model_key in sorted(self.models, key=str.casefold):
|
for model_key in model_keys:
|
||||||
model_config = self.models[model_key]
|
model_config = self.models[model_key]
|
||||||
|
|
||||||
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
||||||
@ -546,10 +561,7 @@ class ModelManager(object):
|
|||||||
model_cfg = self.models.pop(model_key, None)
|
model_cfg = self.models.pop(model_key, None)
|
||||||
|
|
||||||
if model_cfg is None:
|
if model_cfg is None:
|
||||||
self.logger.error(
|
raise KeyError(f"Unknown model {model_key}")
|
||||||
f"Unknown model {model_key}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# note: it not garantie to release memory(model can has other references)
|
# note: it not garantie to release memory(model can has other references)
|
||||||
cache_ids = self.cache_keys.pop(model_key, [])
|
cache_ids = self.cache_keys.pop(model_key, [])
|
||||||
@ -615,6 +627,7 @@ class ModelManager(object):
|
|||||||
self.cache.uncache_model(cache_id)
|
self.cache.uncache_model(cache_id)
|
||||||
|
|
||||||
self.models[model_key] = model_config
|
self.models[model_key] = model_config
|
||||||
|
self.commit()
|
||||||
return AddModelResult(
|
return AddModelResult(
|
||||||
name = model_name,
|
name = model_name,
|
||||||
model_type = model_type,
|
model_type = model_type,
|
||||||
@ -622,6 +635,60 @@ class ModelManager(object):
|
|||||||
config = model_config,
|
config = model_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def convert_model (
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: Union[ModelType.Main,ModelType.Vae],
|
||||||
|
) -> AddModelResult:
|
||||||
|
'''
|
||||||
|
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||||
|
version and deleting the original checkpoint file if it is in the models
|
||||||
|
directory.
|
||||||
|
:param model_name: Name of the model to convert
|
||||||
|
:param base_model: Base model type
|
||||||
|
:param model_type: Type of model ['vae' or 'main']
|
||||||
|
|
||||||
|
This will raise a ValueError unless the model is a checkpoint.
|
||||||
|
'''
|
||||||
|
info = self.model_info(model_name, base_model, model_type)
|
||||||
|
if info["model_format"] != "checkpoint":
|
||||||
|
raise ValueError(f"not a checkpoint format model: {model_name}")
|
||||||
|
|
||||||
|
# We are taking advantage of a side effect of get_model() that converts check points
|
||||||
|
# into cached diffusers directories stored at `location`. It doesn't matter
|
||||||
|
# what submodeltype we request here, so we get the smallest.
|
||||||
|
submodel = {"submodel_type": SubModelType.Tokenizer} if model_type==ModelType.Main else {}
|
||||||
|
model = self.get_model(model_name,
|
||||||
|
base_model,
|
||||||
|
model_type,
|
||||||
|
**submodel,
|
||||||
|
)
|
||||||
|
checkpoint_path = self.app_config.root_path / info["path"]
|
||||||
|
old_diffusers_path = self.app_config.models_path / model.location
|
||||||
|
new_diffusers_path = self.app_config.models_path / base_model.value / model_type.value / model_name
|
||||||
|
if new_diffusers_path.exists():
|
||||||
|
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
move(old_diffusers_path,new_diffusers_path)
|
||||||
|
info["model_format"] = "diffusers"
|
||||||
|
info["path"] = str(new_diffusers_path.relative_to(self.app_config.root_path))
|
||||||
|
info.pop('config')
|
||||||
|
|
||||||
|
result = self.add_model(model_name, base_model, model_type,
|
||||||
|
model_attributes = info,
|
||||||
|
clobber=True)
|
||||||
|
except:
|
||||||
|
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
|
||||||
|
rmtree(new_diffusers_path)
|
||||||
|
raise
|
||||||
|
|
||||||
|
if checkpoint_path.exists() and checkpoint_path.is_relative_to(self.app_config.models_path):
|
||||||
|
checkpoint_path.unlink()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def search_models(self, search_folder):
|
def search_models(self, search_folder):
|
||||||
self.logger.info(f"Finding Models In: {search_folder}")
|
self.logger.info(f"Finding Models In: {search_folder}")
|
||||||
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
|
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
|
||||||
@ -821,6 +888,10 @@ class ModelManager(object):
|
|||||||
The result is a set of successfully installed models. Each element
|
The result is a set of successfully installed models. Each element
|
||||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||||
that model.
|
that model.
|
||||||
|
|
||||||
|
May return the following exceptions:
|
||||||
|
- KeyError - one or more of the items to import is not a valid path, repo_id or URL
|
||||||
|
- ValueError - a corresponding model already exists
|
||||||
'''
|
'''
|
||||||
# avoid circular import here
|
# avoid circular import here
|
||||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||||
@ -830,11 +901,7 @@ class ModelManager(object):
|
|||||||
prediction_type_helper = prediction_type_helper,
|
prediction_type_helper = prediction_type_helper,
|
||||||
model_manager = self)
|
model_manager = self)
|
||||||
for thing in items_to_import:
|
for thing in items_to_import:
|
||||||
try:
|
installed = installer.heuristic_import(thing)
|
||||||
installed = installer.heuristic_import(thing)
|
successfully_installed.update(installed)
|
||||||
successfully_installed.update(installed)
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.warning(f'{thing} could not be imported: {str(e)}')
|
|
||||||
|
|
||||||
self.commit()
|
self.commit()
|
||||||
return successfully_installed
|
return successfully_installed
|
||||||
|
131
invokeai/backend/model_management/model_merge.py
Normal file
131
invokeai/backend/model_management/model_merge.py
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
"""
|
||||||
|
invokeai.backend.model_management.model_merge exports:
|
||||||
|
merge_diffusion_models() -- combine multiple models by location and return a pipeline object
|
||||||
|
merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to models.yaml
|
||||||
|
|
||||||
|
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||||
|
"""
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
from diffusers import logging as dlogging
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
|
from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
|
||||||
|
|
||||||
|
class MergeInterpolationMethod(str, Enum):
|
||||||
|
WeightedSum = "weighted_sum"
|
||||||
|
Sigmoid = "sigmoid"
|
||||||
|
InvSigmoid = "inv_sigmoid"
|
||||||
|
AddDifference = "add_difference"
|
||||||
|
|
||||||
|
class ModelMerger(object):
|
||||||
|
def __init__(self, manager: ModelManager):
|
||||||
|
self.manager = manager
|
||||||
|
|
||||||
|
def merge_diffusion_models(
|
||||||
|
self,
|
||||||
|
model_paths: List[Path],
|
||||||
|
alpha: float = 0.5,
|
||||||
|
interp: MergeInterpolationMethod = None,
|
||||||
|
force: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> DiffusionPipeline:
|
||||||
|
"""
|
||||||
|
:param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids
|
||||||
|
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||||
|
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||||
|
:param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||||
|
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
||||||
|
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||||
|
|
||||||
|
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||||
|
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||||
|
"""
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
verbosity = dlogging.get_verbosity()
|
||||||
|
dlogging.set_verbosity_error()
|
||||||
|
|
||||||
|
pipe = DiffusionPipeline.from_pretrained(
|
||||||
|
model_paths[0],
|
||||||
|
custom_pipeline="checkpoint_merger",
|
||||||
|
)
|
||||||
|
merged_pipe = pipe.merge(
|
||||||
|
pretrained_model_name_or_path_list=model_paths,
|
||||||
|
alpha=alpha,
|
||||||
|
interp=interp.value if interp else None, #diffusers API treats None as "weighted sum"
|
||||||
|
force=force,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
dlogging.set_verbosity(verbosity)
|
||||||
|
return merged_pipe
|
||||||
|
|
||||||
|
|
||||||
|
def merge_diffusion_models_and_save (
|
||||||
|
self,
|
||||||
|
model_names: List[str],
|
||||||
|
base_model: Union[BaseModelType,str],
|
||||||
|
merged_model_name: str,
|
||||||
|
alpha: float = 0.5,
|
||||||
|
interp: MergeInterpolationMethod = None,
|
||||||
|
force: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> AddModelResult:
|
||||||
|
"""
|
||||||
|
:param models: up to three models, designated by their InvokeAI models.yaml model name
|
||||||
|
:param base_model: base model (must be the same for all merged models!)
|
||||||
|
:param merged_model_name: name for new model
|
||||||
|
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||||
|
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||||
|
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||||
|
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
|
||||||
|
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||||
|
|
||||||
|
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||||
|
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||||
|
"""
|
||||||
|
model_paths = list()
|
||||||
|
config = self.manager.app_config
|
||||||
|
base_model = BaseModelType(base_model)
|
||||||
|
vae = None
|
||||||
|
|
||||||
|
for mod in model_names:
|
||||||
|
info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main)
|
||||||
|
assert info, f"model {mod}, base_model {base_model}, is unknown"
|
||||||
|
assert info["model_format"] == "diffusers", f"{mod} is not a diffusers model. It must be optimized before merging"
|
||||||
|
assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged"
|
||||||
|
assert len(model_names) <= 2 or \
|
||||||
|
interp==MergeInterpolationMethod.AddDifference, "When merging three models, only the 'add_difference' merge method is supported"
|
||||||
|
# pick up the first model's vae
|
||||||
|
if mod == model_names[0]:
|
||||||
|
vae = info.get("vae")
|
||||||
|
model_paths.extend([config.root_path / info["path"]])
|
||||||
|
|
||||||
|
merge_method = None if interp == 'weighted_sum' else MergeInterpolationMethod(interp)
|
||||||
|
logger.debug(f'interp = {interp}, merge_method={merge_method}')
|
||||||
|
merged_pipe = self.merge_diffusion_models(
|
||||||
|
model_paths, alpha, merge_method, force, **kwargs
|
||||||
|
)
|
||||||
|
dump_path = config.models_path / base_model.value / ModelType.Main.value
|
||||||
|
dump_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
dump_path = dump_path / merged_model_name
|
||||||
|
|
||||||
|
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||||
|
attributes = dict(
|
||||||
|
path = str(dump_path),
|
||||||
|
description = f"Merge of models {', '.join(model_names)}",
|
||||||
|
model_format = "diffusers",
|
||||||
|
variant = ModelVariantType.Normal.value,
|
||||||
|
vae = vae,
|
||||||
|
)
|
||||||
|
return self.manager.add_model(merged_model_name,
|
||||||
|
base_model = base_model,
|
||||||
|
model_type = ModelType.Main,
|
||||||
|
model_attributes = attributes,
|
||||||
|
clobber = True
|
||||||
|
)
|
@ -116,7 +116,7 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
version=BaseModelType.StableDiffusion1,
|
version=BaseModelType.StableDiffusion1,
|
||||||
model_config=config,
|
model_config=config,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Initialization file for invokeai.frontend.merge
|
Initialization file for invokeai.frontend.merge
|
||||||
"""
|
"""
|
||||||
from .merge_diffusers import main as invokeai_merge_diffusers, merge_diffusion_models
|
from .merge_diffusers import main as invokeai_merge_diffusers
|
||||||
|
|
||||||
|
@ -6,9 +6,7 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
|||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import curses
|
import curses
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
@ -20,99 +18,15 @@ from npyscreen import widget
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from ...backend.model_management import ModelManager
|
from invokeai.backend.model_management import (
|
||||||
from ...frontend.install.widgets import FloatTitleSlider
|
ModelMerger, MergeInterpolationMethod,
|
||||||
|
ModelManager, ModelType, BaseModelType,
|
||||||
|
)
|
||||||
|
from invokeai.frontend.install.widgets import FloatTitleSlider, TextBox, SingleSelectColumns
|
||||||
|
|
||||||
DEST_MERGED_MODEL_DIR = "merged_models"
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
def merge_diffusion_models(
|
|
||||||
model_ids_or_paths: List[Union[str, Path]],
|
|
||||||
alpha: float = 0.5,
|
|
||||||
interp: str = None,
|
|
||||||
force: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
) -> DiffusionPipeline:
|
|
||||||
"""
|
|
||||||
model_ids_or_paths - up to three models, designated by their local paths or HuggingFace repo_ids
|
|
||||||
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
|
||||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
|
||||||
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
|
||||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
|
||||||
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
|
||||||
|
|
||||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
|
||||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
|
||||||
"""
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
verbosity = dlogging.get_verbosity()
|
|
||||||
dlogging.set_verbosity_error()
|
|
||||||
|
|
||||||
pipe = DiffusionPipeline.from_pretrained(
|
|
||||||
model_ids_or_paths[0],
|
|
||||||
cache_dir=kwargs.get("cache_dir", config.cache_dir),
|
|
||||||
custom_pipeline="checkpoint_merger",
|
|
||||||
)
|
|
||||||
merged_pipe = pipe.merge(
|
|
||||||
pretrained_model_name_or_path_list=model_ids_or_paths,
|
|
||||||
alpha=alpha,
|
|
||||||
interp=interp,
|
|
||||||
force=force,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
dlogging.set_verbosity(verbosity)
|
|
||||||
return merged_pipe
|
|
||||||
|
|
||||||
|
|
||||||
def merge_diffusion_models_and_commit(
|
|
||||||
models: List["str"],
|
|
||||||
merged_model_name: str,
|
|
||||||
alpha: float = 0.5,
|
|
||||||
interp: str = None,
|
|
||||||
force: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
models - up to three models, designated by their InvokeAI models.yaml model name
|
|
||||||
merged_model_name = name for new model
|
|
||||||
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
|
||||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
|
||||||
interp - The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
|
|
||||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
|
|
||||||
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
|
||||||
|
|
||||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
|
||||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
|
||||||
"""
|
|
||||||
config_file = config.model_conf_path
|
|
||||||
model_manager = ModelManager(OmegaConf.load(config_file))
|
|
||||||
for mod in models:
|
|
||||||
assert mod in model_manager.model_names(), f'** Unknown model "{mod}"'
|
|
||||||
assert (
|
|
||||||
model_manager.model_info(mod).get("format", None) == "diffusers"
|
|
||||||
), f"** {mod} is not a diffusers model. It must be optimized before merging."
|
|
||||||
model_ids_or_paths = [model_manager.model_name_or_path(x) for x in models]
|
|
||||||
|
|
||||||
merged_pipe = merge_diffusion_models(
|
|
||||||
model_ids_or_paths, alpha, interp, force, **kwargs
|
|
||||||
)
|
|
||||||
dump_path = config.models_dir / DEST_MERGED_MODEL_DIR
|
|
||||||
|
|
||||||
os.makedirs(dump_path, exist_ok=True)
|
|
||||||
dump_path = dump_path / merged_model_name
|
|
||||||
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
|
||||||
import_args = dict(
|
|
||||||
model_name=merged_model_name, description=f'Merge of models {", ".join(models)}'
|
|
||||||
)
|
|
||||||
if vae := model_manager.config[models[0]].get("vae", None):
|
|
||||||
logger.info(f"Using configured VAE assigned to {models[0]}")
|
|
||||||
import_args.update(vae=vae)
|
|
||||||
model_manager.import_diffuser_model(dump_path, **import_args)
|
|
||||||
model_manager.commit(config_file)
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_args() -> Namespace:
|
def _parse_args() -> Namespace:
|
||||||
parser = argparse.ArgumentParser(description="InvokeAI model merging")
|
parser = argparse.ArgumentParser(description="InvokeAI model merging")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -131,10 +45,17 @@ def _parse_args() -> Namespace:
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--models",
|
"--models",
|
||||||
|
dest="model_names",
|
||||||
type=str,
|
type=str,
|
||||||
nargs="+",
|
nargs="+",
|
||||||
help="Two to three model names to be merged",
|
help="Two to three model names to be merged",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--base_model",
|
||||||
|
type=str,
|
||||||
|
choices=[x.value for x in BaseModelType],
|
||||||
|
help="The base model shared by the models to be merged",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--merged_model_name",
|
"--merged_model_name",
|
||||||
"--destination",
|
"--destination",
|
||||||
@ -192,6 +113,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
window_height, window_width = curses.initscr().getmaxyx()
|
window_height, window_width = curses.initscr().getmaxyx()
|
||||||
|
|
||||||
self.model_names = self.get_model_names()
|
self.model_names = self.get_model_names()
|
||||||
|
self.current_base = 0
|
||||||
max_width = max([len(x) for x in self.model_names])
|
max_width = max([len(x) for x in self.model_names])
|
||||||
max_width += 6
|
max_width += 6
|
||||||
horizontal_layout = max_width * 3 < window_width
|
horizontal_layout = max_width * 3 < window_width
|
||||||
@ -208,12 +130,26 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
value="Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next.",
|
value="Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next.",
|
||||||
editable=False,
|
editable=False,
|
||||||
)
|
)
|
||||||
|
self.nextrely += 1
|
||||||
|
self.base_select = self.add_widget_intelligent(
|
||||||
|
SingleSelectColumns,
|
||||||
|
values=[
|
||||||
|
'Models Built on SD-1.x',
|
||||||
|
'Models Built on SD-2.x',
|
||||||
|
],
|
||||||
|
value=[self.current_base],
|
||||||
|
columns = 4,
|
||||||
|
max_height = 2,
|
||||||
|
relx=8,
|
||||||
|
scroll_exit = True,
|
||||||
|
)
|
||||||
|
self.base_select.on_changed = self._populate_models
|
||||||
self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
npyscreen.FixedText,
|
npyscreen.FixedText,
|
||||||
value="MODEL 1",
|
value="MODEL 1",
|
||||||
color="GOOD",
|
color="GOOD",
|
||||||
editable=False,
|
editable=False,
|
||||||
rely=4 if horizontal_layout else None,
|
rely=6 if horizontal_layout else None,
|
||||||
)
|
)
|
||||||
self.model1 = self.add_widget_intelligent(
|
self.model1 = self.add_widget_intelligent(
|
||||||
npyscreen.SelectOne,
|
npyscreen.SelectOne,
|
||||||
@ -222,7 +158,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
max_height=len(self.model_names),
|
max_height=len(self.model_names),
|
||||||
max_width=max_width,
|
max_width=max_width,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
rely=5,
|
rely=7,
|
||||||
)
|
)
|
||||||
self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
npyscreen.FixedText,
|
npyscreen.FixedText,
|
||||||
@ -230,7 +166,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
color="GOOD",
|
color="GOOD",
|
||||||
editable=False,
|
editable=False,
|
||||||
relx=max_width + 3 if horizontal_layout else None,
|
relx=max_width + 3 if horizontal_layout else None,
|
||||||
rely=4 if horizontal_layout else None,
|
rely=6 if horizontal_layout else None,
|
||||||
)
|
)
|
||||||
self.model2 = self.add_widget_intelligent(
|
self.model2 = self.add_widget_intelligent(
|
||||||
npyscreen.SelectOne,
|
npyscreen.SelectOne,
|
||||||
@ -240,7 +176,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
max_height=len(self.model_names),
|
max_height=len(self.model_names),
|
||||||
max_width=max_width,
|
max_width=max_width,
|
||||||
relx=max_width + 3 if horizontal_layout else None,
|
relx=max_width + 3 if horizontal_layout else None,
|
||||||
rely=5 if horizontal_layout else None,
|
rely=7 if horizontal_layout else None,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
@ -249,7 +185,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
color="GOOD",
|
color="GOOD",
|
||||||
editable=False,
|
editable=False,
|
||||||
relx=max_width * 2 + 3 if horizontal_layout else None,
|
relx=max_width * 2 + 3 if horizontal_layout else None,
|
||||||
rely=4 if horizontal_layout else None,
|
rely=6 if horizontal_layout else None,
|
||||||
)
|
)
|
||||||
models_plus_none = self.model_names.copy()
|
models_plus_none = self.model_names.copy()
|
||||||
models_plus_none.insert(0, "None")
|
models_plus_none.insert(0, "None")
|
||||||
@ -262,24 +198,26 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
max_width=max_width,
|
max_width=max_width,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
relx=max_width * 2 + 3 if horizontal_layout else None,
|
relx=max_width * 2 + 3 if horizontal_layout else None,
|
||||||
rely=5 if horizontal_layout else None,
|
rely=7 if horizontal_layout else None,
|
||||||
)
|
)
|
||||||
for m in [self.model1, self.model2, self.model3]:
|
for m in [self.model1, self.model2, self.model3]:
|
||||||
m.when_value_edited = self.models_changed
|
m.when_value_edited = self.models_changed
|
||||||
self.merged_model_name = self.add_widget_intelligent(
|
self.merged_model_name = self.add_widget_intelligent(
|
||||||
npyscreen.TitleText,
|
TextBox,
|
||||||
name="Name for merged model:",
|
name="Name for merged model:",
|
||||||
labelColor="CONTROL",
|
labelColor="CONTROL",
|
||||||
|
max_height=3,
|
||||||
value="",
|
value="",
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
self.force = self.add_widget_intelligent(
|
self.force = self.add_widget_intelligent(
|
||||||
npyscreen.Checkbox,
|
npyscreen.Checkbox,
|
||||||
name="Force merge of incompatible models",
|
name="Force merge of models created by different diffusers library versions",
|
||||||
labelColor="CONTROL",
|
labelColor="CONTROL",
|
||||||
value=False,
|
value=True,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
|
self.nextrely += 1
|
||||||
self.merge_method = self.add_widget_intelligent(
|
self.merge_method = self.add_widget_intelligent(
|
||||||
npyscreen.TitleSelectOne,
|
npyscreen.TitleSelectOne,
|
||||||
name="Merge Method:",
|
name="Merge Method:",
|
||||||
@ -341,7 +279,8 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
interp = self.interpolations[self.merge_method.value[0]]
|
interp = self.interpolations[self.merge_method.value[0]]
|
||||||
|
|
||||||
args = dict(
|
args = dict(
|
||||||
models=models,
|
model_names=models,
|
||||||
|
base_model=tuple(BaseModelType)[self.base_select.value[0]],
|
||||||
alpha=self.alpha.value,
|
alpha=self.alpha.value,
|
||||||
interp=interp,
|
interp=interp,
|
||||||
force=self.force.value,
|
force=self.force.value,
|
||||||
@ -379,21 +318,30 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_model_names(self) -> List[str]:
|
def get_model_names(self, base_model: BaseModelType=None) -> List[str]:
|
||||||
model_names = [
|
model_names = [
|
||||||
name
|
info["name"]
|
||||||
for name in self.model_manager.model_names()
|
for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model)
|
||||||
if self.model_manager.model_info(name).get("format") == "diffusers"
|
if info["model_format"] == "diffusers"
|
||||||
]
|
]
|
||||||
return sorted(model_names)
|
return sorted(model_names)
|
||||||
|
|
||||||
|
def _populate_models(self,value=None):
|
||||||
|
base_model = tuple(BaseModelType)[value[0]]
|
||||||
|
self.model_names = self.get_model_names(base_model)
|
||||||
|
|
||||||
|
models_plus_none = self.model_names.copy()
|
||||||
|
models_plus_none.insert(0, "None")
|
||||||
|
self.model1.values = self.model_names
|
||||||
|
self.model2.values = self.model_names
|
||||||
|
self.model3.values = models_plus_none
|
||||||
|
|
||||||
|
self.display()
|
||||||
|
|
||||||
class Mergeapp(npyscreen.NPSAppManaged):
|
class Mergeapp(npyscreen.NPSAppManaged):
|
||||||
def __init__(self):
|
def __init__(self, model_manager:ModelManager):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
conf = OmegaConf.load(config.model_conf_path)
|
self.model_manager = model_manager
|
||||||
self.model_manager = ModelManager(
|
|
||||||
conf, "cpu", "float16"
|
|
||||||
) # precision doesn't really matter here
|
|
||||||
|
|
||||||
def onStart(self):
|
def onStart(self):
|
||||||
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
|
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
|
||||||
@ -401,44 +349,41 @@ class Mergeapp(npyscreen.NPSAppManaged):
|
|||||||
|
|
||||||
|
|
||||||
def run_gui(args: Namespace):
|
def run_gui(args: Namespace):
|
||||||
mergeapp = Mergeapp()
|
model_manager = ModelManager(config.model_conf_path)
|
||||||
|
mergeapp = Mergeapp(model_manager)
|
||||||
mergeapp.run()
|
mergeapp.run()
|
||||||
|
|
||||||
args = mergeapp.merge_arguments
|
args = mergeapp.merge_arguments
|
||||||
merge_diffusion_models_and_commit(**args)
|
merger = ModelMerger(model_manager)
|
||||||
|
merger.merge_diffusion_models_and_save(**args)
|
||||||
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
|
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
|
||||||
|
|
||||||
|
|
||||||
def run_cli(args: Namespace):
|
def run_cli(args: Namespace):
|
||||||
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
|
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
|
||||||
assert (
|
assert (
|
||||||
args.models and len(args.models) >= 1 and len(args.models) <= 3
|
args.model_names and len(args.model_names) >= 1 and len(args.model_names) <= 3
|
||||||
), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage."
|
), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage."
|
||||||
|
|
||||||
if not args.merged_model_name:
|
if not args.merged_model_name:
|
||||||
args.merged_model_name = "+".join(args.models)
|
args.merged_model_name = "+".join(args.model_names)
|
||||||
logger.info(
|
logger.info(
|
||||||
f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
|
f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
|
||||||
)
|
)
|
||||||
|
|
||||||
model_manager = ModelManager(OmegaConf.load(config.model_conf_path))
|
model_manager = ModelManager(config.model_conf_path)
|
||||||
assert (
|
assert (
|
||||||
args.clobber or args.merged_model_name not in model_manager.model_names()
|
not model_manager.model_exists(args.merged_model_name, args.base_model, ModelType.Main) or args.clobber
|
||||||
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||||
|
|
||||||
merge_diffusion_models_and_commit(**vars(args))
|
merger = ModelMerger(model_manager)
|
||||||
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
|
merger.merge_diffusion_models_and_save(**vars(args))
|
||||||
|
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = _parse_args()
|
args = _parse_args()
|
||||||
config.root = args.root_dir
|
config.parse_args(['--root',str(args.root_dir)])
|
||||||
|
|
||||||
cache_dir = config.cache_dir
|
|
||||||
os.environ[
|
|
||||||
"HF_HOME"
|
|
||||||
] = cache_dir # because not clear the merge pipeline is honoring cache_dir
|
|
||||||
args.cache_dir = cache_dir
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if args.front_end:
|
if args.front_end:
|
||||||
|
Loading…
Reference in New Issue
Block a user