mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into patch-1
This commit is contained in:
commit
80575344fc
@ -1,75 +1,30 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein
|
||||
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from fastapi import Query, Body
|
||||
from fastapi.routing import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field, parse_obj_as
|
||||
from ..dependencies import ApiDependencies
|
||||
from typing import Literal, List, Optional, Union
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, parse_obj_as
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from invokeai.backend import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management import AddModelResult
|
||||
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
|
||||
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
from invokeai.backend.model_management.models import (
|
||||
OPENAPI_MODEL_CONFIGS,
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from invokeai.backend.model_management import MergeInterpolationMethod
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
class VaeRepo(BaseModel):
|
||||
repo_id: str = Field(description="The repo ID to use for this VAE")
|
||||
path: Optional[str] = Field(description="The path to the VAE")
|
||||
subfolder: Optional[str] = Field(description="The subfolder to use for this VAE")
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
description: Optional[str] = Field(description="A description of the model")
|
||||
model_name: str = Field(description="The name of the model")
|
||||
model_type: str = Field(description="The type of the model")
|
||||
|
||||
class DiffusersModelInfo(ModelInfo):
|
||||
format: Literal['folder'] = 'folder'
|
||||
|
||||
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
|
||||
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
|
||||
path: Optional[str] = Field(description="The path to the model")
|
||||
|
||||
class CkptModelInfo(ModelInfo):
|
||||
format: Literal['ckpt'] = 'ckpt'
|
||||
|
||||
config: str = Field(description="The path to the model config")
|
||||
weights: str = Field(description="The path to the model weights")
|
||||
vae: str = Field(description="The path to the model VAE")
|
||||
width: Optional[int] = Field(description="The width of the model")
|
||||
height: Optional[int] = Field(description="The height of the model")
|
||||
|
||||
class SafetensorsModelInfo(CkptModelInfo):
|
||||
format: Literal['safetensors'] = 'safetensors'
|
||||
|
||||
class CreateModelRequest(BaseModel):
|
||||
name: str = Field(description="The name of the model")
|
||||
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
|
||||
|
||||
class CreateModelResponse(BaseModel):
|
||||
name: str = Field(description="The name of the new model")
|
||||
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
|
||||
status: str = Field(description="The status of the API response")
|
||||
|
||||
class ImportModelResponse(BaseModel):
|
||||
name: str = Field(description="The name of the imported model")
|
||||
# base_model: str = Field(description="The base model")
|
||||
# model_type: str = Field(description="The model type")
|
||||
info: AddModelResult = Field(description="The model info")
|
||||
status: str = Field(description="The status of the API response")
|
||||
|
||||
class ConversionRequest(BaseModel):
|
||||
name: str = Field(description="The name of the new model")
|
||||
info: CkptModelInfo = Field(description="The converted model info")
|
||||
save_location: str = Field(description="The path to save the converted model weights")
|
||||
|
||||
class ConvertedModelResponse(BaseModel):
|
||||
name: str = Field(description="The name of the new model")
|
||||
info: DiffusersModelInfo = Field(description="The converted model info")
|
||||
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
models: list[MODEL_CONFIGS]
|
||||
|
||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||
|
||||
@models_router.get(
|
||||
"/",
|
||||
@ -77,75 +32,103 @@ class ModelsList(BaseModel):
|
||||
responses={200: {"model": ModelsList }},
|
||||
)
|
||||
async def list_models(
|
||||
base_model: Optional[BaseModelType] = Query(
|
||||
default=None, description="Base model"
|
||||
),
|
||||
model_type: Optional[ModelType] = Query(
|
||||
default=None, description="The type of model to get"
|
||||
),
|
||||
base_model: Optional[BaseModelType] = Query(default=None, description="Base model"),
|
||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||
) -> ModelsList:
|
||||
"""Gets a list of models"""
|
||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)
|
||||
models = parse_obj_as(ModelsList, { "models": models_raw })
|
||||
return models
|
||||
|
||||
@models_router.post(
|
||||
"/",
|
||||
@models_router.patch(
|
||||
"/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="update_model",
|
||||
responses={200: {"status": "success"}},
|
||||
responses={200: {"description" : "The model was updated successfully"},
|
||||
404: {"description" : "The model could not be found"},
|
||||
400: {"description" : "Bad request"}
|
||||
},
|
||||
status_code = 200,
|
||||
response_model = UpdateModelResponse,
|
||||
)
|
||||
async def update_model(
|
||||
model_request: CreateModelRequest
|
||||
) -> CreateModelResponse:
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||
) -> UpdateModelResponse:
|
||||
""" Add Model """
|
||||
model_request_info = model_request.info
|
||||
info_dict = model_request_info.dict()
|
||||
model_response = CreateModelResponse(name=model_request.name, info=model_request.info, status="success")
|
||||
|
||||
ApiDependencies.invoker.services.model_manager.add_model(
|
||||
model_name=model_request.name,
|
||||
model_attributes=info_dict,
|
||||
clobber=True,
|
||||
)
|
||||
try:
|
||||
ApiDependencies.invoker.services.model_manager.update_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_attributes=info.dict()
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
model_response = parse_obj_as(UpdateModelResponse, model_raw)
|
||||
except KeyError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return model_response
|
||||
|
||||
@models_router.post(
|
||||
"/import",
|
||||
"/",
|
||||
operation_id="import_model",
|
||||
responses= {
|
||||
201: {"description" : "The model imported successfully"},
|
||||
404: {"description" : "The model could not be found"},
|
||||
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
|
||||
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImportModelResponse
|
||||
)
|
||||
async def import_model(
|
||||
name: str = Query(description="A model path, repo_id or URL to import"),
|
||||
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = Query(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
||||
location: str = Body(description="A model path, repo_id or URL to import"),
|
||||
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
|
||||
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
||||
) -> ImportModelResponse:
|
||||
""" Add a model using its local path, repo_id, or remote URL """
|
||||
items_to_import = {name}
|
||||
|
||||
items_to_import = {location}
|
||||
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||
items_to_import = items_to_import,
|
||||
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
|
||||
)
|
||||
if info := installed_models.get(name):
|
||||
logger.info(f'Successfully imported {name}, got {info}')
|
||||
return ImportModelResponse(
|
||||
name = name,
|
||||
info = info,
|
||||
status = "success",
|
||||
|
||||
try:
|
||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||
items_to_import = items_to_import,
|
||||
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
|
||||
)
|
||||
else:
|
||||
logger.error(f'Model {name} not imported')
|
||||
raise HTTPException(status_code=404, detail=f'Model {name} not found')
|
||||
info = installed_models.get(location)
|
||||
|
||||
if not info:
|
||||
logger.error("Import failed")
|
||||
raise HTTPException(status_code=424)
|
||||
|
||||
logger.info(f'Successfully imported {location}, got {info}')
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=info.name,
|
||||
base_model=info.base_model,
|
||||
model_type=info.model_type
|
||||
)
|
||||
return parse_obj_as(ImportModelResponse, model_raw)
|
||||
|
||||
except KeyError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
|
||||
@models_router.delete(
|
||||
"/{model_name}",
|
||||
"/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="del_model",
|
||||
responses={
|
||||
204: {
|
||||
@ -156,144 +139,95 @@ async def import_model(
|
||||
}
|
||||
},
|
||||
)
|
||||
async def delete_model(model_name: str) -> None:
|
||||
async def delete_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
) -> Response:
|
||||
"""Delete Model"""
|
||||
model_names = ApiDependencies.invoker.services.model_manager.model_names()
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
model_exists = model_name in model_names
|
||||
|
||||
# check if model exists
|
||||
logger.info(f"Checking for model {model_name}...")
|
||||
|
||||
if model_exists:
|
||||
logger.info(f"Deleting Model: {model_name}")
|
||||
ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True)
|
||||
logger.info(f"Model Deleted: {model_name}")
|
||||
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
|
||||
|
||||
else:
|
||||
logger.error("Model not found")
|
||||
try:
|
||||
ApiDependencies.invoker.services.model_manager.del_model(model_name,
|
||||
base_model = base_model,
|
||||
model_type = model_type
|
||||
)
|
||||
logger.info(f"Deleted model: {model_name}")
|
||||
return Response(status_code=204)
|
||||
except KeyError:
|
||||
logger.error(f"Model not found: {model_name}")
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||
|
||||
|
||||
# @socketio.on("convertToDiffusers")
|
||||
# def convert_to_diffusers(model_to_convert: dict):
|
||||
# try:
|
||||
# if model_info := self.generate.model_manager.model_info(
|
||||
# model_name=model_to_convert["model_name"]
|
||||
# ):
|
||||
# if "weights" in model_info:
|
||||
# ckpt_path = Path(model_info["weights"])
|
||||
# original_config_file = Path(model_info["config"])
|
||||
# model_name = model_to_convert["model_name"]
|
||||
# model_description = model_info["description"]
|
||||
# else:
|
||||
# self.socketio.emit(
|
||||
# "error", {"message": "Model is not a valid checkpoint file"}
|
||||
# )
|
||||
# else:
|
||||
# self.socketio.emit(
|
||||
# "error", {"message": "Could not retrieve model info."}
|
||||
# )
|
||||
|
||||
# if not ckpt_path.is_absolute():
|
||||
# ckpt_path = Path(Globals.root, ckpt_path)
|
||||
|
||||
# if original_config_file and not original_config_file.is_absolute():
|
||||
# original_config_file = Path(Globals.root, original_config_file)
|
||||
|
||||
# diffusers_path = Path(
|
||||
# ckpt_path.parent.absolute(), f"{model_name}_diffusers"
|
||||
# )
|
||||
|
||||
# if model_to_convert["save_location"] == "root":
|
||||
# diffusers_path = Path(
|
||||
# global_converted_ckpts_dir(), f"{model_name}_diffusers"
|
||||
# )
|
||||
|
||||
# if (
|
||||
# model_to_convert["save_location"] == "custom"
|
||||
# and model_to_convert["custom_location"] is not None
|
||||
# ):
|
||||
# diffusers_path = Path(
|
||||
# model_to_convert["custom_location"], f"{model_name}_diffusers"
|
||||
# )
|
||||
|
||||
# if diffusers_path.exists():
|
||||
# shutil.rmtree(diffusers_path)
|
||||
|
||||
# self.generate.model_manager.convert_and_import(
|
||||
# ckpt_path,
|
||||
# diffusers_path,
|
||||
# model_name=model_name,
|
||||
# model_description=model_description,
|
||||
# vae=None,
|
||||
# original_config_file=original_config_file,
|
||||
# commit_to_conf=opt.conf,
|
||||
# )
|
||||
|
||||
# new_model_list = self.generate.model_manager.list_models()
|
||||
# socketio.emit(
|
||||
# "modelConverted",
|
||||
# {
|
||||
# "new_model_name": model_name,
|
||||
# "model_list": new_model_list,
|
||||
# "update": True,
|
||||
# },
|
||||
# )
|
||||
# print(f">> Model Converted: {model_name}")
|
||||
# except Exception as e:
|
||||
# self.handle_exceptions(e)
|
||||
|
||||
# @socketio.on("mergeDiffusersModels")
|
||||
# def merge_diffusers_models(model_merge_info: dict):
|
||||
# try:
|
||||
# models_to_merge = model_merge_info["models_to_merge"]
|
||||
# model_ids_or_paths = [
|
||||
# self.generate.model_manager.model_name_or_path(x)
|
||||
# for x in models_to_merge
|
||||
# ]
|
||||
# merged_pipe = merge_diffusion_models(
|
||||
# model_ids_or_paths,
|
||||
# model_merge_info["alpha"],
|
||||
# model_merge_info["interp"],
|
||||
# model_merge_info["force"],
|
||||
# )
|
||||
|
||||
# dump_path = global_models_dir() / "merged_models"
|
||||
# if model_merge_info["model_merge_save_path"] is not None:
|
||||
# dump_path = Path(model_merge_info["model_merge_save_path"])
|
||||
|
||||
# os.makedirs(dump_path, exist_ok=True)
|
||||
# dump_path = dump_path / model_merge_info["merged_model_name"]
|
||||
# merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||
|
||||
# merged_model_config = dict(
|
||||
# model_name=model_merge_info["merged_model_name"],
|
||||
# description=f'Merge of models {", ".join(models_to_merge)}',
|
||||
# commit_to_conf=opt.conf,
|
||||
# )
|
||||
|
||||
# if vae := self.generate.model_manager.config[models_to_merge[0]].get(
|
||||
# "vae", None
|
||||
# ):
|
||||
# print(f">> Using configured VAE assigned to {models_to_merge[0]}")
|
||||
# merged_model_config.update(vae=vae)
|
||||
|
||||
# self.generate.model_manager.import_diffuser_model(
|
||||
# dump_path, **merged_model_config
|
||||
# )
|
||||
# new_model_list = self.generate.model_manager.list_models()
|
||||
|
||||
# socketio.emit(
|
||||
# "modelsMerged",
|
||||
# {
|
||||
# "merged_models": models_to_merge,
|
||||
# "merged_model_name": model_merge_info["merged_model_name"],
|
||||
# "model_list": new_model_list,
|
||||
# "update": True,
|
||||
# },
|
||||
# )
|
||||
# print(f">> Models Merged: {models_to_merge}")
|
||||
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
|
||||
# except Exception as e:
|
||||
@models_router.put(
|
||||
"/convert/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="convert_model",
|
||||
responses={
|
||||
200: { "description": "Model converted successfully" },
|
||||
400: {"description" : "Bad request" },
|
||||
404: { "description": "Model not found" },
|
||||
},
|
||||
status_code = 200,
|
||||
response_model = ConvertModelResponse,
|
||||
)
|
||||
async def convert_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
) -> ConvertModelResponse:
|
||||
"""Convert a checkpoint model into a diffusers model"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Converting model: {model_name}")
|
||||
ApiDependencies.invoker.services.model_manager.convert_model(model_name,
|
||||
base_model = base_model,
|
||||
model_type = model_type
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name,
|
||||
base_model = base_model,
|
||||
model_type = model_type)
|
||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||
except KeyError:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
|
||||
@models_router.put(
|
||||
"/merge/{base_model}",
|
||||
operation_id="merge_models",
|
||||
responses={
|
||||
200: { "description": "Model converted successfully" },
|
||||
400: { "description": "Incompatible models" },
|
||||
404: { "description": "One or more models not found" },
|
||||
},
|
||||
status_code = 200,
|
||||
response_model = MergeModelResponse,
|
||||
)
|
||||
async def merge_models(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
|
||||
merged_model_name: Optional[str] = Body(description = "Name of destination model"),
|
||||
alpha: Optional[float] = Body(description = "Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
interp: Union[MergeInterpolationMethod, None] = Body(description = "Interpolation method"),
|
||||
force: Optional[bool] = Body(description = "Force merging of models created with different versions of diffusers", default=False),
|
||||
) -> MergeModelResponse:
|
||||
"""Convert a checkpoint model into a diffusers model"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Merging models: {model_names}")
|
||||
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
|
||||
base_model,
|
||||
merged_model_name or "+".join(model_names),
|
||||
alpha,
|
||||
interp,
|
||||
force)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
|
||||
base_model = base_model,
|
||||
model_type = ModelType.Main,
|
||||
)
|
||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||
except KeyError:
|
||||
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
|
@ -2,22 +2,29 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING
|
||||
from dataclasses import dataclass
|
||||
from pydantic import Field
|
||||
from typing import Optional, Union, Callable, List, Tuple, TYPE_CHECKING
|
||||
from types import ModuleType
|
||||
|
||||
from invokeai.backend.model_management.model_manager import (
|
||||
from invokeai.backend.model_management import (
|
||||
ModelManager,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
ModelInfo,
|
||||
AddModelResult,
|
||||
SchedulerPredictionType,
|
||||
ModelMerger,
|
||||
MergeInterpolationMethod,
|
||||
)
|
||||
|
||||
|
||||
import torch
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from .config import InvokeAIAppConfig
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
from .config import InvokeAIAppConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
|
||||
@ -30,7 +37,7 @@ class ModelManagerServiceBase(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
logger: types.ModuleType,
|
||||
logger: ModuleType,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
@ -73,13 +80,7 @@ class ModelManagerServiceBase(ABC):
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns a list of all the model names known.
|
||||
Uses the exact format as the omegaconf stanza.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -101,7 +102,20 @@ class ModelManagerServiceBase(ABC):
|
||||
}
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Return information about the model using the same format as list_models()
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns a list of all the model names known.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_model(
|
||||
@ -111,7 +125,7 @@ class ModelManagerServiceBase(ABC):
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False
|
||||
) -> None:
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
@ -121,6 +135,24 @@ class ModelManagerServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with a
|
||||
KeyErrorException if the name does not already exist.
|
||||
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def del_model(
|
||||
self,
|
||||
@ -135,11 +167,32 @@ class ModelManagerServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def convert_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Union[ModelType.Main,ModelType.Vae],
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
version and deleting the original checkpoint file if it is in the models
|
||||
directory.
|
||||
:param model_name: Name of the model to convert
|
||||
:param base_model: Base model type
|
||||
:param model_type: Type of model ['vae' or 'main']
|
||||
|
||||
This will raise a ValueError unless the model is not a checkpoint. It will
|
||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||
directory already in place.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def heuristic_import(self,
|
||||
items_to_import: Set[str],
|
||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||
)->Dict[str, AddModelResult]:
|
||||
items_to_import: set[str],
|
||||
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
|
||||
)->dict[str, AddModelResult]:
|
||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
@ -159,7 +212,27 @@ class ModelManagerServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def commit(self, conf_file: Path = None) -> None:
|
||||
def merge_models(
|
||||
self,
|
||||
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
|
||||
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
:param model_names: List of 2-3 models to merge
|
||||
:param base_model: Base model to use for all models
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
If no conf_file is provided, then replaces the
|
||||
@ -173,7 +246,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
logger: types.ModuleType,
|
||||
logger: ModuleType,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
@ -299,12 +372,19 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None
|
||||
) -> list[dict]:
|
||||
# ) -> dict:
|
||||
"""
|
||||
Return a list of models.
|
||||
"""
|
||||
return self.mgr.list_models(base_model, model_type)
|
||||
|
||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Return information about the model using the same format as list_models()
|
||||
"""
|
||||
return self.mgr.list_model(model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type)
|
||||
|
||||
def add_model(
|
||||
self,
|
||||
model_name: str,
|
||||
@ -320,9 +400,28 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
self.logger.debug(f'add/update model {model_name}')
|
||||
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
|
||||
|
||||
|
||||
def update_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with a
|
||||
KeyError exception if the name does not already exist.
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
self.logger.debug(f'update model {model_name}')
|
||||
if not self.model_exists(model_name, base_model, model_type):
|
||||
raise KeyError(f"Unknown model {model_name}")
|
||||
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
|
||||
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
@ -334,8 +433,29 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
as well. Call commit() to write to disk.
|
||||
"""
|
||||
self.logger.debug(f'delete model {model_name}')
|
||||
self.mgr.del_model(model_name, base_model, model_type)
|
||||
|
||||
def convert_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Union[ModelType.Main,ModelType.Vae],
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
version and deleting the original checkpoint file if it is in the models
|
||||
directory.
|
||||
:param model_name: Name of the model to convert
|
||||
:param base_model: Base model type
|
||||
:param model_type: Type of model ['vae' or 'main']
|
||||
|
||||
This will raise a ValueError unless the model is not a checkpoint. It will
|
||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||
directory already in place.
|
||||
"""
|
||||
self.logger.debug(f'convert model {model_name}')
|
||||
return self.mgr.convert_model(model_name, base_model, model_type)
|
||||
|
||||
def commit(self, conf_file: Optional[Path]=None):
|
||||
"""
|
||||
@ -387,9 +507,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
return self.mgr.logger
|
||||
|
||||
def heuristic_import(self,
|
||||
items_to_import: Set[str],
|
||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||
)->Dict[str, AddModelResult]:
|
||||
items_to_import: set[str],
|
||||
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
|
||||
)->dict[str, AddModelResult]:
|
||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
@ -406,4 +526,31 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
'''
|
||||
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
||||
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
||||
|
||||
def merge_models(
|
||||
self,
|
||||
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
|
||||
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
:param model_names: List of 2-3 models to merge
|
||||
:param base_model: Base model to use for all models
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
"""
|
||||
merger = ModelMerger(self.mgr)
|
||||
return merger.merge_diffusion_models_and_save(
|
||||
model_names = model_names,
|
||||
base_model = base_model,
|
||||
merged_model_name = merged_model_name,
|
||||
alpha = alpha,
|
||||
interp = interp,
|
||||
force = force,
|
||||
)
|
||||
|
@ -166,14 +166,18 @@ class ModelInstall(object):
|
||||
# add requested models
|
||||
for path in selections.install_models:
|
||||
logger.info(f'Installing {path} [{job}/{jobs}]')
|
||||
self.heuristic_import(path)
|
||||
try:
|
||||
self.heuristic_import(path)
|
||||
except (ValueError, KeyError) as e:
|
||||
logger.error(str(e))
|
||||
job += 1
|
||||
|
||||
self.mgr.commit()
|
||||
|
||||
def heuristic_import(self,
|
||||
model_path_id_or_url: Union[str,Path],
|
||||
models_installed: Set[Path]=None)->Dict[str, AddModelResult]:
|
||||
model_path_id_or_url: Union[str,Path],
|
||||
models_installed: Set[Path]=None,
|
||||
)->Dict[str, AddModelResult]:
|
||||
'''
|
||||
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
|
||||
:param models_installed: Set of installed models, used for recursive invocation
|
||||
@ -187,61 +191,53 @@ class ModelInstall(object):
|
||||
self.current_id = model_path_id_or_url
|
||||
path = Path(model_path_id_or_url)
|
||||
|
||||
try:
|
||||
# checkpoint file, or similar
|
||||
if path.is_file():
|
||||
models_installed.update(self._install_path(path))
|
||||
# checkpoint file, or similar
|
||||
if path.is_file():
|
||||
models_installed.update({str(path):self._install_path(path)})
|
||||
|
||||
# folders style or similar
|
||||
elif path.is_dir() and any([(path/x).exists() for x in \
|
||||
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
|
||||
]
|
||||
):
|
||||
models_installed.update(self._install_path(path))
|
||||
# folders style or similar
|
||||
elif path.is_dir() and any([(path/x).exists() for x in \
|
||||
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
|
||||
]
|
||||
):
|
||||
models_installed.update(self._install_path(path))
|
||||
|
||||
# recursive scan
|
||||
elif path.is_dir():
|
||||
for child in path.iterdir():
|
||||
self.heuristic_import(child, models_installed=models_installed)
|
||||
# recursive scan
|
||||
elif path.is_dir():
|
||||
for child in path.iterdir():
|
||||
self.heuristic_import(child, models_installed=models_installed)
|
||||
|
||||
# huggingface repo
|
||||
elif len(str(path).split('/')) == 2:
|
||||
models_installed.update(self._install_repo(str(path)))
|
||||
# huggingface repo
|
||||
elif len(str(model_path_id_or_url).split('/')) == 2:
|
||||
models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
|
||||
|
||||
# a URL
|
||||
elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")):
|
||||
models_installed.update(self._install_url(model_path_id_or_url))
|
||||
# a URL
|
||||
elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")):
|
||||
models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
|
||||
|
||||
else:
|
||||
logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
else:
|
||||
raise KeyError(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
|
||||
|
||||
return models_installed
|
||||
|
||||
# install a model from a local path. The optional info parameter is there to prevent
|
||||
# the model from being probed twice in the event that it has already been probed.
|
||||
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Dict[str, AddModelResult]:
|
||||
try:
|
||||
model_result = None
|
||||
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
||||
model_name = path.stem if path.is_file() else path.name
|
||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||
raise ValueError(f'A model named "{model_name}" is already installed.')
|
||||
attributes = self._make_attributes(path,info)
|
||||
model_result = self.mgr.add_model(model_name = model_name,
|
||||
base_model = info.base_type,
|
||||
model_type = info.model_type,
|
||||
model_attributes = attributes,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f'{str(e)} Skipping registration.')
|
||||
return {}
|
||||
return {str(path): model_result}
|
||||
def _install_path(self, path: Path, info: ModelProbeInfo=None)->AddModelResult:
|
||||
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
||||
if not info:
|
||||
logger.warning(f'Unable to parse format of {path}')
|
||||
return None
|
||||
model_name = path.stem if path.is_file() else path.name
|
||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||
raise ValueError(f'A model named "{model_name}" is already installed.')
|
||||
attributes = self._make_attributes(path,info)
|
||||
return self.mgr.add_model(model_name = model_name,
|
||||
base_model = info.base_type,
|
||||
model_type = info.model_type,
|
||||
model_attributes = attributes,
|
||||
)
|
||||
|
||||
def _install_url(self, url: str)->dict:
|
||||
# copy to a staging area, probe, import and delete
|
||||
def _install_url(self, url: str)->AddModelResult:
|
||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||
location = download_with_resume(url,Path(staging))
|
||||
if not location:
|
||||
@ -253,7 +249,7 @@ class ModelInstall(object):
|
||||
# staged version will be garbage-collected at this time
|
||||
return self._install_path(Path(models_path), info)
|
||||
|
||||
def _install_repo(self, repo_id: str)->dict:
|
||||
def _install_repo(self, repo_id: str)->AddModelResult:
|
||||
hinfo = HfApi().model_info(repo_id)
|
||||
|
||||
# we try to figure out how to download this most economically
|
||||
|
@ -1,7 +1,8 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_management
|
||||
"""
|
||||
from .model_manager import ModelManager, ModelInfo, AddModelResult
|
||||
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
|
||||
from .model_cache import ModelCache
|
||||
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
|
||||
from .model_merge import ModelMerger, MergeInterpolationMethod
|
||||
|
||||
|
@ -2,8 +2,8 @@ from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Dict, Tuple, Any, Union, List
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Tuple, Union, List
|
||||
|
||||
import torch
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
|
@ -234,7 +234,7 @@ import textwrap
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types
|
||||
from shutil import rmtree
|
||||
from shutil import rmtree, move
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
@ -279,7 +279,7 @@ class InvalidModelError(Exception):
|
||||
pass
|
||||
|
||||
class AddModelResult(BaseModel):
|
||||
name: str = Field(description="The name of the model after import")
|
||||
name: str = Field(description="The name of the model after installation")
|
||||
model_type: ModelType = Field(description="The type of model")
|
||||
base_model: BaseModelType = Field(description="The base model")
|
||||
config: ModelConfigBase = Field(description="The configuration of the model")
|
||||
@ -491,17 +491,32 @@ class ModelManager(object):
|
||||
"""
|
||||
return [(self.parse_key(x)) for x in self.models.keys()]
|
||||
|
||||
def list_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> dict:
|
||||
"""
|
||||
Returns a dict describing one installed model, using
|
||||
the combined format of the list_models() method.
|
||||
"""
|
||||
models = self.list_models(base_model,model_type,model_name)
|
||||
return models[0] if models else None
|
||||
|
||||
def list_models(
|
||||
self,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
model_name: Optional[str] = None,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Return a list of models.
|
||||
"""
|
||||
|
||||
model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold)
|
||||
models = []
|
||||
for model_key in sorted(self.models, key=str.casefold):
|
||||
for model_key in model_keys:
|
||||
model_config = self.models[model_key]
|
||||
|
||||
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
||||
@ -546,10 +561,7 @@ class ModelManager(object):
|
||||
model_cfg = self.models.pop(model_key, None)
|
||||
|
||||
if model_cfg is None:
|
||||
self.logger.error(
|
||||
f"Unknown model {model_key}"
|
||||
)
|
||||
return
|
||||
raise KeyError(f"Unknown model {model_key}")
|
||||
|
||||
# note: it not garantie to release memory(model can has other references)
|
||||
cache_ids = self.cache_keys.pop(model_key, [])
|
||||
@ -615,6 +627,7 @@ class ModelManager(object):
|
||||
self.cache.uncache_model(cache_id)
|
||||
|
||||
self.models[model_key] = model_config
|
||||
self.commit()
|
||||
return AddModelResult(
|
||||
name = model_name,
|
||||
model_type = model_type,
|
||||
@ -622,6 +635,60 @@ class ModelManager(object):
|
||||
config = model_config,
|
||||
)
|
||||
|
||||
def convert_model (
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Union[ModelType.Main,ModelType.Vae],
|
||||
) -> AddModelResult:
|
||||
'''
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
version and deleting the original checkpoint file if it is in the models
|
||||
directory.
|
||||
:param model_name: Name of the model to convert
|
||||
:param base_model: Base model type
|
||||
:param model_type: Type of model ['vae' or 'main']
|
||||
|
||||
This will raise a ValueError unless the model is a checkpoint.
|
||||
'''
|
||||
info = self.model_info(model_name, base_model, model_type)
|
||||
if info["model_format"] != "checkpoint":
|
||||
raise ValueError(f"not a checkpoint format model: {model_name}")
|
||||
|
||||
# We are taking advantage of a side effect of get_model() that converts check points
|
||||
# into cached diffusers directories stored at `location`. It doesn't matter
|
||||
# what submodeltype we request here, so we get the smallest.
|
||||
submodel = {"submodel_type": SubModelType.Tokenizer} if model_type==ModelType.Main else {}
|
||||
model = self.get_model(model_name,
|
||||
base_model,
|
||||
model_type,
|
||||
**submodel,
|
||||
)
|
||||
checkpoint_path = self.app_config.root_path / info["path"]
|
||||
old_diffusers_path = self.app_config.models_path / model.location
|
||||
new_diffusers_path = self.app_config.models_path / base_model.value / model_type.value / model_name
|
||||
if new_diffusers_path.exists():
|
||||
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
|
||||
|
||||
try:
|
||||
move(old_diffusers_path,new_diffusers_path)
|
||||
info["model_format"] = "diffusers"
|
||||
info["path"] = str(new_diffusers_path.relative_to(self.app_config.root_path))
|
||||
info.pop('config')
|
||||
|
||||
result = self.add_model(model_name, base_model, model_type,
|
||||
model_attributes = info,
|
||||
clobber=True)
|
||||
except:
|
||||
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
|
||||
rmtree(new_diffusers_path)
|
||||
raise
|
||||
|
||||
if checkpoint_path.exists() and checkpoint_path.is_relative_to(self.app_config.models_path):
|
||||
checkpoint_path.unlink()
|
||||
|
||||
return result
|
||||
|
||||
def search_models(self, search_folder):
|
||||
self.logger.info(f"Finding Models In: {search_folder}")
|
||||
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
|
||||
@ -821,6 +888,10 @@ class ModelManager(object):
|
||||
The result is a set of successfully installed models. Each element
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
|
||||
May return the following exceptions:
|
||||
- KeyError - one or more of the items to import is not a valid path, repo_id or URL
|
||||
- ValueError - a corresponding model already exists
|
||||
'''
|
||||
# avoid circular import here
|
||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||
@ -830,11 +901,7 @@ class ModelManager(object):
|
||||
prediction_type_helper = prediction_type_helper,
|
||||
model_manager = self)
|
||||
for thing in items_to_import:
|
||||
try:
|
||||
installed = installer.heuristic_import(thing)
|
||||
successfully_installed.update(installed)
|
||||
except Exception as e:
|
||||
self.logger.warning(f'{thing} could not be imported: {str(e)}')
|
||||
|
||||
installed = installer.heuristic_import(thing)
|
||||
successfully_installed.update(installed)
|
||||
self.commit()
|
||||
return successfully_installed
|
||||
|
129
invokeai/backend/model_management/model_merge.py
Normal file
129
invokeai/backend/model_management/model_merge.py
Normal file
@ -0,0 +1,129 @@
|
||||
"""
|
||||
invokeai.backend.model_management.model_merge exports:
|
||||
merge_diffusion_models() -- combine multiple models by location and return a pipeline object
|
||||
merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to models.yaml
|
||||
|
||||
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import logging as dlogging
|
||||
from typing import List, Union
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
|
||||
|
||||
class MergeInterpolationMethod(str, Enum):
|
||||
Sigmoid = "sigmoid"
|
||||
InvSigmoid = "inv_sigmoid"
|
||||
AddDifference = "add_difference"
|
||||
WeightedSum = "weighted_sum"
|
||||
|
||||
class ModelMerger(object):
|
||||
def __init__(self, manager: ModelManager):
|
||||
self.manager = manager
|
||||
|
||||
def merge_diffusion_models(
|
||||
self,
|
||||
model_paths: List[Path],
|
||||
alpha: float = 0.5,
|
||||
interp: MergeInterpolationMethod = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
) -> DiffusionPipeline:
|
||||
"""
|
||||
:param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids
|
||||
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
:param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
||||
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
verbosity = dlogging.get_verbosity()
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
model_paths[0],
|
||||
custom_pipeline="checkpoint_merger",
|
||||
)
|
||||
merged_pipe = pipe.merge(
|
||||
pretrained_model_name_or_path_list=model_paths,
|
||||
alpha=alpha,
|
||||
interp=interp.value if interp else None, #diffusers API treats None as "weighted sum"
|
||||
force=force,
|
||||
**kwargs,
|
||||
)
|
||||
dlogging.set_verbosity(verbosity)
|
||||
return merged_pipe
|
||||
|
||||
|
||||
def merge_diffusion_models_and_save (
|
||||
self,
|
||||
model_names: List[str],
|
||||
base_model: Union[BaseModelType,str],
|
||||
merged_model_name: str,
|
||||
alpha: float = 0.5,
|
||||
interp: MergeInterpolationMethod = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
:param models: up to three models, designated by their InvokeAI models.yaml model name
|
||||
:param base_model: base model (must be the same for all merged models!)
|
||||
:param merged_model_name: name for new model
|
||||
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
|
||||
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
model_paths = list()
|
||||
config = self.manager.app_config
|
||||
base_model = BaseModelType(base_model)
|
||||
vae = None
|
||||
|
||||
for mod in model_names:
|
||||
info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main)
|
||||
assert info, f"model {mod}, base_model {base_model}, is unknown"
|
||||
assert info["model_format"] == "diffusers", f"{mod} is not a diffusers model. It must be optimized before merging"
|
||||
assert info["variant"] == "normal", (f"{mod} is a {info['variant']} model, which cannot currently be merged")
|
||||
# pick up the first model's vae
|
||||
if mod == model_names[0]:
|
||||
vae = info.get("vae")
|
||||
model_paths.extend([config.root_path / info["path"]])
|
||||
|
||||
merge_method = None if interp == 'weighted_sum' else MergeInterpolationMethod(interp)
|
||||
merged_pipe = self.merge_diffusion_models(
|
||||
model_paths, alpha, merge_method, force, **kwargs
|
||||
)
|
||||
dump_path = config.models_path / base_model.value / ModelType.Main.value
|
||||
dump_path.mkdir(parents=True, exist_ok=True)
|
||||
dump_path = dump_path / merged_model_name
|
||||
|
||||
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||
attributes = dict(
|
||||
path = str(dump_path),
|
||||
description = f"Merge of models {', '.join(model_names)}",
|
||||
model_format = "diffusers",
|
||||
variant = ModelVariantType.Normal.value,
|
||||
vae = vae,
|
||||
)
|
||||
return self.manager.add_model(merged_model_name,
|
||||
base_model = base_model,
|
||||
model_type = ModelType.Main,
|
||||
model_attributes = attributes,
|
||||
clobber = True
|
||||
)
|
@ -116,7 +116,7 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
version=BaseModelType.StableDiffusion1,
|
||||
model_config=config,
|
||||
output_path=output_path,
|
||||
)
|
||||
)
|
||||
else:
|
||||
return model_path
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""
|
||||
Initialization file for invokeai.frontend.merge
|
||||
"""
|
||||
from .merge_diffusers import main as invokeai_merge_diffusers, merge_diffusion_models
|
||||
from .merge_diffusers import main as invokeai_merge_diffusers
|
||||
|
||||
|
@ -6,9 +6,7 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
import argparse
|
||||
import curses
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
@ -20,99 +18,15 @@ from npyscreen import widget
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.services.config import InvokeAIAppConfig
|
||||
from ...backend.model_management import ModelManager
|
||||
from ...frontend.install.widgets import FloatTitleSlider
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import (
|
||||
ModelMerger, MergeInterpolationMethod,
|
||||
ModelManager, ModelType, BaseModelType,
|
||||
)
|
||||
from invokeai.frontend.install.widgets import FloatTitleSlider, TextBox, SingleSelectColumns
|
||||
|
||||
DEST_MERGED_MODEL_DIR = "merged_models"
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
def merge_diffusion_models(
|
||||
model_ids_or_paths: List[Union[str, Path]],
|
||||
alpha: float = 0.5,
|
||||
interp: str = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
) -> DiffusionPipeline:
|
||||
"""
|
||||
model_ids_or_paths - up to three models, designated by their local paths or HuggingFace repo_ids
|
||||
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
||||
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
verbosity = dlogging.get_verbosity()
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
model_ids_or_paths[0],
|
||||
cache_dir=kwargs.get("cache_dir", config.cache_dir),
|
||||
custom_pipeline="checkpoint_merger",
|
||||
)
|
||||
merged_pipe = pipe.merge(
|
||||
pretrained_model_name_or_path_list=model_ids_or_paths,
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
**kwargs,
|
||||
)
|
||||
dlogging.set_verbosity(verbosity)
|
||||
return merged_pipe
|
||||
|
||||
|
||||
def merge_diffusion_models_and_commit(
|
||||
models: List["str"],
|
||||
merged_model_name: str,
|
||||
alpha: float = 0.5,
|
||||
interp: str = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
models - up to three models, designated by their InvokeAI models.yaml model name
|
||||
merged_model_name = name for new model
|
||||
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
interp - The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
|
||||
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
config_file = config.model_conf_path
|
||||
model_manager = ModelManager(OmegaConf.load(config_file))
|
||||
for mod in models:
|
||||
assert mod in model_manager.model_names(), f'** Unknown model "{mod}"'
|
||||
assert (
|
||||
model_manager.model_info(mod).get("format", None) == "diffusers"
|
||||
), f"** {mod} is not a diffusers model. It must be optimized before merging."
|
||||
model_ids_or_paths = [model_manager.model_name_or_path(x) for x in models]
|
||||
|
||||
merged_pipe = merge_diffusion_models(
|
||||
model_ids_or_paths, alpha, interp, force, **kwargs
|
||||
)
|
||||
dump_path = config.models_dir / DEST_MERGED_MODEL_DIR
|
||||
|
||||
os.makedirs(dump_path, exist_ok=True)
|
||||
dump_path = dump_path / merged_model_name
|
||||
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||
import_args = dict(
|
||||
model_name=merged_model_name, description=f'Merge of models {", ".join(models)}'
|
||||
)
|
||||
if vae := model_manager.config[models[0]].get("vae", None):
|
||||
logger.info(f"Using configured VAE assigned to {models[0]}")
|
||||
import_args.update(vae=vae)
|
||||
model_manager.import_diffuser_model(dump_path, **import_args)
|
||||
model_manager.commit(config_file)
|
||||
|
||||
|
||||
def _parse_args() -> Namespace:
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model merging")
|
||||
parser.add_argument(
|
||||
@ -131,10 +45,17 @@ def _parse_args() -> Namespace:
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
dest="model_names",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Two to three model names to be merged",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base_model",
|
||||
type=str,
|
||||
choices=[x.value for x in BaseModelType],
|
||||
help="The base model shared by the models to be merged",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--merged_model_name",
|
||||
"--destination",
|
||||
@ -192,6 +113,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
window_height, window_width = curses.initscr().getmaxyx()
|
||||
|
||||
self.model_names = self.get_model_names()
|
||||
self.current_base = 0
|
||||
max_width = max([len(x) for x in self.model_names])
|
||||
max_width += 6
|
||||
horizontal_layout = max_width * 3 < window_width
|
||||
@ -208,12 +130,26 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
value="Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next.",
|
||||
editable=False,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.base_select = self.add_widget_intelligent(
|
||||
SingleSelectColumns,
|
||||
values=[
|
||||
'Models Built on SD-1.x',
|
||||
'Models Built on SD-2.x',
|
||||
],
|
||||
value=[self.current_base],
|
||||
columns = 4,
|
||||
max_height = 2,
|
||||
relx=8,
|
||||
scroll_exit = True,
|
||||
)
|
||||
self.base_select.on_changed = self._populate_models
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="MODEL 1",
|
||||
color="GOOD",
|
||||
editable=False,
|
||||
rely=4 if horizontal_layout else None,
|
||||
rely=6 if horizontal_layout else None,
|
||||
)
|
||||
self.model1 = self.add_widget_intelligent(
|
||||
npyscreen.SelectOne,
|
||||
@ -222,7 +158,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
max_height=len(self.model_names),
|
||||
max_width=max_width,
|
||||
scroll_exit=True,
|
||||
rely=5,
|
||||
rely=7,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
@ -230,7 +166,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
color="GOOD",
|
||||
editable=False,
|
||||
relx=max_width + 3 if horizontal_layout else None,
|
||||
rely=4 if horizontal_layout else None,
|
||||
rely=6 if horizontal_layout else None,
|
||||
)
|
||||
self.model2 = self.add_widget_intelligent(
|
||||
npyscreen.SelectOne,
|
||||
@ -240,7 +176,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
max_height=len(self.model_names),
|
||||
max_width=max_width,
|
||||
relx=max_width + 3 if horizontal_layout else None,
|
||||
rely=5 if horizontal_layout else None,
|
||||
rely=7 if horizontal_layout else None,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
@ -249,7 +185,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
color="GOOD",
|
||||
editable=False,
|
||||
relx=max_width * 2 + 3 if horizontal_layout else None,
|
||||
rely=4 if horizontal_layout else None,
|
||||
rely=6 if horizontal_layout else None,
|
||||
)
|
||||
models_plus_none = self.model_names.copy()
|
||||
models_plus_none.insert(0, "None")
|
||||
@ -262,24 +198,26 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
max_width=max_width,
|
||||
scroll_exit=True,
|
||||
relx=max_width * 2 + 3 if horizontal_layout else None,
|
||||
rely=5 if horizontal_layout else None,
|
||||
rely=7 if horizontal_layout else None,
|
||||
)
|
||||
for m in [self.model1, self.model2, self.model3]:
|
||||
m.when_value_edited = self.models_changed
|
||||
self.merged_model_name = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
TextBox,
|
||||
name="Name for merged model:",
|
||||
labelColor="CONTROL",
|
||||
max_height=3,
|
||||
value="",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.force = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Force merge of incompatible models",
|
||||
name="Force merge of models created by different diffusers library versions",
|
||||
labelColor="CONTROL",
|
||||
value=False,
|
||||
value=True,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.merge_method = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Merge Method:",
|
||||
@ -341,7 +279,8 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
interp = self.interpolations[self.merge_method.value[0]]
|
||||
|
||||
args = dict(
|
||||
models=models,
|
||||
model_names=models,
|
||||
base_model=tuple(BaseModelType)[self.base_select.value[0]],
|
||||
alpha=self.alpha.value,
|
||||
interp=interp,
|
||||
force=self.force.value,
|
||||
@ -379,21 +318,30 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_model_names(self) -> List[str]:
|
||||
def get_model_names(self, base_model: BaseModelType=None) -> List[str]:
|
||||
model_names = [
|
||||
name
|
||||
for name in self.model_manager.model_names()
|
||||
if self.model_manager.model_info(name).get("format") == "diffusers"
|
||||
info["name"]
|
||||
for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model)
|
||||
if info["model_format"] == "diffusers"
|
||||
]
|
||||
return sorted(model_names)
|
||||
|
||||
def _populate_models(self,value=None):
|
||||
base_model = tuple(BaseModelType)[value[0]]
|
||||
self.model_names = self.get_model_names(base_model)
|
||||
|
||||
models_plus_none = self.model_names.copy()
|
||||
models_plus_none.insert(0, "None")
|
||||
self.model1.values = self.model_names
|
||||
self.model2.values = self.model_names
|
||||
self.model3.values = models_plus_none
|
||||
|
||||
self.display()
|
||||
|
||||
class Mergeapp(npyscreen.NPSAppManaged):
|
||||
def __init__(self):
|
||||
def __init__(self, model_manager:ModelManager):
|
||||
super().__init__()
|
||||
conf = OmegaConf.load(config.model_conf_path)
|
||||
self.model_manager = ModelManager(
|
||||
conf, "cpu", "float16"
|
||||
) # precision doesn't really matter here
|
||||
self.model_manager = model_manager
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
|
||||
@ -401,44 +349,41 @@ class Mergeapp(npyscreen.NPSAppManaged):
|
||||
|
||||
|
||||
def run_gui(args: Namespace):
|
||||
mergeapp = Mergeapp()
|
||||
model_manager = ModelManager(config.model_conf_path)
|
||||
mergeapp = Mergeapp(model_manager)
|
||||
mergeapp.run()
|
||||
|
||||
args = mergeapp.merge_arguments
|
||||
merge_diffusion_models_and_commit(**args)
|
||||
merger = ModelMerger(model_manager)
|
||||
merger.merge_diffusion_models_and_save(**args)
|
||||
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
|
||||
|
||||
|
||||
def run_cli(args: Namespace):
|
||||
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
|
||||
assert (
|
||||
args.models and len(args.models) >= 1 and len(args.models) <= 3
|
||||
args.model_names and len(args.model_names) >= 1 and len(args.model_names) <= 3
|
||||
), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage."
|
||||
|
||||
if not args.merged_model_name:
|
||||
args.merged_model_name = "+".join(args.models)
|
||||
args.merged_model_name = "+".join(args.model_names)
|
||||
logger.info(
|
||||
f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
|
||||
)
|
||||
|
||||
model_manager = ModelManager(OmegaConf.load(config.model_conf_path))
|
||||
assert (
|
||||
args.clobber or args.merged_model_name not in model_manager.model_names()
|
||||
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||
model_manager = ModelManager(config.model_conf_path)
|
||||
assert (
|
||||
not model_manager.model_exists(args.merged_model_name, args.base_model, ModelType.Main) or args.clobber
|
||||
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||
|
||||
merge_diffusion_models_and_commit(**vars(args))
|
||||
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
|
||||
merger = ModelMerger(model_manager)
|
||||
merger.merge_diffusion_models_and_save(**vars(args))
|
||||
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
|
||||
|
||||
|
||||
def main():
|
||||
args = _parse_args()
|
||||
config.root = args.root_dir
|
||||
|
||||
cache_dir = config.cache_dir
|
||||
os.environ[
|
||||
"HF_HOME"
|
||||
] = cache_dir # because not clear the merge pipeline is honoring cache_dir
|
||||
args.cache_dir = cache_dir
|
||||
config.parse_args(['--root',str(args.root_dir)])
|
||||
|
||||
try:
|
||||
if args.front_end:
|
||||
|
@ -1,6 +1,5 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { appSocketConnected, socketConnected } from 'services/events/actions';
|
||||
import { receivedPageOfImages } from 'services/api/thunks/image';
|
||||
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
||||
import { startAppListening } from '../..';
|
||||
|
||||
@ -14,19 +13,10 @@ export const addSocketConnectedEventListener = () => {
|
||||
|
||||
moduleLog.debug({ timestamp }, 'Connected');
|
||||
|
||||
const { nodes, config, gallery } = getState();
|
||||
const { nodes, config } = getState();
|
||||
|
||||
const { disabledTabs } = config;
|
||||
|
||||
if (!gallery.ids.length) {
|
||||
dispatch(
|
||||
receivedPageOfImages({
|
||||
categories: ['general'],
|
||||
is_intermediate: false,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
if (!nodes.schema && !disabledTabs.includes('nodes')) {
|
||||
dispatch(receivedOpenAPISchema());
|
||||
}
|
||||
|
@ -6,10 +6,15 @@ import { validateSeedWeights } from 'common/util/seedWeightPairs';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import {
|
||||
modelsApi,
|
||||
useGetMainModelsQuery,
|
||||
} from '../../services/api/endpoints/models';
|
||||
|
||||
const readinessSelector = createSelector(
|
||||
[stateSelector, activeTabNameSelector],
|
||||
({ generation, system, batch }, activeTabName) => {
|
||||
(state, activeTabName) => {
|
||||
const { generation, system, batch } = state;
|
||||
const { shouldGenerateVariations, seedWeights, initialImage, seed } =
|
||||
generation;
|
||||
|
||||
@ -32,6 +37,13 @@ const readinessSelector = createSelector(
|
||||
reasonsWhyNotReady.push('No initial image selected');
|
||||
}
|
||||
|
||||
const { isSuccess: mainModelsSuccessfullyLoaded } =
|
||||
modelsApi.endpoints.getMainModels.select()(state);
|
||||
if (!mainModelsSuccessfullyLoaded) {
|
||||
isReady = false;
|
||||
reasonsWhyNotReady.push('Models are not loaded');
|
||||
}
|
||||
|
||||
// TODO: job queue
|
||||
// Cannot generate if already processing an image
|
||||
if (isProcessing) {
|
||||
|
@ -182,6 +182,15 @@ const ImageGalleryContent = () => {
|
||||
return () => osInstance()?.destroy();
|
||||
}, [scroller, initialize, osInstance]);
|
||||
|
||||
useEffect(() => {
|
||||
dispatch(
|
||||
receivedPageOfImages({
|
||||
categories: ['general'],
|
||||
is_intermediate: false,
|
||||
})
|
||||
);
|
||||
}, [dispatch]);
|
||||
|
||||
const handleClickImagesCategory = useCallback(() => {
|
||||
dispatch(imageCategoriesChanged(IMAGE_CATEGORIES));
|
||||
dispatch(setGalleryView('images'));
|
||||
|
Loading…
Reference in New Issue
Block a user