Merge branch 'main' into patch-1

This commit is contained in:
Lincoln Stein 2023-07-06 15:11:40 -04:00 committed by GitHub
commit 80575344fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 700 additions and 469 deletions

View File

@ -1,75 +1,30 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein
from typing import Literal, Optional, Union
from fastapi import Query, Body
from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies
from typing import Literal, List, Optional, Union
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
from pydantic import BaseModel, parse_obj_as
from starlette.exceptions import HTTPException
from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management import AddModelResult
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
from invokeai.backend.model_management.models import (
OPENAPI_MODEL_CONFIGS,
SchedulerPredictionType,
)
from invokeai.backend.model_management import MergeInterpolationMethod
from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"])
class VaeRepo(BaseModel):
repo_id: str = Field(description="The repo ID to use for this VAE")
path: Optional[str] = Field(description="The path to the VAE")
subfolder: Optional[str] = Field(description="The subfolder to use for this VAE")
class ModelInfo(BaseModel):
description: Optional[str] = Field(description="A description of the model")
model_name: str = Field(description="The name of the model")
model_type: str = Field(description="The type of the model")
class DiffusersModelInfo(ModelInfo):
format: Literal['folder'] = 'folder'
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
path: Optional[str] = Field(description="The path to the model")
class CkptModelInfo(ModelInfo):
format: Literal['ckpt'] = 'ckpt'
config: str = Field(description="The path to the model config")
weights: str = Field(description="The path to the model weights")
vae: str = Field(description="The path to the model VAE")
width: Optional[int] = Field(description="The width of the model")
height: Optional[int] = Field(description="The height of the model")
class SafetensorsModelInfo(CkptModelInfo):
format: Literal['safetensors'] = 'safetensors'
class CreateModelRequest(BaseModel):
name: str = Field(description="The name of the model")
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
class CreateModelResponse(BaseModel):
name: str = Field(description="The name of the new model")
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
status: str = Field(description="The status of the API response")
class ImportModelResponse(BaseModel):
name: str = Field(description="The name of the imported model")
# base_model: str = Field(description="The base model")
# model_type: str = Field(description="The model type")
info: AddModelResult = Field(description="The model info")
status: str = Field(description="The status of the API response")
class ConversionRequest(BaseModel):
name: str = Field(description="The name of the new model")
info: CkptModelInfo = Field(description="The converted model info")
save_location: str = Field(description="The path to save the converted model weights")
class ConvertedModelResponse(BaseModel):
name: str = Field(description="The name of the new model")
info: DiffusersModelInfo = Field(description="The converted model info")
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
class ModelsList(BaseModel):
models: list[MODEL_CONFIGS]
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
@models_router.get(
"/",
@ -77,75 +32,103 @@ class ModelsList(BaseModel):
responses={200: {"model": ModelsList }},
)
async def list_models(
base_model: Optional[BaseModelType] = Query(
default=None, description="Base model"
),
model_type: Optional[ModelType] = Query(
default=None, description="The type of model to get"
),
base_model: Optional[BaseModelType] = Query(default=None, description="Base model"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
) -> ModelsList:
"""Gets a list of models"""
models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)
models = parse_obj_as(ModelsList, { "models": models_raw })
return models
@models_router.post(
"/",
@models_router.patch(
"/{base_model}/{model_type}/{model_name}",
operation_id="update_model",
responses={200: {"status": "success"}},
responses={200: {"description" : "The model was updated successfully"},
404: {"description" : "The model could not be found"},
400: {"description" : "Bad request"}
},
status_code = 200,
response_model = UpdateModelResponse,
)
async def update_model(
model_request: CreateModelRequest
) -> CreateModelResponse:
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> UpdateModelResponse:
""" Add Model """
model_request_info = model_request.info
info_dict = model_request_info.dict()
model_response = CreateModelResponse(name=model_request.name, info=model_request.info, status="success")
ApiDependencies.invoker.services.model_manager.add_model(
model_name=model_request.name,
model_attributes=info_dict,
clobber=True,
)
try:
ApiDependencies.invoker.services.model_manager.update_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
model_attributes=info.dict()
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
model_response = parse_obj_as(UpdateModelResponse, model_raw)
except KeyError as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return model_response
@models_router.post(
"/import",
"/",
operation_id="import_model",
responses= {
201: {"description" : "The model imported successfully"},
404: {"description" : "The model could not be found"},
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description" : "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
response_model=ImportModelResponse
)
async def import_model(
name: str = Query(description="A model path, repo_id or URL to import"),
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = Query(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
location: str = Body(description="A model path, repo_id or URL to import"),
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
) -> ImportModelResponse:
""" Add a model using its local path, repo_id, or remote URL """
items_to_import = {name}
items_to_import = {location}
prediction_types = { x.value: x for x in SchedulerPredictionType }
logger = ApiDependencies.invoker.services.logger
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
items_to_import = items_to_import,
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
)
if info := installed_models.get(name):
logger.info(f'Successfully imported {name}, got {info}')
return ImportModelResponse(
name = name,
info = info,
status = "success",
try:
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
items_to_import = items_to_import,
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
)
else:
logger.error(f'Model {name} not imported')
raise HTTPException(status_code=404, detail=f'Model {name} not found')
info = installed_models.get(location)
if not info:
logger.error("Import failed")
raise HTTPException(status_code=424)
logger.info(f'Successfully imported {location}, got {info}')
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=info.name,
base_model=info.base_model,
model_type=info.model_type
)
return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.delete(
"/{model_name}",
"/{base_model}/{model_type}/{model_name}",
operation_id="del_model",
responses={
204: {
@ -156,144 +139,95 @@ async def import_model(
}
},
)
async def delete_model(model_name: str) -> None:
async def delete_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
) -> Response:
"""Delete Model"""
model_names = ApiDependencies.invoker.services.model_manager.model_names()
logger = ApiDependencies.invoker.services.logger
model_exists = model_name in model_names
# check if model exists
logger.info(f"Checking for model {model_name}...")
if model_exists:
logger.info(f"Deleting Model: {model_name}")
ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True)
logger.info(f"Model Deleted: {model_name}")
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
else:
logger.error("Model not found")
try:
ApiDependencies.invoker.services.model_manager.del_model(model_name,
base_model = base_model,
model_type = model_type
)
logger.info(f"Deleted model: {model_name}")
return Response(status_code=204)
except KeyError:
logger.error(f"Model not found: {model_name}")
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
# @socketio.on("convertToDiffusers")
# def convert_to_diffusers(model_to_convert: dict):
# try:
# if model_info := self.generate.model_manager.model_info(
# model_name=model_to_convert["model_name"]
# ):
# if "weights" in model_info:
# ckpt_path = Path(model_info["weights"])
# original_config_file = Path(model_info["config"])
# model_name = model_to_convert["model_name"]
# model_description = model_info["description"]
# else:
# self.socketio.emit(
# "error", {"message": "Model is not a valid checkpoint file"}
# )
# else:
# self.socketio.emit(
# "error", {"message": "Could not retrieve model info."}
# )
# if not ckpt_path.is_absolute():
# ckpt_path = Path(Globals.root, ckpt_path)
# if original_config_file and not original_config_file.is_absolute():
# original_config_file = Path(Globals.root, original_config_file)
# diffusers_path = Path(
# ckpt_path.parent.absolute(), f"{model_name}_diffusers"
# )
# if model_to_convert["save_location"] == "root":
# diffusers_path = Path(
# global_converted_ckpts_dir(), f"{model_name}_diffusers"
# )
# if (
# model_to_convert["save_location"] == "custom"
# and model_to_convert["custom_location"] is not None
# ):
# diffusers_path = Path(
# model_to_convert["custom_location"], f"{model_name}_diffusers"
# )
# if diffusers_path.exists():
# shutil.rmtree(diffusers_path)
# self.generate.model_manager.convert_and_import(
# ckpt_path,
# diffusers_path,
# model_name=model_name,
# model_description=model_description,
# vae=None,
# original_config_file=original_config_file,
# commit_to_conf=opt.conf,
# )
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelConverted",
# {
# "new_model_name": model_name,
# "model_list": new_model_list,
# "update": True,
# },
# )
# print(f">> Model Converted: {model_name}")
# except Exception as e:
# self.handle_exceptions(e)
# @socketio.on("mergeDiffusersModels")
# def merge_diffusers_models(model_merge_info: dict):
# try:
# models_to_merge = model_merge_info["models_to_merge"]
# model_ids_or_paths = [
# self.generate.model_manager.model_name_or_path(x)
# for x in models_to_merge
# ]
# merged_pipe = merge_diffusion_models(
# model_ids_or_paths,
# model_merge_info["alpha"],
# model_merge_info["interp"],
# model_merge_info["force"],
# )
# dump_path = global_models_dir() / "merged_models"
# if model_merge_info["model_merge_save_path"] is not None:
# dump_path = Path(model_merge_info["model_merge_save_path"])
# os.makedirs(dump_path, exist_ok=True)
# dump_path = dump_path / model_merge_info["merged_model_name"]
# merged_pipe.save_pretrained(dump_path, safe_serialization=1)
# merged_model_config = dict(
# model_name=model_merge_info["merged_model_name"],
# description=f'Merge of models {", ".join(models_to_merge)}',
# commit_to_conf=opt.conf,
# )
# if vae := self.generate.model_manager.config[models_to_merge[0]].get(
# "vae", None
# ):
# print(f">> Using configured VAE assigned to {models_to_merge[0]}")
# merged_model_config.update(vae=vae)
# self.generate.model_manager.import_diffuser_model(
# dump_path, **merged_model_config
# )
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelsMerged",
# {
# "merged_models": models_to_merge,
# "merged_model_name": model_merge_info["merged_model_name"],
# "model_list": new_model_list,
# "update": True,
# },
# )
# print(f">> Models Merged: {models_to_merge}")
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
# except Exception as e:
@models_router.put(
"/convert/{base_model}/{model_type}/{model_name}",
operation_id="convert_model",
responses={
200: { "description": "Model converted successfully" },
400: {"description" : "Bad request" },
404: { "description": "Model not found" },
},
status_code = 200,
response_model = ConvertModelResponse,
)
async def convert_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
) -> ConvertModelResponse:
"""Convert a checkpoint model into a diffusers model"""
logger = ApiDependencies.invoker.services.logger
try:
logger.info(f"Converting model: {model_name}")
ApiDependencies.invoker.services.model_manager.convert_model(model_name,
base_model = base_model,
model_type = model_type
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name,
base_model = base_model,
model_type = model_type)
response = parse_obj_as(ConvertModelResponse, model_raw)
except KeyError:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return response
@models_router.put(
"/merge/{base_model}",
operation_id="merge_models",
responses={
200: { "description": "Model converted successfully" },
400: { "description": "Incompatible models" },
404: { "description": "One or more models not found" },
},
status_code = 200,
response_model = MergeModelResponse,
)
async def merge_models(
base_model: BaseModelType = Path(description="Base model"),
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
merged_model_name: Optional[str] = Body(description = "Name of destination model"),
alpha: Optional[float] = Body(description = "Alpha weighting strength to apply to 2d and 3d models", default=0.5),
interp: Union[MergeInterpolationMethod, None] = Body(description = "Interpolation method"),
force: Optional[bool] = Body(description = "Force merging of models created with different versions of diffusers", default=False),
) -> MergeModelResponse:
"""Convert a checkpoint model into a diffusers model"""
logger = ApiDependencies.invoker.services.logger
try:
logger.info(f"Merging models: {model_names}")
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
base_model,
merged_model_name or "+".join(model_names),
alpha,
interp,
force)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
base_model = base_model,
model_type = ModelType.Main,
)
response = parse_obj_as(ConvertModelResponse, model_raw)
except KeyError:
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return response

