convert add_model(), del_model(), list_models() etc to use bifurcated names

This commit is contained in:
Lincoln Stein 2023-05-13 14:44:44 -04:00
parent bc96727cbe
commit 72967bf118
6 changed files with 217 additions and 128 deletions

View File

@ -1,12 +1,9 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
import shutil from typing import Annotated, Literal, Optional, Union
import asyncio
from typing import Annotated, Any, List, Literal, Optional, Union
from fastapi.routing import APIRouter, HTTPException from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as from pydantic import BaseModel, Field, parse_obj_as
from pathlib import Path
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"]) models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -19,6 +16,15 @@ class VaeRepo(BaseModel):
class ModelInfo(BaseModel): class ModelInfo(BaseModel):
description: Optional[str] = Field(description="A description of the model") description: Optional[str] = Field(description="A description of the model")
model_name: str = Field(description="The name of the model")
model_type: str = Field(description="The type of the model")
class DiffusersModelInfo(ModelInfo):
format: Literal['folder'] = 'folder'
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
path: Optional[str] = Field(description="The path to the model")
class CkptModelInfo(ModelInfo): class CkptModelInfo(ModelInfo):
format: Literal['ckpt'] = 'ckpt' format: Literal['ckpt'] = 'ckpt'
@ -29,12 +35,8 @@ class CkptModelInfo(ModelInfo):
width: Optional[int] = Field(description="The width of the model") width: Optional[int] = Field(description="The width of the model")
height: Optional[int] = Field(description="The height of the model") height: Optional[int] = Field(description="The height of the model")
class DiffusersModelInfo(ModelInfo): class SafetensorsModelInfo(CkptModelInfo):
format: Literal['diffusers'] = 'diffusers' format: Literal['safetensors'] = 'safetensors'
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
path: Optional[str] = Field(description="The path to the model")
class CreateModelRequest(BaseModel): class CreateModelRequest(BaseModel):
name: str = Field(description="The name of the model") name: str = Field(description="The name of the model")
@ -56,7 +58,7 @@ class ConvertedModelResponse(BaseModel):
info: DiffusersModelInfo = Field(description="The converted model info") info: DiffusersModelInfo = Field(description="The converted model info")
class ModelsList(BaseModel): class ModelsList(BaseModel):
models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]] models: dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]
@models_router.get( @models_router.get(
@ -121,7 +123,7 @@ async def delete_model(model_name: str) -> None:
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully") raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
else: else:
logger.error(f"Model not found") logger.error("Model not found")
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")

View File

@ -60,14 +60,14 @@ class ModelLoaderInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ModelLoaderOutput: def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
# TODO: not found exceptions # TODO: not found exceptions
if not context.services.model_manager.valid_model( if not context.services.model_manager.model_exists(
model_name=self.model_name, model_name=self.model_name,
model_type=SDModelType.diffusers, model_type=SDModelType.diffusers,
): ):
raise Exception(f"Unkown model name: {self.model_name}!") raise Exception(f"Unkown model name: {self.model_name}!")
""" """
if not context.services.model_manager.valid_model( if not context.services.model_manager.model_exists(
model_name=self.model_name, model_name=self.model_name,
model_type=SDModelType.diffusers, model_type=SDModelType.diffusers,
submodel=SDModelType.tokenizer, submodel=SDModelType.tokenizer,
@ -76,7 +76,7 @@ class ModelLoaderInvocation(BaseInvocation):
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted" f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
) )
if not context.services.model_manager.valid_model( if not context.services.model_manager.model_exists(
model_name=self.model_name, model_name=self.model_name,
model_type=SDModelType.diffusers, model_type=SDModelType.diffusers,
submodel=SDModelType.text_encoder, submodel=SDModelType.text_encoder,
@ -85,7 +85,7 @@ class ModelLoaderInvocation(BaseInvocation):
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted" f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
) )
if not context.services.model_manager.valid_model( if not context.services.model_manager.model_exists(
model_name=self.model_name, model_name=self.model_name,
model_type=SDModelType.diffusers, model_type=SDModelType.diffusers,
submodel=SDModelType.unet, submodel=SDModelType.unet,

View File

@ -59,35 +59,35 @@ class ModelManagerServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def valid_model(self, model_name: str) -> bool: def model_exists(
""" self,
Given a model name, returns True if it is a valid model_name: str,
identifier. model_type: SDModelType
""" ) -> bool:
pass pass
@abstractmethod @abstractmethod
def default_model(self) -> Union[str,None]: def default_model(self) -> Union[Tuple(str, SDModelType),None],
""" """
Returns the name of the default model, or None Returns the name and typeof the default model, or None
if none is defined. if none is defined.
""" """
pass pass
@abstractmethod @abstractmethod
def set_default_model(self, model_name:str): def set_default_model(self, model_name: str, model_type: SDModelType):
"""Sets the default model to the indicated name.""" """Sets the default model to the indicated name."""
pass pass
@abstractmethod @abstractmethod
def model_info(self, model_name: str)->dict: def model_info(self, model_name: str, model_type: SDModelType)->dict:
""" """
Given a model name returns a dict-like (OmegaConf) object describing it. Given a model name returns a dict-like (OmegaConf) object describing it.
""" """
pass pass
@abstractmethod @abstractmethod
def model_names(self)->list[str]: def model_names(self)->List[Tuple(str, SDModelType)]:
""" """
Returns a list of all the model names known. Returns a list of all the model names known.
""" """
@ -97,18 +97,25 @@ class ModelManagerServiceBase(ABC):
def list_models(self)->dict: def list_models(self)->dict:
""" """
Return a dict of models in the format: Return a dict of models in the format:
{ model_name1: {'status': ('active'|'cached'|'not loaded'), { model_key1: {'status': 'active'|'cached'|'not loaded',
'description': description, 'model_name' : name,
'format': ('ckpt'|'diffusers'|'vae'|'text_encoder'|'tokenizer'|'lora'...), 'model_type' : SDModelType,
'description': description,
'format': 'folder'|'safetensors'|'ckpt'
}, },
model_name2: { etc } model_key2: { etc }
""" """
pass pass
@abstractmethod @abstractmethod
def add_model( def add_model(
self, model_name: str, model_attributes: dict, clobber: bool = False)->None: self,
model_name: str,
model_type: SDModelType,
model_attributes: dict,
clobber: bool = False
)->None:
""" """
Update the named model with a dictionary of attributes. Will fail with an Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite. assertion error if the name already exists. Pass clobber=True to overwrite.
@ -121,7 +128,7 @@ class ModelManagerServiceBase(ABC):
@abstractmethod @abstractmethod
def del_model(self, def del_model(self,
model_name: str, model_name: str,
model_type: SDModelType=SDModelType.diffusers, model_type: SDModelType,
delete_files: bool = False): delete_files: bool = False):
""" """
Delete the named model from configuration. If delete_files is true, Delete the named model from configuration. If delete_files is true,
@ -332,33 +339,37 @@ class ModelManagerService(ModelManagerServiceBase):
return model_info return model_info
def valid_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers) -> bool: def model_exists(
self,
model_name: str,
model_type: SDModelType
) -> bool:
""" """
Given a model name, returns True if it is a valid Given a model name, returns True if it is a valid
identifier. identifier.
""" """
return self.mgr.valid_model( return self.mgr.model_exists(
model_name, model_name,
model_type) model_type)
def default_model(self) -> Union[str,None]: def default_model(self) -> Union[Tuple(str, SDModelType),None],
""" """
Returns the name of the default model, or None Returns the name of the default model, or None
if none is defined. if none is defined.
""" """
return self.mgr.default_model() return self.mgr.default_model()
def set_default_model(self, model_name:str): def set_default_model(self, model_name:str, model_type: SDModelType):
"""Sets the default model to the indicated name.""" """Sets the default model to the indicated name."""
self.mgr.set_default_model(model_name) self.mgr.set_default_model(model_name)
def model_info(self, model_name: str)->dict: def model_info(self, model_name: str, model_type: SDModelType)->dict:
""" """
Given a model name returns a dict-like (OmegaConf) object describing it. Given a model name returns a dict-like (OmegaConf) object describing it.
""" """
return self.mgr.model_info(model_name) return self.mgr.model_info(model_name)
def model_names(self)->list[str]: def model_names(self)->List[Tuple(str, SDModelType)]:
""" """
Returns a list of all the model names known. Returns a list of all the model names known.
""" """
@ -367,16 +378,21 @@ class ModelManagerService(ModelManagerServiceBase):
def list_models(self)->dict: def list_models(self)->dict:
""" """
Return a dict of models in the format: Return a dict of models in the format:
{ model_name1: {'status': ('active'|'cached'|'not loaded'), { model_key: {'status': 'active'|'cached'|'not loaded',
'model_name' : name,
'model_type' : SDModelType,
'description': description, 'description': description,
'format': ('ckpt'|'diffusers'|'vae'|'text_encoder'|'tokenizer'|'lora'...), 'format': 'folder'|'safetensors'|'ckpt'
}, },
model_name2: { etc }
""" """
return self.mgr.list_models() return self.mgr.list_models()
def add_model( def add_model(
self, model_name: str, model_attributes: dict, clobber: bool = False)->None: self,
model_name: str,
model_type: SDModelType,
model_attributes: dict,
clobber: bool = False)->None:
""" """
Update the named model with a dictionary of attributes. Will fail with an Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite. assertion error if the name already exists. Pass clobber=True to overwrite.
@ -384,7 +400,7 @@ class ModelManagerService(ModelManagerServiceBase):
with an assertion error if provided attributes are incorrect or with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk. the model name is missing. Call commit() to write changes to disk.
""" """
return self.mgr.add_model(model_name, model_attributes, dict, clobber) return self.mgr.add_model(model_name, model_type, model_attributes, dict, clobber)
def del_model(self, def del_model(self,
@ -439,7 +455,7 @@ class ModelManagerService(ModelManagerServiceBase):
Creates an entry for the indicated textual inversion embedding file. Creates an entry for the indicated textual inversion embedding file.
Call commit() to write out the configuration to models.yaml Call commit() to write out the configuration to models.yaml
""" """
self.mgr(path, model_name, description) self.mgr.import_embedding(path, model_name, description)
def heuristic_import( def heuristic_import(
self, self,

View File

@ -227,7 +227,7 @@ class Generate:
# don't accept invalid models # don't accept invalid models
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
model = model or fallback model = model or fallback
if not self.model_manager.valid_model(model): if not self.model_manager.model_exists(model):
logger.warning( logger.warning(
f'"{model}" is not a known model name; falling back to {fallback}.' f'"{model}" is not a known model name; falling back to {fallback}.'
) )
@ -877,7 +877,7 @@ class Generate:
# the model cache does the loading and offloading # the model cache does the loading and offloading
cache = self.model_manager cache = self.model_manager
if not cache.valid_model(model_name): if not cache.model_exists(model_name):
raise KeyError( raise KeyError(
f'** "{model_name}" is not a known model name. Cannot change.' f'** "{model_name}" is not a known model name. Cannot change.'
) )

View File

@ -9,6 +9,7 @@ return a SDModelInfo object that contains the following attributes:
model into GPU VRAM and returns the model for use. model into GPU VRAM and returns the model for use.
See below for usage. See below for usage.
* name -- symbolic name of the model * name -- symbolic name of the model
* type -- SDModelType of the model
* hash -- unique hash for the model * hash -- unique hash for the model
* location -- path or repo_id of the model * location -- path or repo_id of the model
* revision -- revision of the model if coming from a repo id, * revision -- revision of the model if coming from a repo id,
@ -26,7 +27,7 @@ Typical usage:
max_cache_size=8 max_cache_size=8
) # gigabytes ) # gigabytes
model_info = manager.get_model('stable-diffusion-1.5') model_info = manager.get_model('stable-diffusion-1.5', SDModelType.diffusers)
with model_info.context as my_model: with model_info.context as my_model:
my_model.latents_from_embeddings(...) my_model.latents_from_embeddings(...)
@ -54,15 +55,20 @@ MODELS.YAML
The general format of a models.yaml section is: The general format of a models.yaml section is:
name-of-model: type-of-model/name-of-model:
format: diffusers|ckpt|vae|text_encoder|tokenizer... format: folder|ckpt|safetensors
repo_id: owner/repo repo_id: owner/repo
path: /path/to/local/file/or/directory path: /path/to/local/file/or/directory
subfolder: subfolder-name subfolder: subfolder-name
The format is one of {diffusers, ckpt, vae, text_encoder, tokenizer, The type of model is given in the stanza key, and is one of
unet, scheduler, safety_checker, feature_extractor}, and correspond to {diffusers, ckpt, vae, text_encoder, tokenizer, unet, scheduler,
items in the SDModelType enum defined in model_cache.py safety_checker, feature_extractor, lora, textual_inversion}, and
correspond to items in the SDModelType enum defined in model_cache.py
The format indicates whether the model is organized as a folder with
model subdirectories, or is contained in a single checkpoint or
safetensors file.
One, but not both, of repo_id and path are provided. repo_id is the One, but not both, of repo_id and path are provided. repo_id is the
HuggingFace repository ID of the model, and path points to the file or HuggingFace repository ID of the model, and path points to the file or
@ -74,13 +80,13 @@ the main model. These are usually named after the model type, such as
This example summarizes the two ways of getting a non-diffuser model: This example summarizes the two ways of getting a non-diffuser model:
clip-test-1: text_encoder/clip-test-1:
format: text_encoder format: folder
repo_id: openai/clip-vit-large-patch14 repo_id: openai/clip-vit-large-patch14
description: Returns standalone CLIPTextModel description: Returns standalone CLIPTextModel
clip-test-2: text_encoder/clip-test-2:
format: text_encoder format: folder
repo_id: stabilityai/stable-diffusion-2 repo_id: stabilityai/stable-diffusion-2
subfolder: text_encoder subfolder: text_encoder
description: Returns the text_encoder in the subfolder of the diffusers model (just the encoder in RAM) description: Returns the text_encoder in the subfolder of the diffusers model (just the encoder in RAM)
@ -101,12 +107,12 @@ You may wish to use the same name for a related family of models. To
do this, disambiguate the stanza key with the model and and format do this, disambiguate the stanza key with the model and and format
separated by "/". Example: separated by "/". Example:
clip-large/tokenizer: tokenizer/clip-large:
format: tokenizer format: tokenizer
repo_id: openai/clip-vit-large-patch14 repo_id: openai/clip-vit-large-patch14
description: Returns standalone tokenizer description: Returns standalone tokenizer
clip-large/text_encoder: text_encoder/clip-large:
format: text_encoder format: text_encoder
repo_id: openai/clip-vit-large-patch14 repo_id: openai/clip-vit-large-patch14
description: Returns standalone text encoder description: Returns standalone text encoder
@ -128,31 +134,41 @@ from __future__ import annotations
import os import os
import re import re
import textwrap import textwrap
from contextlib import suppress
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from packaging import version
from pathlib import Path from pathlib import Path
from shutil import rmtree from shutil import rmtree
from typing import Union, Callable, types from typing import Callable, Optional, List, Tuple, Union, types
from contextlib import suppress
import safetensors import safetensors
import safetensors.torch import safetensors.torch
import torch import torch
import invokeai.backend.util.logging as logger
from huggingface_hub import scan_cache_dir from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig from omegaconf.dictconfig import DictConfig
from invokeai.backend.globals import Globals, global_cache_dir, global_resolve_path import invokeai.backend.util.logging as logger
from .model_cache import ModelCache, ModelLocker, SDModelType, ModelStatus, SilenceWarnings from invokeai.backend.globals import (Globals, global_cache_dir,
global_resolve_path)
from invokeai.backend.util import download_with_resume
from ..util import CUDA_DEVICE from ..util import CUDA_DEVICE
from .model_cache import (ModelCache, ModelLocker, ModelStatus, SDModelType,
SilenceWarnings)
# We are only starting to number the config file with release 3.
# The config file version doesn't have to start at release version, but it will help
# reduce confusion.
CONFIG_FILE_VERSION='3.0.0'
# wanted to use pydantic here, but Generator objects not supported # wanted to use pydantic here, but Generator objects not supported
@dataclass @dataclass
class SDModelInfo(): class SDModelInfo():
context: ModelLocker context: ModelLocker
name: str name: str
type: SDModelType
hash: str hash: str
location: Union[Path,str] location: Union[Path,str]
precision: torch.dtype precision: torch.dtype
@ -208,14 +224,17 @@ class ModelManager(object):
type and precision are set up for a CUDA system running at half precision. type and precision are set up for a CUDA system running at half precision.
""" """
if isinstance(config, DictConfig): if isinstance(config, DictConfig):
self.config = config
self.config_path = None self.config_path = None
self.config = config
elif isinstance(config,(str,Path)): elif isinstance(config,(str,Path)):
self.config_path = config self.config_path = config
self.config = OmegaConf.load(self.config_path) self.config = OmegaConf.load(self.config_path)
else: else:
raise ValueError('config argument must be an OmegaConf object, a Path or a string') raise ValueError('config argument must be an OmegaConf object, a Path or a string')
# check config version number and update on disk/RAM if necessary
self._update_config_file_version()
self.cache = ModelCache( self.cache = ModelCache(
max_cache_size=max_cache_size, max_cache_size=max_cache_size,
execution_device = device_type, execution_device = device_type,
@ -226,8 +245,7 @@ class ModelManager(object):
self.cache_keys = dict() self.cache_keys = dict()
self.logger = logger self.logger = logger
# TODO: rename to smth like - is_model_exists def model_exists(
def valid_model(
self, self,
model_name: str, model_name: str,
model_type: SDModelType = SDModelType.diffusers, model_type: SDModelType = SDModelType.diffusers,
@ -246,14 +264,14 @@ class ModelManager(object):
model_type_str, model_name = model_key.split('/', 1) model_type_str, model_name = model_key.split('/', 1)
if model_type_str not in SDModelType.__members__: if model_type_str not in SDModelType.__members__:
# TODO: # TODO:
raise Exception(f"Unkown model type: {model_type_str}") raise Exception(f"Unknown model type: {model_type_str}")
return (model_name, SDModelType[model_type_str]) return (model_name, SDModelType[model_type_str])
def get_model( def get_model(
self, self,
model_name: str, model_name: str,
model_type: SDModelType=None, model_type: SDModelType=SDModelType.diffusers,
submodel: SDModelType=None, submodel: SDModelType=None,
) -> SDModelInfo: ) -> SDModelInfo:
"""Given a model named identified in models.yaml, return """Given a model named identified in models.yaml, return
@ -311,7 +329,7 @@ class ModelManager(object):
if model_type == SDModelType.diffusers: if model_type == SDModelType.diffusers:
# intercept stanzas that point to checkpoint weights and replace them # intercept stanzas that point to checkpoint weights and replace them
# with the equivalent diffusers model # with the equivalent diffusers model
if 'weights' in mconfig: if mconfig.format in ["ckpt", "diffusers"]:
location = self.convert_ckpt_and_cache(mconfig) location = self.convert_ckpt_and_cache(mconfig)
else: else:
location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id') location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id')
@ -350,6 +368,7 @@ class ModelManager(object):
return SDModelInfo( return SDModelInfo(
context = model_context, context = model_context,
name = model_name, name = model_name,
type = submodel or model_type,
hash = hash, hash = hash,
location = location, location = location,
revision = revision, revision = revision,
@ -358,43 +377,50 @@ class ModelManager(object):
_cache = self.cache _cache = self.cache
) )
def default_model(self) -> Union[str,None]: def default_model(self) -> Union[Tuple(str, SDModelType),None]:
""" """
Returns the name of the default model, or None Returns the name of the default model, or None
if none is defined. if none is defined.
""" """
for model_name in self.config: for model_name, model_type in self.model_names():
if self.config[model_name].get("default"): model_key = self.create_key(model_name, model_type)
return model_name if self.config[model_key].get("default"):
return list(self.config.keys())[0] # first one return (model_name, model_type)
return self.model_names()[0][0]
def set_default_model(self, model_name: str) -> None: def set_default_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers) -> None:
""" """
Set the default model. The change will not take Set the default model. The change will not take
effect until you call model_manager.commit() effect until you call model_manager.commit()
""" """
assert model_name in self.model_names(), f"unknown model '{model_name}'" assert self.model_exists(model_name, model_type), f"unknown model '{model_name}'"
config = self.config config = self.config
for model in config: for model_name, model_type in self.model_names():
config[model].pop("default", None) key = self.create_key(model_name, model_type)
config[model_name]["default"] = True config[key].pop("default", None)
config[self.create_key(model_name, model_type)]["default"] = True
def model_info(self, model_name: str) -> dict: def model_info(
self,
model_name: str,
model_type: SDModelType=SDModelType.diffusers
) -> dict:
""" """
Given a model name returns the OmegaConf (dict-like) object describing it. Given a model name returns the OmegaConf (dict-like) object describing it.
""" """
if model_name not in self.config: if not self.exists(model_name, model_type):
return None return None
return self.config[model_name] return self.config[self.create_key(model_name,model_type)]
def model_names(self) -> list[str]: def model_names(self) -> List[Tuple(str, SDModelType)]:
""" """
Return a list consisting of all the names of models defined in models.yaml Return a list of (str, SDModelType) corresponding to all models
known to the configuration.
""" """
return list(self.config.keys()) return [(self.parse_key(x)) for x in self.config.keys() if isinstance(self.config[x],DictConfig)]
def is_legacy(self, model_name: str) -> bool: def is_legacy(self, model_name: str, model_type: SDModelType.diffusers) -> bool:
""" """
Return true if this is a legacy (.ckpt) model Return true if this is a legacy (.ckpt) model
""" """
@ -402,7 +428,7 @@ class ModelManager(object):
# there are no legacy ckpts! # there are no legacy ckpts!
if Globals.ckpt_convert: if Globals.ckpt_convert:
return False return False
info = self.model_info(model_name) info = self.model_info(model_name, model_type)
if "weights" in info and info["weights"].endswith((".ckpt", ".safetensors")): if "weights" in info and info["weights"].endswith((".ckpt", ".safetensors")):
return True return True
return False return False
@ -422,15 +448,17 @@ class ModelManager(object):
# don't include VAEs in listing (legacy style) # don't include VAEs in listing (legacy style)
if "config" in stanza and "/VAE/" in stanza["config"]: if "config" in stanza and "/VAE/" in stanza["config"]:
continue continue
if model_key=='config_file_version':
continue
model_name, model_type = self.parse_key(model_key) model_name, model_type = self.parse_key(model_key)
models[model_name] = dict() models[model_key] = dict()
# TODO: return all models in future # TODO: return all models in future
if model_type != SDModelType.diffusers: if model_type != SDModelType.diffusers:
continue continue
model_format = "ckpt" if "weights" in stanza else "diffusers" model_format = stanza.get('format')
# Common Attribs # Common Attribs
status = self.cache.status( status = self.cache.status(
@ -439,26 +467,25 @@ class ModelManager(object):
subfolder=stanza.get('subfolder') subfolder=stanza.get('subfolder')
) )
description = stanza.get("description", None) description = stanza.get("description", None)
models[model_name].update( models[model_key].update(
description=description, model_name=model_name,
type=model_type, model_type=model_type.name,
format=model_format, format=model_format,
description=description,
status=status.value status=status.value
) )
# Checkpoint Config Parse # Checkpoint Config Parse
if model_format == "ckpt": if model_format in ["ckpt","safetensors"]:
models[model_name].update( models[model_key].update(
config = str(stanza.get("config", None)), config = str(stanza.get("config", None)),
weights = str(stanza.get("weights", None)), weights = str(stanza.get("weights", None)),
vae = str(stanza.get("vae", None)), vae = str(stanza.get("vae", None)),
width = str(stanza.get("width", 512)),
height = str(stanza.get("height", 512)),
) )
# Diffusers Config Parse # Diffusers Config Parse
elif model_format == "diffusers": elif model_format == "folder":
if vae := stanza.get("vae", None): if vae := stanza.get("vae", None):
if isinstance(vae, DictConfig): if isinstance(vae, DictConfig):
vae = dict( vae = dict(
@ -467,7 +494,7 @@ class ModelManager(object):
subfolder = str(vae.get("subfolder", None)), subfolder = str(vae.get("subfolder", None)),
) )
models[model_name].update( models[model_key].update(
vae = vae, vae = vae,
repo_id = str(stanza.get("repo_id", None)), repo_id = str(stanza.get("repo_id", None)),
path = str(stanza.get("path", None)), path = str(stanza.get("path", None)),
@ -479,12 +506,9 @@ class ModelManager(object):
""" """
Print a table of models, their descriptions, and load status Print a table of models, their descriptions, and load status
""" """
models = self.list_models() for model_key, model_info in self.list_models().items():
for name in models: line = f'{model_info["model_name"]:25s} {model_info["status"]:>15s} {model_info["model_type"]:10s} {model_info["description"]}'
if models[name]["type"] == "vae": if model_info["status"] in ["in gpu","locked in gpu"]:
continue
line = f'{name:25s} {models[name]["status"]:>15s} {models[name]["type"]:10s} {models[name]["description"]}'
if models[name]["status"] == "active":
line = f"\033[1m{line}\033[0m" line = f"\033[1m{line}\033[0m"
print(line) print(line)
@ -511,9 +535,9 @@ class ModelManager(object):
# self.stack.remove(model_name) # self.stack.remove(model_name)
if delete_files: if delete_files:
repo_id = conf.get("repo_id", None) repo_id = model_cfg.get("repo_id", None)
path = self._abs_path(conf.get("path", None)) path = self._abs_path(model_cfg.get("path", None))
weights = self._abs_path(conf.get("weights", None)) weights = self._abs_path(model_cfg.get("weights", None))
if "weights" in model_cfg: if "weights" in model_cfg:
weights = self._abs_path(model_cfg["weights"]) weights = self._abs_path(model_cfg["weights"])
self.logger.info(f"Deleting file {weights}") self.logger.info(f"Deleting file {weights}")
@ -558,7 +582,7 @@ class ModelManager(object):
), 'model must have either the "path" or "repo_id" fields defined' ), 'model must have either the "path" or "repo_id" fields defined'
elif model_format == "ckpt": elif model_format == "ckpt":
for field in ("description", "weights", "height", "width", "config"): for field in ("description", "weights", "config"):
assert field in model_attributes, f"required field {field} is missing" assert field in model_attributes, f"required field {field} is missing"
else: else:
@ -579,11 +603,6 @@ class ModelManager(object):
self.cache.uncache_model(self.cache_keys[model_key]) self.cache.uncache_model(self.cache_keys[model_key])
del self.cache_keys[model_key] del self.cache_keys[model_key]
def import_diffuser_model( def import_diffuser_model(
self, self,
repo_or_path: Union[str, Path], repo_or_path: Union[str, Path],
@ -604,7 +623,6 @@ class ModelManager(object):
models.yaml file. models.yaml file.
""" """
model_name = model_name or Path(repo_or_path).stem model_name = model_name or Path(repo_or_path).stem
model_key = f'{model_name}/diffusers'
model_description = description or f"Imported diffusers model {model_name}" model_description = description or f"Imported diffusers model {model_name}"
new_config = dict( new_config = dict(
description=model_description, description=model_description,
@ -616,10 +634,10 @@ class ModelManager(object):
else: else:
new_config.update(repo_id=repo_or_path) new_config.update(repo_id=repo_or_path)
self.add_model(model_key, new_config, True) self.add_model(model_name, SDModelType.diffusers, new_config, True)
if commit_to_conf: if commit_to_conf:
self.commit(commit_to_conf) self.commit(commit_to_conf)
return model_key return self.create_key(model_name, SDModelType.diffusers)
def import_lora( def import_lora(
self, self,
@ -635,7 +653,8 @@ class ModelManager(object):
model_name = model_name or path.stem model_name = model_name or path.stem
model_description = description or f"LoRA model {model_name}" model_description = description or f"LoRA model {model_name}"
self.add_model( self.add_model(
f'{model_name}/{SDModelType.lora.name}', model_name,
SDModelType.lora,
dict( dict(
format="lora", format="lora",
weights=str(path), weights=str(path),
@ -663,7 +682,8 @@ class ModelManager(object):
model_name = model_name or path.stem model_name = model_name or path.stem
model_description = description or f"Textual embedding model {model_name}" model_description = description or f"Textual embedding model {model_name}"
self.add_model( self.add_model(
f'{model_name}/{SDModelType.textual_inversion.name}', model_name,
SDModelType.textual_inversion,
dict( dict(
format="textual_inversion", format="textual_inversion",
weights=str(weights), weights=str(weights),
@ -1025,9 +1045,14 @@ class ModelManager(object):
description=model_description, description=model_description,
format="diffusers", format="diffusers",
) )
if model_name in self.config: if self.model_exists(model_name, SDModelType.diffusers):
self.del_model(model_name) self.del_model(model_name, SDModelType.diffusers)
self.add_model(model_name, new_config, True) self.add_model(
model_name,
SDModelType.diffusers,
new_config,
True
)
if commit_to_conf: if commit_to_conf:
self.commit(commit_to_conf) self.commit(commit_to_conf)
self.logger.debug("Conversion succeeded") self.logger.debug("Conversion succeeded")
@ -1081,11 +1106,6 @@ class ModelManager(object):
"""\ """\
# This file describes the alternative machine learning models # This file describes the alternative machine learning models
# available to InvokeAI script. # available to InvokeAI script.
#
# To add a new model, follow the examples below. Each
# model requires a model config file, a weights file,
# and the width and height of the images it
# was trained on.
""" """
) )
@ -1130,3 +1150,53 @@ class ModelManager(object):
source = os.path.join(Globals.root, source) source = os.path.join(Globals.root, source)
resolved_path = Path(source) resolved_path = Path(source)
return resolved_path return resolved_path
def _update_config_file_version(self):
"""
This gets called at object init time and will update
from older versions of the config file to new ones
as necessary.
"""
current_version = self.config.get("config_file_version","1.0.0")
if version.parse(current_version) < version.parse(CONFIG_FILE_VERSION):
self.logger.info(f'models.yaml version {current_version} detected. Updating to {CONFIG_FILE_VERSION}')
new_config = OmegaConf.create()
new_config["config_file_version"] = CONFIG_FILE_VERSION
for model_key in self.config:
old_stanza = self.config[model_key]
# ignore old and ugly way of associating a legacy
# vae with a legacy checkpont model
if old_stanza.get("config") and '/VAE/' in old_stanza.get("config"):
continue
# bare keys are updated to be prefixed with 'diffusers/'
if '/' not in model_key:
new_key = f'diffusers/{model_key}'
else:
new_key = model_key
if old_stanza.get('format')=='diffusers':
model_format = 'folder'
elif old_stanza.get('weights') and Path(old_stanza.get('weights')).suffix == '.ckpt':
model_format = 'ckpt'
elif old_stanza.get('weights') and Path(old_stanza.get('weights')).suffix == '.safetensors':
model_format = 'safetensors'
# copy fields over manually rather than doing a copy() or deepcopy()
# in order to avoid bringing in unwanted fields.
new_config[new_key] = dict(
description = old_stanza.get('description'),
format = model_format,
)
for field in ["repo_id", "path", "weights", "config", "vae"]:
if field_value := old_stanza.get(field):
new_config[new_key].update({field: field_value})
self.config = new_config
if self.config_path:
self.commit()

View File

@ -345,13 +345,14 @@ class Completer(object):
partial = text partial = text
matches = list() matches = list()
for s in self.models: for s in self.models:
name = self.models[s]["model_name"]
format = self.models[s]["format"] format = self.models[s]["format"]
if format == "vae": if format == "vae":
continue continue
if ckpt_only and format != "ckpt": if ckpt_only and format != "ckpt":
continue continue
if s.startswith(partial): if name.startswith(partial):
matches.append(switch + s) matches.append(switch + name)
matches.sort() matches.sort()
return matches return matches