mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
convert add_model(), del_model(), list_models() etc to use bifurcated names
This commit is contained in:
parent
bc96727cbe
commit
72967bf118
@ -1,12 +1,9 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
|
||||
|
||||
import shutil
|
||||
import asyncio
|
||||
from typing import Annotated, Any, List, Literal, Optional, Union
|
||||
from typing import Annotated, Literal, Optional, Union
|
||||
|
||||
from fastapi.routing import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field, parse_obj_as
|
||||
from pathlib import Path
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
@ -19,6 +16,15 @@ class VaeRepo(BaseModel):
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
description: Optional[str] = Field(description="A description of the model")
|
||||
model_name: str = Field(description="The name of the model")
|
||||
model_type: str = Field(description="The type of the model")
|
||||
|
||||
class DiffusersModelInfo(ModelInfo):
|
||||
format: Literal['folder'] = 'folder'
|
||||
|
||||
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
|
||||
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
|
||||
path: Optional[str] = Field(description="The path to the model")
|
||||
|
||||
class CkptModelInfo(ModelInfo):
|
||||
format: Literal['ckpt'] = 'ckpt'
|
||||
@ -29,12 +35,8 @@ class CkptModelInfo(ModelInfo):
|
||||
width: Optional[int] = Field(description="The width of the model")
|
||||
height: Optional[int] = Field(description="The height of the model")
|
||||
|
||||
class DiffusersModelInfo(ModelInfo):
|
||||
format: Literal['diffusers'] = 'diffusers'
|
||||
|
||||
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
|
||||
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
|
||||
path: Optional[str] = Field(description="The path to the model")
|
||||
class SafetensorsModelInfo(CkptModelInfo):
|
||||
format: Literal['safetensors'] = 'safetensors'
|
||||
|
||||
class CreateModelRequest(BaseModel):
|
||||
name: str = Field(description="The name of the model")
|
||||
@ -56,7 +58,7 @@ class ConvertedModelResponse(BaseModel):
|
||||
info: DiffusersModelInfo = Field(description="The converted model info")
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]]
|
||||
models: dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]
|
||||
|
||||
|
||||
@models_router.get(
|
||||
@ -121,7 +123,7 @@ async def delete_model(model_name: str) -> None:
|
||||
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
|
||||
|
||||
else:
|
||||
logger.error(f"Model not found")
|
||||
logger.error("Model not found")
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||
|
||||
|
||||
|
@ -60,14 +60,14 @@ class ModelLoaderInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.valid_model(
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.diffusers,
|
||||
):
|
||||
raise Exception(f"Unkown model name: {self.model_name}!")
|
||||
|
||||
"""
|
||||
if not context.services.model_manager.valid_model(
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.diffusers,
|
||||
submodel=SDModelType.tokenizer,
|
||||
@ -76,7 +76,7 @@ class ModelLoaderInvocation(BaseInvocation):
|
||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.valid_model(
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.diffusers,
|
||||
submodel=SDModelType.text_encoder,
|
||||
@ -85,7 +85,7 @@ class ModelLoaderInvocation(BaseInvocation):
|
||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.valid_model(
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.diffusers,
|
||||
submodel=SDModelType.unet,
|
||||
|
@ -59,35 +59,35 @@ class ModelManagerServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def valid_model(self, model_name: str) -> bool:
|
||||
"""
|
||||
Given a model name, returns True if it is a valid
|
||||
identifier.
|
||||
"""
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType
|
||||
) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def default_model(self) -> Union[str,None]:
|
||||
def default_model(self) -> Union[Tuple(str, SDModelType),None],
|
||||
"""
|
||||
Returns the name of the default model, or None
|
||||
Returns the name and typeof the default model, or None
|
||||
if none is defined.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_default_model(self, model_name:str):
|
||||
def set_default_model(self, model_name: str, model_type: SDModelType):
|
||||
"""Sets the default model to the indicated name."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_info(self, model_name: str)->dict:
|
||||
def model_info(self, model_name: str, model_type: SDModelType)->dict:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_names(self)->list[str]:
|
||||
def model_names(self)->List[Tuple(str, SDModelType)]:
|
||||
"""
|
||||
Returns a list of all the model names known.
|
||||
"""
|
||||
@ -97,18 +97,25 @@ class ModelManagerServiceBase(ABC):
|
||||
def list_models(self)->dict:
|
||||
"""
|
||||
Return a dict of models in the format:
|
||||
{ model_name1: {'status': ('active'|'cached'|'not loaded'),
|
||||
{ model_key1: {'status': 'active'|'cached'|'not loaded',
|
||||
'model_name' : name,
|
||||
'model_type' : SDModelType,
|
||||
'description': description,
|
||||
'format': ('ckpt'|'diffusers'|'vae'|'text_encoder'|'tokenizer'|'lora'...),
|
||||
'format': 'folder'|'safetensors'|'ckpt'
|
||||
},
|
||||
model_name2: { etc }
|
||||
model_key2: { etc }
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def add_model(
|
||||
self, model_name: str, model_attributes: dict, clobber: bool = False)->None:
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False
|
||||
)->None:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
@ -121,7 +128,7 @@ class ModelManagerServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def del_model(self,
|
||||
model_name: str,
|
||||
model_type: SDModelType=SDModelType.diffusers,
|
||||
model_type: SDModelType,
|
||||
delete_files: bool = False):
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files is true,
|
||||
@ -332,33 +339,37 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
|
||||
return model_info
|
||||
|
||||
def valid_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers) -> bool:
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType
|
||||
) -> bool:
|
||||
"""
|
||||
Given a model name, returns True if it is a valid
|
||||
identifier.
|
||||
"""
|
||||
return self.mgr.valid_model(
|
||||
return self.mgr.model_exists(
|
||||
model_name,
|
||||
model_type)
|
||||
|
||||
def default_model(self) -> Union[str,None]:
|
||||
def default_model(self) -> Union[Tuple(str, SDModelType),None],
|
||||
"""
|
||||
Returns the name of the default model, or None
|
||||
if none is defined.
|
||||
"""
|
||||
return self.mgr.default_model()
|
||||
|
||||
def set_default_model(self, model_name:str):
|
||||
def set_default_model(self, model_name:str, model_type: SDModelType):
|
||||
"""Sets the default model to the indicated name."""
|
||||
self.mgr.set_default_model(model_name)
|
||||
|
||||
def model_info(self, model_name: str)->dict:
|
||||
def model_info(self, model_name: str, model_type: SDModelType)->dict:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
"""
|
||||
return self.mgr.model_info(model_name)
|
||||
|
||||
def model_names(self)->list[str]:
|
||||
def model_names(self)->List[Tuple(str, SDModelType)]:
|
||||
"""
|
||||
Returns a list of all the model names known.
|
||||
"""
|
||||
@ -367,16 +378,21 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
def list_models(self)->dict:
|
||||
"""
|
||||
Return a dict of models in the format:
|
||||
{ model_name1: {'status': ('active'|'cached'|'not loaded'),
|
||||
{ model_key: {'status': 'active'|'cached'|'not loaded',
|
||||
'model_name' : name,
|
||||
'model_type' : SDModelType,
|
||||
'description': description,
|
||||
'format': ('ckpt'|'diffusers'|'vae'|'text_encoder'|'tokenizer'|'lora'...),
|
||||
'format': 'folder'|'safetensors'|'ckpt'
|
||||
},
|
||||
model_name2: { etc }
|
||||
"""
|
||||
return self.mgr.list_models()
|
||||
|
||||
def add_model(
|
||||
self, model_name: str, model_attributes: dict, clobber: bool = False)->None:
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False)->None:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
@ -384,7 +400,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
return self.mgr.add_model(model_name, model_attributes, dict, clobber)
|
||||
return self.mgr.add_model(model_name, model_type, model_attributes, dict, clobber)
|
||||
|
||||
|
||||
def del_model(self,
|
||||
@ -439,7 +455,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
Creates an entry for the indicated textual inversion embedding file.
|
||||
Call commit() to write out the configuration to models.yaml
|
||||
"""
|
||||
self.mgr(path, model_name, description)
|
||||
self.mgr.import_embedding(path, model_name, description)
|
||||
|
||||
def heuristic_import(
|
||||
self,
|
||||
|
@ -227,7 +227,7 @@ class Generate:
|
||||
# don't accept invalid models
|
||||
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
|
||||
model = model or fallback
|
||||
if not self.model_manager.valid_model(model):
|
||||
if not self.model_manager.model_exists(model):
|
||||
logger.warning(
|
||||
f'"{model}" is not a known model name; falling back to {fallback}.'
|
||||
)
|
||||
@ -877,7 +877,7 @@ class Generate:
|
||||
|
||||
# the model cache does the loading and offloading
|
||||
cache = self.model_manager
|
||||
if not cache.valid_model(model_name):
|
||||
if not cache.model_exists(model_name):
|
||||
raise KeyError(
|
||||
f'** "{model_name}" is not a known model name. Cannot change.'
|
||||
)
|
||||
|
@ -9,6 +9,7 @@ return a SDModelInfo object that contains the following attributes:
|
||||
model into GPU VRAM and returns the model for use.
|
||||
See below for usage.
|
||||
* name -- symbolic name of the model
|
||||
* type -- SDModelType of the model
|
||||
* hash -- unique hash for the model
|
||||
* location -- path or repo_id of the model
|
||||
* revision -- revision of the model if coming from a repo id,
|
||||
@ -26,7 +27,7 @@ Typical usage:
|
||||
max_cache_size=8
|
||||
) # gigabytes
|
||||
|
||||
model_info = manager.get_model('stable-diffusion-1.5')
|
||||
model_info = manager.get_model('stable-diffusion-1.5', SDModelType.diffusers)
|
||||
with model_info.context as my_model:
|
||||
my_model.latents_from_embeddings(...)
|
||||
|
||||
@ -54,15 +55,20 @@ MODELS.YAML
|
||||
|
||||
The general format of a models.yaml section is:
|
||||
|
||||
name-of-model:
|
||||
format: diffusers|ckpt|vae|text_encoder|tokenizer...
|
||||
type-of-model/name-of-model:
|
||||
format: folder|ckpt|safetensors
|
||||
repo_id: owner/repo
|
||||
path: /path/to/local/file/or/directory
|
||||
subfolder: subfolder-name
|
||||
|
||||
The format is one of {diffusers, ckpt, vae, text_encoder, tokenizer,
|
||||
unet, scheduler, safety_checker, feature_extractor}, and correspond to
|
||||
items in the SDModelType enum defined in model_cache.py
|
||||
The type of model is given in the stanza key, and is one of
|
||||
{diffusers, ckpt, vae, text_encoder, tokenizer, unet, scheduler,
|
||||
safety_checker, feature_extractor, lora, textual_inversion}, and
|
||||
correspond to items in the SDModelType enum defined in model_cache.py
|
||||
|
||||
The format indicates whether the model is organized as a folder with
|
||||
model subdirectories, or is contained in a single checkpoint or
|
||||
safetensors file.
|
||||
|
||||
One, but not both, of repo_id and path are provided. repo_id is the
|
||||
HuggingFace repository ID of the model, and path points to the file or
|
||||
@ -74,13 +80,13 @@ the main model. These are usually named after the model type, such as
|
||||
|
||||
This example summarizes the two ways of getting a non-diffuser model:
|
||||
|
||||
clip-test-1:
|
||||
format: text_encoder
|
||||
text_encoder/clip-test-1:
|
||||
format: folder
|
||||
repo_id: openai/clip-vit-large-patch14
|
||||
description: Returns standalone CLIPTextModel
|
||||
|
||||
clip-test-2:
|
||||
format: text_encoder
|
||||
text_encoder/clip-test-2:
|
||||
format: folder
|
||||
repo_id: stabilityai/stable-diffusion-2
|
||||
subfolder: text_encoder
|
||||
description: Returns the text_encoder in the subfolder of the diffusers model (just the encoder in RAM)
|
||||
@ -101,12 +107,12 @@ You may wish to use the same name for a related family of models. To
|
||||
do this, disambiguate the stanza key with the model and and format
|
||||
separated by "/". Example:
|
||||
|
||||
clip-large/tokenizer:
|
||||
tokenizer/clip-large:
|
||||
format: tokenizer
|
||||
repo_id: openai/clip-vit-large-patch14
|
||||
description: Returns standalone tokenizer
|
||||
|
||||
clip-large/text_encoder:
|
||||
text_encoder/clip-large:
|
||||
format: text_encoder
|
||||
repo_id: openai/clip-vit-large-patch14
|
||||
description: Returns standalone text encoder
|
||||
@ -128,31 +134,41 @@ from __future__ import annotations
|
||||
import os
|
||||
import re
|
||||
import textwrap
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from packaging import version
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Union, Callable, types
|
||||
from contextlib import suppress
|
||||
from typing import Callable, Optional, List, Tuple, Union, types
|
||||
|
||||
import safetensors
|
||||
import safetensors.torch
|
||||
import torch
|
||||
import invokeai.backend.util.logging as logger
|
||||
from huggingface_hub import scan_cache_dir
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
|
||||
from invokeai.backend.globals import Globals, global_cache_dir, global_resolve_path
|
||||
from .model_cache import ModelCache, ModelLocker, SDModelType, ModelStatus, SilenceWarnings
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import (Globals, global_cache_dir,
|
||||
global_resolve_path)
|
||||
from invokeai.backend.util import download_with_resume
|
||||
|
||||
from ..util import CUDA_DEVICE
|
||||
from .model_cache import (ModelCache, ModelLocker, ModelStatus, SDModelType,
|
||||
SilenceWarnings)
|
||||
|
||||
# We are only starting to number the config file with release 3.
|
||||
# The config file version doesn't have to start at release version, but it will help
|
||||
# reduce confusion.
|
||||
CONFIG_FILE_VERSION='3.0.0'
|
||||
|
||||
# wanted to use pydantic here, but Generator objects not supported
|
||||
@dataclass
|
||||
class SDModelInfo():
|
||||
context: ModelLocker
|
||||
name: str
|
||||
type: SDModelType
|
||||
hash: str
|
||||
location: Union[Path,str]
|
||||
precision: torch.dtype
|
||||
@ -208,14 +224,17 @@ class ModelManager(object):
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
"""
|
||||
if isinstance(config, DictConfig):
|
||||
self.config = config
|
||||
self.config_path = None
|
||||
self.config = config
|
||||
elif isinstance(config,(str,Path)):
|
||||
self.config_path = config
|
||||
self.config = OmegaConf.load(self.config_path)
|
||||
else:
|
||||
raise ValueError('config argument must be an OmegaConf object, a Path or a string')
|
||||
|
||||
# check config version number and update on disk/RAM if necessary
|
||||
self._update_config_file_version()
|
||||
|
||||
self.cache = ModelCache(
|
||||
max_cache_size=max_cache_size,
|
||||
execution_device = device_type,
|
||||
@ -226,8 +245,7 @@ class ModelManager(object):
|
||||
self.cache_keys = dict()
|
||||
self.logger = logger
|
||||
|
||||
# TODO: rename to smth like - is_model_exists
|
||||
def valid_model(
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType = SDModelType.diffusers,
|
||||
@ -246,14 +264,14 @@ class ModelManager(object):
|
||||
model_type_str, model_name = model_key.split('/', 1)
|
||||
if model_type_str not in SDModelType.__members__:
|
||||
# TODO:
|
||||
raise Exception(f"Unkown model type: {model_type_str}")
|
||||
raise Exception(f"Unknown model type: {model_type_str}")
|
||||
|
||||
return (model_name, SDModelType[model_type_str])
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType=None,
|
||||
model_type: SDModelType=SDModelType.diffusers,
|
||||
submodel: SDModelType=None,
|
||||
) -> SDModelInfo:
|
||||
"""Given a model named identified in models.yaml, return
|
||||
@ -311,7 +329,7 @@ class ModelManager(object):
|
||||
if model_type == SDModelType.diffusers:
|
||||
# intercept stanzas that point to checkpoint weights and replace them
|
||||
# with the equivalent diffusers model
|
||||
if 'weights' in mconfig:
|
||||
if mconfig.format in ["ckpt", "diffusers"]:
|
||||
location = self.convert_ckpt_and_cache(mconfig)
|
||||
else:
|
||||
location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id')
|
||||
@ -350,6 +368,7 @@ class ModelManager(object):
|
||||
return SDModelInfo(
|
||||
context = model_context,
|
||||
name = model_name,
|
||||
type = submodel or model_type,
|
||||
hash = hash,
|
||||
location = location,
|
||||
revision = revision,
|
||||
@ -358,43 +377,50 @@ class ModelManager(object):
|
||||
_cache = self.cache
|
||||
)
|
||||
|
||||
def default_model(self) -> Union[str,None]:
|
||||
def default_model(self) -> Union[Tuple(str, SDModelType),None]:
|
||||
"""
|
||||
Returns the name of the default model, or None
|
||||
if none is defined.
|
||||
"""
|
||||
for model_name in self.config:
|
||||
if self.config[model_name].get("default"):
|
||||
return model_name
|
||||
return list(self.config.keys())[0] # first one
|
||||
for model_name, model_type in self.model_names():
|
||||
model_key = self.create_key(model_name, model_type)
|
||||
if self.config[model_key].get("default"):
|
||||
return (model_name, model_type)
|
||||
return self.model_names()[0][0]
|
||||
|
||||
def set_default_model(self, model_name: str) -> None:
|
||||
def set_default_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers) -> None:
|
||||
"""
|
||||
Set the default model. The change will not take
|
||||
effect until you call model_manager.commit()
|
||||
"""
|
||||
assert model_name in self.model_names(), f"unknown model '{model_name}'"
|
||||
assert self.model_exists(model_name, model_type), f"unknown model '{model_name}'"
|
||||
|
||||
config = self.config
|
||||
for model in config:
|
||||
config[model].pop("default", None)
|
||||
config[model_name]["default"] = True
|
||||
for model_name, model_type in self.model_names():
|
||||
key = self.create_key(model_name, model_type)
|
||||
config[key].pop("default", None)
|
||||
config[self.create_key(model_name, model_type)]["default"] = True
|
||||
|
||||
def model_info(self, model_name: str) -> dict:
|
||||
def model_info(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType=SDModelType.diffusers
|
||||
) -> dict:
|
||||
"""
|
||||
Given a model name returns the OmegaConf (dict-like) object describing it.
|
||||
"""
|
||||
if model_name not in self.config:
|
||||
if not self.exists(model_name, model_type):
|
||||
return None
|
||||
return self.config[model_name]
|
||||
return self.config[self.create_key(model_name,model_type)]
|
||||
|
||||
def model_names(self) -> list[str]:
|
||||
def model_names(self) -> List[Tuple(str, SDModelType)]:
|
||||
"""
|
||||
Return a list consisting of all the names of models defined in models.yaml
|
||||
Return a list of (str, SDModelType) corresponding to all models
|
||||
known to the configuration.
|
||||
"""
|
||||
return list(self.config.keys())
|
||||
return [(self.parse_key(x)) for x in self.config.keys() if isinstance(self.config[x],DictConfig)]
|
||||
|
||||
def is_legacy(self, model_name: str) -> bool:
|
||||
def is_legacy(self, model_name: str, model_type: SDModelType.diffusers) -> bool:
|
||||
"""
|
||||
Return true if this is a legacy (.ckpt) model
|
||||
"""
|
||||
@ -402,7 +428,7 @@ class ModelManager(object):
|
||||
# there are no legacy ckpts!
|
||||
if Globals.ckpt_convert:
|
||||
return False
|
||||
info = self.model_info(model_name)
|
||||
info = self.model_info(model_name, model_type)
|
||||
if "weights" in info and info["weights"].endswith((".ckpt", ".safetensors")):
|
||||
return True
|
||||
return False
|
||||
@ -422,15 +448,17 @@ class ModelManager(object):
|
||||
# don't include VAEs in listing (legacy style)
|
||||
if "config" in stanza and "/VAE/" in stanza["config"]:
|
||||
continue
|
||||
if model_key=='config_file_version':
|
||||
continue
|
||||
|
||||
model_name, model_type = self.parse_key(model_key)
|
||||
models[model_name] = dict()
|
||||
models[model_key] = dict()
|
||||
|
||||
# TODO: return all models in future
|
||||
if model_type != SDModelType.diffusers:
|
||||
continue
|
||||
|
||||
model_format = "ckpt" if "weights" in stanza else "diffusers"
|
||||
model_format = stanza.get('format')
|
||||
|
||||
# Common Attribs
|
||||
status = self.cache.status(
|
||||
@ -439,26 +467,25 @@ class ModelManager(object):
|
||||
subfolder=stanza.get('subfolder')
|
||||
)
|
||||
description = stanza.get("description", None)
|
||||
models[model_name].update(
|
||||
description=description,
|
||||
type=model_type,
|
||||
models[model_key].update(
|
||||
model_name=model_name,
|
||||
model_type=model_type.name,
|
||||
format=model_format,
|
||||
description=description,
|
||||
status=status.value
|
||||
)
|
||||
|
||||
|
||||
# Checkpoint Config Parse
|
||||
if model_format == "ckpt":
|
||||
models[model_name].update(
|
||||
if model_format in ["ckpt","safetensors"]:
|
||||
models[model_key].update(
|
||||
config = str(stanza.get("config", None)),
|
||||
weights = str(stanza.get("weights", None)),
|
||||
vae = str(stanza.get("vae", None)),
|
||||
width = str(stanza.get("width", 512)),
|
||||
height = str(stanza.get("height", 512)),
|
||||
)
|
||||
|
||||
# Diffusers Config Parse
|
||||
elif model_format == "diffusers":
|
||||
elif model_format == "folder":
|
||||
if vae := stanza.get("vae", None):
|
||||
if isinstance(vae, DictConfig):
|
||||
vae = dict(
|
||||
@ -467,7 +494,7 @@ class ModelManager(object):
|
||||
subfolder = str(vae.get("subfolder", None)),
|
||||
)
|
||||
|
||||
models[model_name].update(
|
||||
models[model_key].update(
|
||||
vae = vae,
|
||||
repo_id = str(stanza.get("repo_id", None)),
|
||||
path = str(stanza.get("path", None)),
|
||||
@ -479,12 +506,9 @@ class ModelManager(object):
|
||||
"""
|
||||
Print a table of models, their descriptions, and load status
|
||||
"""
|
||||
models = self.list_models()
|
||||
for name in models:
|
||||
if models[name]["type"] == "vae":
|
||||
continue
|
||||
line = f'{name:25s} {models[name]["status"]:>15s} {models[name]["type"]:10s} {models[name]["description"]}'
|
||||
if models[name]["status"] == "active":
|
||||
for model_key, model_info in self.list_models().items():
|
||||
line = f'{model_info["model_name"]:25s} {model_info["status"]:>15s} {model_info["model_type"]:10s} {model_info["description"]}'
|
||||
if model_info["status"] in ["in gpu","locked in gpu"]:
|
||||
line = f"\033[1m{line}\033[0m"
|
||||
print(line)
|
||||
|
||||
@ -511,9 +535,9 @@ class ModelManager(object):
|
||||
# self.stack.remove(model_name)
|
||||
|
||||
if delete_files:
|
||||
repo_id = conf.get("repo_id", None)
|
||||
path = self._abs_path(conf.get("path", None))
|
||||
weights = self._abs_path(conf.get("weights", None))
|
||||
repo_id = model_cfg.get("repo_id", None)
|
||||
path = self._abs_path(model_cfg.get("path", None))
|
||||
weights = self._abs_path(model_cfg.get("weights", None))
|
||||
if "weights" in model_cfg:
|
||||
weights = self._abs_path(model_cfg["weights"])
|
||||
self.logger.info(f"Deleting file {weights}")
|
||||
@ -558,7 +582,7 @@ class ModelManager(object):
|
||||
), 'model must have either the "path" or "repo_id" fields defined'
|
||||
|
||||
elif model_format == "ckpt":
|
||||
for field in ("description", "weights", "height", "width", "config"):
|
||||
for field in ("description", "weights", "config"):
|
||||
assert field in model_attributes, f"required field {field} is missing"
|
||||
|
||||
else:
|
||||
@ -579,11 +603,6 @@ class ModelManager(object):
|
||||
self.cache.uncache_model(self.cache_keys[model_key])
|
||||
del self.cache_keys[model_key]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def import_diffuser_model(
|
||||
self,
|
||||
repo_or_path: Union[str, Path],
|
||||
@ -604,7 +623,6 @@ class ModelManager(object):
|
||||
models.yaml file.
|
||||
"""
|
||||
model_name = model_name or Path(repo_or_path).stem
|
||||
model_key = f'{model_name}/diffusers'
|
||||
model_description = description or f"Imported diffusers model {model_name}"
|
||||
new_config = dict(
|
||||
description=model_description,
|
||||
@ -616,10 +634,10 @@ class ModelManager(object):
|
||||
else:
|
||||
new_config.update(repo_id=repo_or_path)
|
||||
|
||||
self.add_model(model_key, new_config, True)
|
||||
self.add_model(model_name, SDModelType.diffusers, new_config, True)
|
||||
if commit_to_conf:
|
||||
self.commit(commit_to_conf)
|
||||
return model_key
|
||||
return self.create_key(model_name, SDModelType.diffusers)
|
||||
|
||||
def import_lora(
|
||||
self,
|
||||
@ -635,7 +653,8 @@ class ModelManager(object):
|
||||
model_name = model_name or path.stem
|
||||
model_description = description or f"LoRA model {model_name}"
|
||||
self.add_model(
|
||||
f'{model_name}/{SDModelType.lora.name}',
|
||||
model_name,
|
||||
SDModelType.lora,
|
||||
dict(
|
||||
format="lora",
|
||||
weights=str(path),
|
||||
@ -663,7 +682,8 @@ class ModelManager(object):
|
||||
model_name = model_name or path.stem
|
||||
model_description = description or f"Textual embedding model {model_name}"
|
||||
self.add_model(
|
||||
f'{model_name}/{SDModelType.textual_inversion.name}',
|
||||
model_name,
|
||||
SDModelType.textual_inversion,
|
||||
dict(
|
||||
format="textual_inversion",
|
||||
weights=str(weights),
|
||||
@ -1025,9 +1045,14 @@ class ModelManager(object):
|
||||
description=model_description,
|
||||
format="diffusers",
|
||||
)
|
||||
if model_name in self.config:
|
||||
self.del_model(model_name)
|
||||
self.add_model(model_name, new_config, True)
|
||||
if self.model_exists(model_name, SDModelType.diffusers):
|
||||
self.del_model(model_name, SDModelType.diffusers)
|
||||
self.add_model(
|
||||
model_name,
|
||||
SDModelType.diffusers,
|
||||
new_config,
|
||||
True
|
||||
)
|
||||
if commit_to_conf:
|
||||
self.commit(commit_to_conf)
|
||||
self.logger.debug("Conversion succeeded")
|
||||
@ -1081,11 +1106,6 @@ class ModelManager(object):
|
||||
"""\
|
||||
# This file describes the alternative machine learning models
|
||||
# available to InvokeAI script.
|
||||
#
|
||||
# To add a new model, follow the examples below. Each
|
||||
# model requires a model config file, a weights file,
|
||||
# and the width and height of the images it
|
||||
# was trained on.
|
||||
"""
|
||||
)
|
||||
|
||||
@ -1130,3 +1150,53 @@ class ModelManager(object):
|
||||
source = os.path.join(Globals.root, source)
|
||||
resolved_path = Path(source)
|
||||
return resolved_path
|
||||
|
||||
def _update_config_file_version(self):
|
||||
"""
|
||||
This gets called at object init time and will update
|
||||
from older versions of the config file to new ones
|
||||
as necessary.
|
||||
"""
|
||||
current_version = self.config.get("config_file_version","1.0.0")
|
||||
if version.parse(current_version) < version.parse(CONFIG_FILE_VERSION):
|
||||
self.logger.info(f'models.yaml version {current_version} detected. Updating to {CONFIG_FILE_VERSION}')
|
||||
|
||||
new_config = OmegaConf.create()
|
||||
new_config["config_file_version"] = CONFIG_FILE_VERSION
|
||||
|
||||
for model_key in self.config:
|
||||
|
||||
old_stanza = self.config[model_key]
|
||||
|
||||
# ignore old and ugly way of associating a legacy
|
||||
# vae with a legacy checkpont model
|
||||
if old_stanza.get("config") and '/VAE/' in old_stanza.get("config"):
|
||||
continue
|
||||
|
||||
# bare keys are updated to be prefixed with 'diffusers/'
|
||||
if '/' not in model_key:
|
||||
new_key = f'diffusers/{model_key}'
|
||||
else:
|
||||
new_key = model_key
|
||||
|
||||
if old_stanza.get('format')=='diffusers':
|
||||
model_format = 'folder'
|
||||
elif old_stanza.get('weights') and Path(old_stanza.get('weights')).suffix == '.ckpt':
|
||||
model_format = 'ckpt'
|
||||
elif old_stanza.get('weights') and Path(old_stanza.get('weights')).suffix == '.safetensors':
|
||||
model_format = 'safetensors'
|
||||
|
||||
# copy fields over manually rather than doing a copy() or deepcopy()
|
||||
# in order to avoid bringing in unwanted fields.
|
||||
new_config[new_key] = dict(
|
||||
description = old_stanza.get('description'),
|
||||
format = model_format,
|
||||
)
|
||||
for field in ["repo_id", "path", "weights", "config", "vae"]:
|
||||
if field_value := old_stanza.get(field):
|
||||
new_config[new_key].update({field: field_value})
|
||||
|
||||
self.config = new_config
|
||||
if self.config_path:
|
||||
self.commit()
|
||||
|
||||
|
@ -345,13 +345,14 @@ class Completer(object):
|
||||
partial = text
|
||||
matches = list()
|
||||
for s in self.models:
|
||||
name = self.models[s]["model_name"]
|
||||
format = self.models[s]["format"]
|
||||
if format == "vae":
|
||||
continue
|
||||
if ckpt_only and format != "ckpt":
|
||||
continue
|
||||
if s.startswith(partial):
|
||||
matches.append(switch + s)
|
||||
if name.startswith(partial):
|
||||
matches.append(switch + name)
|
||||
matches.sort()
|
||||
return matches
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user