View File

@ -2,22 +2,29 @@
from __future__ import annotations
import torch
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING
from dataclasses import dataclass
from pydantic import Field
from typing import Optional, Union, Callable, List, Tuple, TYPE_CHECKING
from types import ModuleType
from invokeai.backend.model_management.model_manager import (
from invokeai.backend.model_management import (
ModelManager,
BaseModelType,
ModelType,
SubModelType,
ModelInfo,
AddModelResult,
SchedulerPredictionType,
ModelMerger,
MergeInterpolationMethod,
)
import torch
from invokeai.app.models.exceptions import CanceledException
from .config import InvokeAIAppConfig
from ...backend.util import choose_precision, choose_torch_device
from .config import InvokeAIAppConfig
if TYPE_CHECKING:
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
@ -30,7 +37,7 @@ class ModelManagerServiceBase(ABC):
def __init__(
self,
config: InvokeAIAppConfig,
logger: types.ModuleType,
logger: ModuleType,
):
"""
Initialize with the path to the models.yaml config file.
@ -73,13 +80,7 @@ class ModelManagerServiceBase(ABC):
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Given a model name returns a dict-like (OmegaConf) object describing it.
"""
pass
@abstractmethod
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
"""
Returns a list of all the model names known.
Uses the exact format as the omegaconf stanza.
"""
pass
@ -101,7 +102,20 @@ class ModelManagerServiceBase(ABC):
}
"""
pass
@abstractmethod
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Return information about the model using the same format as list_models()
"""
pass
@abstractmethod
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
"""
Returns a list of all the model names known.
"""
pass
@abstractmethod
def add_model(
@ -111,7 +125,7 @@ class ModelManagerServiceBase(ABC):
model_type: ModelType,
model_attributes: dict,
clobber: bool = False
) -> None:
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
@ -121,6 +135,24 @@ class ModelManagerServiceBase(ABC):
"""
pass
@abstractmethod
def update_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with a
KeyErrorException if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
pass
@abstractmethod
def del_model(
self,
@ -135,11 +167,32 @@ class ModelManagerServiceBase(ABC):
"""
pass
@abstractmethod
def convert_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: Union[ModelType.Main,ModelType.Vae],
) -> AddModelResult:
"""
Convert a checkpoint file into a diffusers folder, deleting the cached
version and deleting the original checkpoint file if it is in the models
directory.
:param model_name: Name of the model to convert
:param base_model: Base model type
:param model_type: Type of model ['vae' or 'main']
This will raise a ValueError unless the model is not a checkpoint. It will
also raise a ValueError in the event that there is a similarly-named diffusers
directory already in place.
"""
pass
@abstractmethod
def heuristic_import(self,
items_to_import: Set[str],
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
)->Dict[str, AddModelResult]:
items_to_import: set[str],
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
)->dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported.
@ -159,7 +212,27 @@ class ModelManagerServiceBase(ABC):
pass
@abstractmethod
def commit(self, conf_file: Path = None) -> None:
def merge_models(
self,
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
) -> AddModelResult:
"""
Merge two to three diffusrs pipeline models and save as a new model.
:param model_names: List of 2-3 models to merge
:param base_model: Base model to use for all models
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
"""
pass
@abstractmethod
def commit(self, conf_file: Optional[Path] = None) -> None:
"""
Write current configuration out to the indicated file.
If no conf_file is provided, then replaces the
@ -173,7 +246,7 @@ class ModelManagerService(ModelManagerServiceBase):
def __init__(
self,
config: InvokeAIAppConfig,
logger: types.ModuleType,
logger: ModuleType,
):
"""
Initialize with the path to the models.yaml config file.
@ -299,12 +372,19 @@ class ModelManagerService(ModelManagerServiceBase):
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None
) -> list[dict]:
# ) -> dict:
"""
Return a list of models.
"""
return self.mgr.list_models(base_model, model_type)
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Return information about the model using the same format as list_models()
"""
return self.mgr.list_model(model_name=model_name,
base_model=base_model,
model_type=model_type)
def add_model(
self,
model_name: str,
@ -320,9 +400,28 @@ class ModelManagerService(ModelManagerServiceBase):
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
self.logger.debug(f'add/update model {model_name}')
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
def update_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with a
KeyError exception if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
self.logger.debug(f'update model {model_name}')
if not self.model_exists(model_name, base_model, model_type):
raise KeyError(f"Unknown model {model_name}")
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
def del_model(
self,
model_name: str,
@ -334,8 +433,29 @@ class ModelManagerService(ModelManagerServiceBase):
then the underlying weight file or diffusers directory will be deleted
as well. Call commit() to write to disk.
"""
self.logger.debug(f'delete model {model_name}')
self.mgr.del_model(model_name, base_model, model_type)
def convert_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: Union[ModelType.Main,ModelType.Vae],
) -> AddModelResult:
"""
Convert a checkpoint file into a diffusers folder, deleting the cached
version and deleting the original checkpoint file if it is in the models
directory.
:param model_name: Name of the model to convert
:param base_model: Base model type
:param model_type: Type of model ['vae' or 'main']
This will raise a ValueError unless the model is not a checkpoint. It will
also raise a ValueError in the event that there is a similarly-named diffusers
directory already in place.
"""
self.logger.debug(f'convert model {model_name}')
return self.mgr.convert_model(model_name, base_model, model_type)
def commit(self, conf_file: Optional[Path]=None):
"""
@ -387,9 +507,9 @@ class ModelManagerService(ModelManagerServiceBase):
return self.mgr.logger
def heuristic_import(self,
items_to_import: Set[str],
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
)->Dict[str, AddModelResult]:
items_to_import: set[str],
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
)->dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported.
@ -406,4 +526,31 @@ class ModelManagerService(ModelManagerServiceBase):
of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model.
'''
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
def merge_models(
self,
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
) -> AddModelResult:
"""
Merge two to three diffusrs pipeline models and save as a new model.
:param model_names: List of 2-3 models to merge
:param base_model: Base model to use for all models
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
"""
merger = ModelMerger(self.mgr)
return merger.merge_diffusion_models_and_save(
model_names = model_names,
base_model = base_model,
merged_model_name = merged_model_name,
alpha = alpha,
interp = interp,
force = force,
)

View File

@ -166,14 +166,18 @@ class ModelInstall(object):
# add requested models
for path in selections.install_models:
logger.info(f'Installing {path} [{job}/{jobs}]')
self.heuristic_import(path)
try:
self.heuristic_import(path)
except (ValueError, KeyError) as e:
logger.error(str(e))
job += 1
self.mgr.commit()
def heuristic_import(self,
model_path_id_or_url: Union[str,Path],
models_installed: Set[Path]=None)->Dict[str, AddModelResult]:
model_path_id_or_url: Union[str,Path],
models_installed: Set[Path]=None,
)->Dict[str, AddModelResult]:
'''
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
:param models_installed: Set of installed models, used for recursive invocation
@ -187,61 +191,53 @@ class ModelInstall(object):
self.current_id = model_path_id_or_url
path = Path(model_path_id_or_url)
try:
# checkpoint file, or similar
if path.is_file():
models_installed.update(self._install_path(path))
# checkpoint file, or similar
if path.is_file():
models_installed.update({str(path):self._install_path(path)})
# folders style or similar
elif path.is_dir() and any([(path/x).exists() for x in \
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
]
):
models_installed.update(self._install_path(path))
# folders style or similar
elif path.is_dir() and any([(path/x).exists() for x in \
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
]
):
models_installed.update(self._install_path(path))
# recursive scan
elif path.is_dir():
for child in path.iterdir():
self.heuristic_import(child, models_installed=models_installed)
# recursive scan
elif path.is_dir():
for child in path.iterdir():
self.heuristic_import(child, models_installed=models_installed)
# huggingface repo
elif len(str(path).split('/')) == 2:
models_installed.update(self._install_repo(str(path)))
# huggingface repo
elif len(str(model_path_id_or_url).split('/')) == 2:
models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
# a URL
elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")):
models_installed.update(self._install_url(model_path_id_or_url))
# a URL
elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")):
models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
else:
logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
except ValueError as e:
logger.error(str(e))
else:
raise KeyError(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
return models_installed
# install a model from a local path. The optional info parameter is there to prevent
# the model from being probed twice in the event that it has already been probed.
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Dict[str, AddModelResult]:
try:
model_result = None
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
model_name = path.stem if path.is_file() else path.name
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
raise ValueError(f'A model named "{model_name}" is already installed.')
attributes = self._make_attributes(path,info)
model_result = self.mgr.add_model(model_name = model_name,
base_model = info.base_type,
model_type = info.model_type,
model_attributes = attributes,
)
except Exception as e:
logger.warning(f'{str(e)} Skipping registration.')
return {}
return {str(path): model_result}
def _install_path(self, path: Path, info: ModelProbeInfo=None)->AddModelResult:
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
if not info:
logger.warning(f'Unable to parse format of {path}')
return None
model_name = path.stem if path.is_file() else path.name
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
raise ValueError(f'A model named "{model_name}" is already installed.')
attributes = self._make_attributes(path,info)
return self.mgr.add_model(model_name = model_name,
base_model = info.base_type,
model_type = info.model_type,
model_attributes = attributes,
)
def _install_url(self, url: str)->dict:
# copy to a staging area, probe, import and delete
def _install_url(self, url: str)->AddModelResult:
with TemporaryDirectory(dir=self.config.models_path) as staging:
location = download_with_resume(url,Path(staging))
if not location:
@ -253,7 +249,7 @@ class ModelInstall(object):
# staged version will be garbage-collected at this time
return self._install_path(Path(models_path), info)
def _install_repo(self, repo_id: str)->dict:
def _install_repo(self, repo_id: str)->AddModelResult:
hinfo = HfApi().model_info(repo_id)
# we try to figure out how to download this most economically

View File

@ -1,7 +1,8 @@
"""
Initialization file for invokeai.backend.model_management
"""
from .model_manager import ModelManager, ModelInfo, AddModelResult
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
from .model_cache import ModelCache
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
from .model_merge import ModelMerger, MergeInterpolationMethod

View File

@ -2,8 +2,8 @@ from __future__ import annotations
import copy
from contextlib import contextmanager
from typing import Optional, Dict, Tuple, Any, Union, List
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union, List
import torch
from compel.embeddings_provider import BaseTextualInversionManager

View File

@ -234,7 +234,7 @@ import textwrap
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types
from shutil import rmtree
from shutil import rmtree, move
import torch
from omegaconf import OmegaConf
@ -279,7 +279,7 @@ class InvalidModelError(Exception):
pass
class AddModelResult(BaseModel):
name: str = Field(description="The name of the model after import")
name: str = Field(description="The name of the model after installation")
model_type: ModelType = Field(description="The type of model")
base_model: BaseModelType = Field(description="The base model")
config: ModelConfigBase = Field(description="The configuration of the model")
@ -491,17 +491,32 @@ class ModelManager(object):
"""
return [(self.parse_key(x)) for x in self.models.keys()]
def list_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> dict:
"""
Returns a dict describing one installed model, using
the combined format of the list_models() method.
"""
models = self.list_models(base_model,model_type,model_name)
return models[0] if models else None
def list_models(
self,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
model_name: Optional[str] = None,
) -> list[dict]:
"""
Return a list of models.
"""
model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold)
models = []
for model_key in sorted(self.models, key=str.casefold):
for model_key in model_keys:
model_config = self.models[model_key]
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
@ -546,10 +561,7 @@ class ModelManager(object):
model_cfg = self.models.pop(model_key, None)
if model_cfg is None:
self.logger.error(
f"Unknown model {model_key}"
)
return
raise KeyError(f"Unknown model {model_key}")
# note: it not garantie to release memory(model can has other references)
cache_ids = self.cache_keys.pop(model_key, [])
@ -615,6 +627,7 @@ class ModelManager(object):
self.cache.uncache_model(cache_id)
self.models[model_key] = model_config
self.commit()
return AddModelResult(
name = model_name,
model_type = model_type,
@ -622,6 +635,60 @@ class ModelManager(object):
config = model_config,
)
def convert_model (
self,
model_name: str,
base_model: BaseModelType,
model_type: Union[ModelType.Main,ModelType.Vae],
) -> AddModelResult:
'''
Convert a checkpoint file into a diffusers folder, deleting the cached
version and deleting the original checkpoint file if it is in the models
directory.
:param model_name: Name of the model to convert
:param base_model: Base model type
:param model_type: Type of model ['vae' or 'main']
This will raise a ValueError unless the model is a checkpoint.
'''
info = self.model_info(model_name, base_model, model_type)
if info["model_format"] != "checkpoint":
raise ValueError(f"not a checkpoint format model: {model_name}")
# We are taking advantage of a side effect of get_model() that converts check points
# into cached diffusers directories stored at `location`. It doesn't matter
# what submodeltype we request here, so we get the smallest.
submodel = {"submodel_type": SubModelType.Tokenizer} if model_type==ModelType.Main else {}
model = self.get_model(model_name,
base_model,
model_type,
**submodel,
)
checkpoint_path = self.app_config.root_path / info["path"]
old_diffusers_path = self.app_config.models_path / model.location
new_diffusers_path = self.app_config.models_path / base_model.value / model_type.value / model_name
if new_diffusers_path.exists():
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
try:
move(old_diffusers_path,new_diffusers_path)
info["model_format"] = "diffusers"
info["path"] = str(new_diffusers_path.relative_to(self.app_config.root_path))
info.pop('config')
result = self.add_model(model_name, base_model, model_type,
model_attributes = info,
clobber=True)
except:
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
rmtree(new_diffusers_path)
raise
if checkpoint_path.exists() and checkpoint_path.is_relative_to(self.app_config.models_path):
checkpoint_path.unlink()
return result
def search_models(self, search_folder):
self.logger.info(f"Finding Models In: {search_folder}")
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
@ -821,6 +888,10 @@ class ModelManager(object):
The result is a set of successfully installed models. Each element
of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model.
May return the following exceptions:
- KeyError - one or more of the items to import is not a valid path, repo_id or URL
- ValueError - a corresponding model already exists
'''
# avoid circular import here
from invokeai.backend.install.model_install_backend import ModelInstall
@ -830,11 +901,7 @@ class ModelManager(object):
prediction_type_helper = prediction_type_helper,
model_manager = self)
for thing in items_to_import:
try:
installed = installer.heuristic_import(thing)
successfully_installed.update(installed)
except Exception as e:
self.logger.warning(f'{thing} could not be imported: {str(e)}')
installed = installer.heuristic_import(thing)
successfully_installed.update(installed)
self.commit()
return successfully_installed

View File

@ -0,0 +1,129 @@
"""
invokeai.backend.model_management.model_merge exports:
merge_diffusion_models() -- combine multiple models by location and return a pipeline object
merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to models.yaml
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
"""
import warnings
from enum import Enum
from pathlib import Path
from diffusers import DiffusionPipeline
from diffusers import logging as dlogging
from typing import List, Union
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
class MergeInterpolationMethod(str, Enum):
Sigmoid = "sigmoid"
InvSigmoid = "inv_sigmoid"
AddDifference = "add_difference"
WeightedSum = "weighted_sum"
class ModelMerger(object):
def __init__(self, manager: ModelManager):
self.manager = manager
def merge_diffusion_models(
self,
model_paths: List[Path],
alpha: float = 0.5,
interp: MergeInterpolationMethod = None,
force: bool = False,
**kwargs,
) -> DiffusionPipeline:
"""
:param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
:param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
with warnings.catch_warnings():
warnings.simplefilter("ignore")
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
pipe = DiffusionPipeline.from_pretrained(
model_paths[0],
custom_pipeline="checkpoint_merger",
)
merged_pipe = pipe.merge(
pretrained_model_name_or_path_list=model_paths,
alpha=alpha,
interp=interp.value if interp else None, #diffusers API treats None as "weighted sum"
force=force,
**kwargs,
)
dlogging.set_verbosity(verbosity)
return merged_pipe
def merge_diffusion_models_and_save (
self,
model_names: List[str],
base_model: Union[BaseModelType,str],
merged_model_name: str,
alpha: float = 0.5,
interp: MergeInterpolationMethod = None,
force: bool = False,
**kwargs,
) -> AddModelResult:
"""
:param models: up to three models, designated by their InvokeAI models.yaml model name
:param base_model: base model (must be the same for all merged models!)
:param merged_model_name: name for new model
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
model_paths = list()
config = self.manager.app_config
base_model = BaseModelType(base_model)
vae = None
for mod in model_names:
info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main)
assert info, f"model {mod}, base_model {base_model}, is unknown"
assert info["model_format"] == "diffusers", f"{mod} is not a diffusers model. It must be optimized before merging"
assert info["variant"] == "normal", (f"{mod} is a {info['variant']} model, which cannot currently be merged")
# pick up the first model's vae
if mod == model_names[0]:
vae = info.get("vae")
model_paths.extend([config.root_path / info["path"]])
merge_method = None if interp == 'weighted_sum' else MergeInterpolationMethod(interp)
merged_pipe = self.merge_diffusion_models(
model_paths, alpha, merge_method, force, **kwargs
)
dump_path = config.models_path / base_model.value / ModelType.Main.value
dump_path.mkdir(parents=True, exist_ok=True)
dump_path = dump_path / merged_model_name
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
attributes = dict(
path = str(dump_path),
description = f"Merge of models {', '.join(model_names)}",
model_format = "diffusers",
variant = ModelVariantType.Normal.value,
vae = vae,
)
return self.manager.add_model(merged_model_name,
base_model = base_model,
model_type = ModelType.Main,
model_attributes = attributes,
clobber = True
)

View File

@ -116,7 +116,7 @@ class StableDiffusion1Model(DiffusersModel):
version=BaseModelType.StableDiffusion1,
model_config=config,
output_path=output_path,
)
)
else:
return model_path

View File

@ -1,4 +1,5 @@
"""
Initialization file for invokeai.frontend.merge
"""
from .merge_diffusers import main as invokeai_merge_diffusers, merge_diffusion_models
from .merge_diffusers import main as invokeai_merge_diffusers

View File

@ -6,9 +6,7 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
"""
import argparse
import curses
import os
import sys
import warnings
from argparse import Namespace
from pathlib import Path
from typing import List, Union
@ -20,99 +18,15 @@ from npyscreen import widget
from omegaconf import OmegaConf
import invokeai.backend.util.logging as logger
from invokeai.services.config import InvokeAIAppConfig
from ...backend.model_management import ModelManager
from ...frontend.install.widgets import FloatTitleSlider
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import (
ModelMerger, MergeInterpolationMethod,
ModelManager, ModelType, BaseModelType,
)
from invokeai.frontend.install.widgets import FloatTitleSlider, TextBox, SingleSelectColumns
DEST_MERGED_MODEL_DIR = "merged_models"
config = InvokeAIAppConfig.get_config()
def merge_diffusion_models(
model_ids_or_paths: List[Union[str, Path]],
alpha: float = 0.5,
interp: str = None,
force: bool = False,
**kwargs,
) -> DiffusionPipeline:
"""
model_ids_or_paths - up to three models, designated by their local paths or HuggingFace repo_ids
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
with warnings.catch_warnings():
warnings.simplefilter("ignore")
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
pipe = DiffusionPipeline.from_pretrained(
model_ids_or_paths[0],
cache_dir=kwargs.get("cache_dir", config.cache_dir),
custom_pipeline="checkpoint_merger",
)
merged_pipe = pipe.merge(
pretrained_model_name_or_path_list=model_ids_or_paths,
alpha=alpha,
interp=interp,
force=force,
**kwargs,
)
dlogging.set_verbosity(verbosity)
return merged_pipe
def merge_diffusion_models_and_commit(
models: List["str"],
merged_model_name: str,
alpha: float = 0.5,
interp: str = None,
force: bool = False,
**kwargs,
):
"""
models - up to three models, designated by their InvokeAI models.yaml model name
merged_model_name = name for new model
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
interp - The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
config_file = config.model_conf_path
model_manager = ModelManager(OmegaConf.load(config_file))
for mod in models:
assert mod in model_manager.model_names(), f'** Unknown model "{mod}"'
assert (
model_manager.model_info(mod).get("format", None) == "diffusers"
), f"** {mod} is not a diffusers model. It must be optimized before merging."
model_ids_or_paths = [model_manager.model_name_or_path(x) for x in models]
merged_pipe = merge_diffusion_models(
model_ids_or_paths, alpha, interp, force, **kwargs
)
dump_path = config.models_dir / DEST_MERGED_MODEL_DIR
os.makedirs(dump_path, exist_ok=True)
dump_path = dump_path / merged_model_name
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
import_args = dict(
model_name=merged_model_name, description=f'Merge of models {", ".join(models)}'
)
if vae := model_manager.config[models[0]].get("vae", None):
logger.info(f"Using configured VAE assigned to {models[0]}")
import_args.update(vae=vae)
model_manager.import_diffuser_model(dump_path, **import_args)
model_manager.commit(config_file)
def _parse_args() -> Namespace:
parser = argparse.ArgumentParser(description="InvokeAI model merging")
parser.add_argument(
@ -131,10 +45,17 @@ def _parse_args() -> Namespace:
)
parser.add_argument(
"--models",
dest="model_names",
type=str,
nargs="+",
help="Two to three model names to be merged",
)
parser.add_argument(
"--base_model",
type=str,
choices=[x.value for x in BaseModelType],
help="The base model shared by the models to be merged",
)
parser.add_argument(
"--merged_model_name",
"--destination",
@ -192,6 +113,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
window_height, window_width = curses.initscr().getmaxyx()
self.model_names = self.get_model_names()
self.current_base = 0
max_width = max([len(x) for x in self.model_names])
max_width += 6
horizontal_layout = max_width * 3 < window_width
@ -208,12 +130,26 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
value="Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next.",
editable=False,
)
self.nextrely += 1
self.base_select = self.add_widget_intelligent(
SingleSelectColumns,
values=[
'Models Built on SD-1.x',
'Models Built on SD-2.x',
],
value=[self.current_base],
columns = 4,
max_height = 2,
relx=8,
scroll_exit = True,
)
self.base_select.on_changed = self._populate_models
self.add_widget_intelligent(
npyscreen.FixedText,
value="MODEL 1",
color="GOOD",
editable=False,
rely=4 if horizontal_layout else None,
rely=6 if horizontal_layout else None,
)
self.model1 = self.add_widget_intelligent(
npyscreen.SelectOne,
@ -222,7 +158,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
max_height=len(self.model_names),
max_width=max_width,
scroll_exit=True,
rely=5,
rely=7,
)
self.add_widget_intelligent(
npyscreen.FixedText,
@ -230,7 +166,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
color="GOOD",
editable=False,
relx=max_width + 3 if horizontal_layout else None,
rely=4 if horizontal_layout else None,
rely=6 if horizontal_layout else None,
)
self.model2 = self.add_widget_intelligent(
npyscreen.SelectOne,
@ -240,7 +176,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
max_height=len(self.model_names),
max_width=max_width,
relx=max_width + 3 if horizontal_layout else None,
rely=5 if horizontal_layout else None,
rely=7 if horizontal_layout else None,
scroll_exit=True,
)
self.add_widget_intelligent(
@ -249,7 +185,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
color="GOOD",
editable=False,
relx=max_width * 2 + 3 if horizontal_layout else None,
rely=4 if horizontal_layout else None,
rely=6 if horizontal_layout else None,
)
models_plus_none = self.model_names.copy()
models_plus_none.insert(0, "None")
@ -262,24 +198,26 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
max_width=max_width,
scroll_exit=True,
relx=max_width * 2 + 3 if horizontal_layout else None,
rely=5 if horizontal_layout else None,
rely=7 if horizontal_layout else None,
)
for m in [self.model1, self.model2, self.model3]:
m.when_value_edited = self.models_changed
self.merged_model_name = self.add_widget_intelligent(
npyscreen.TitleText,
TextBox,
name="Name for merged model:",
labelColor="CONTROL",
max_height=3,
value="",
scroll_exit=True,
)
self.force = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Force merge of incompatible models",
name="Force merge of models created by different diffusers library versions",
labelColor="CONTROL",
value=False,
value=True,
scroll_exit=True,
)
self.nextrely += 1
self.merge_method = self.add_widget_intelligent(
npyscreen.TitleSelectOne,
name="Merge Method:",
@ -341,7 +279,8 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
interp = self.interpolations[self.merge_method.value[0]]
args = dict(
models=models,
model_names=models,
base_model=tuple(BaseModelType)[self.base_select.value[0]],
alpha=self.alpha.value,
interp=interp,
force=self.force.value,
@ -379,21 +318,30 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
else:
return True
def get_model_names(self) -> List[str]:
def get_model_names(self, base_model: BaseModelType=None) -> List[str]:
model_names = [
name
for name in self.model_manager.model_names()
if self.model_manager.model_info(name).get("format") == "diffusers"
info["name"]
for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model)
if info["model_format"] == "diffusers"
]
return sorted(model_names)
def _populate_models(self,value=None):
base_model = tuple(BaseModelType)[value[0]]
self.model_names = self.get_model_names(base_model)
models_plus_none = self.model_names.copy()
models_plus_none.insert(0, "None")
self.model1.values = self.model_names
self.model2.values = self.model_names
self.model3.values = models_plus_none
self.display()
class Mergeapp(npyscreen.NPSAppManaged):
def __init__(self):
def __init__(self, model_manager:ModelManager):
super().__init__()
conf = OmegaConf.load(config.model_conf_path)
self.model_manager = ModelManager(
conf, "cpu", "float16"
) # precision doesn't really matter here
self.model_manager = model_manager
def onStart(self):
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
@ -401,44 +349,41 @@ class Mergeapp(npyscreen.NPSAppManaged):
def run_gui(args: Namespace):
mergeapp = Mergeapp()
model_manager = ModelManager(config.model_conf_path)
mergeapp = Mergeapp(model_manager)
mergeapp.run()
args = mergeapp.merge_arguments
merge_diffusion_models_and_commit(**args)
merger = ModelMerger(model_manager)
merger.merge_diffusion_models_and_save(**args)
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
def run_cli(args: Namespace):
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
assert (
args.models and len(args.models) >= 1 and len(args.models) <= 3
args.model_names and len(args.model_names) >= 1 and len(args.model_names) <= 3
), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage."
if not args.merged_model_name:
args.merged_model_name = "+".join(args.models)
args.merged_model_name = "+".join(args.model_names)
logger.info(
f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
)
model_manager = ModelManager(OmegaConf.load(config.model_conf_path))
assert (
args.clobber or args.merged_model_name not in model_manager.model_names()
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
model_manager = ModelManager(config.model_conf_path)
assert (
not model_manager.model_exists(args.merged_model_name, args.base_model, ModelType.Main) or args.clobber
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
merge_diffusion_models_and_commit(**vars(args))
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
merger = ModelMerger(model_manager)
merger.merge_diffusion_models_and_save(**vars(args))
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
def main():
args = _parse_args()
config.root = args.root_dir
cache_dir = config.cache_dir
os.environ[
"HF_HOME"
] = cache_dir # because not clear the merge pipeline is honoring cache_dir
args.cache_dir = cache_dir
config.parse_args(['--root',str(args.root_dir)])
try:
if args.front_end:

View File

@ -1,6 +1,5 @@
import { log } from 'app/logging/useLogger';
import { appSocketConnected, socketConnected } from 'services/events/actions';
import { receivedPageOfImages } from 'services/api/thunks/image';
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { startAppListening } from '../..';
@ -14,19 +13,10 @@ export const addSocketConnectedEventListener = () => {
moduleLog.debug({ timestamp }, 'Connected');
const { nodes, config, gallery } = getState();
const { nodes, config } = getState();
const { disabledTabs } = config;
if (!gallery.ids.length) {
dispatch(
receivedPageOfImages({
categories: ['general'],
is_intermediate: false,
})
);
}
if (!nodes.schema && !disabledTabs.includes('nodes')) {
dispatch(receivedOpenAPISchema());
}

View File

@ -6,10 +6,15 @@ import { validateSeedWeights } from 'common/util/seedWeightPairs';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { systemSelector } from 'features/system/store/systemSelectors';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import {
modelsApi,
useGetMainModelsQuery,
} from '../../services/api/endpoints/models';
const readinessSelector = createSelector(
[stateSelector, activeTabNameSelector],
({ generation, system, batch }, activeTabName) => {
(state, activeTabName) => {
const { generation, system, batch } = state;
const { shouldGenerateVariations, seedWeights, initialImage, seed } =
generation;
@ -32,6 +37,13 @@ const readinessSelector = createSelector(
reasonsWhyNotReady.push('No initial image selected');
}
const { isSuccess: mainModelsSuccessfullyLoaded } =
modelsApi.endpoints.getMainModels.select()(state);
if (!mainModelsSuccessfullyLoaded) {
isReady = false;
reasonsWhyNotReady.push('Models are not loaded');
}
// TODO: job queue
// Cannot generate if already processing an image
if (isProcessing) {

View File

@ -182,6 +182,15 @@ const ImageGalleryContent = () => {
return () => osInstance()?.destroy();
}, [scroller, initialize, osInstance]);
useEffect(() => {
dispatch(
receivedPageOfImages({
categories: ['general'],
is_intermediate: false,
})
);
}, [dispatch]);
const handleClickImagesCategory = useCallback(() => {
dispatch(imageCategoriesChanged(IMAGE_CATEGORIES));
dispatch(setGalleryView('images'));