mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
convert add_model(), del_model(), list_models() etc to use bifurcated names
This commit is contained in:
parent
bc96727cbe
commit
72967bf118
@ -1,12 +1,9 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
|
||||||
|
|
||||||
import shutil
|
from typing import Annotated, Literal, Optional, Union
|
||||||
import asyncio
|
|
||||||
from typing import Annotated, Any, List, Literal, Optional, Union
|
|
||||||
|
|
||||||
from fastapi.routing import APIRouter, HTTPException
|
from fastapi.routing import APIRouter, HTTPException
|
||||||
from pydantic import BaseModel, Field, parse_obj_as
|
from pydantic import BaseModel, Field, parse_obj_as
|
||||||
from pathlib import Path
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||||
@ -19,6 +16,15 @@ class VaeRepo(BaseModel):
|
|||||||
|
|
||||||
class ModelInfo(BaseModel):
|
class ModelInfo(BaseModel):
|
||||||
description: Optional[str] = Field(description="A description of the model")
|
description: Optional[str] = Field(description="A description of the model")
|
||||||
|
model_name: str = Field(description="The name of the model")
|
||||||
|
model_type: str = Field(description="The type of the model")
|
||||||
|
|
||||||
|
class DiffusersModelInfo(ModelInfo):
|
||||||
|
format: Literal['folder'] = 'folder'
|
||||||
|
|
||||||
|
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
|
||||||
|
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
|
||||||
|
path: Optional[str] = Field(description="The path to the model")
|
||||||
|
|
||||||
class CkptModelInfo(ModelInfo):
|
class CkptModelInfo(ModelInfo):
|
||||||
format: Literal['ckpt'] = 'ckpt'
|
format: Literal['ckpt'] = 'ckpt'
|
||||||
@ -29,12 +35,8 @@ class CkptModelInfo(ModelInfo):
|
|||||||
width: Optional[int] = Field(description="The width of the model")
|
width: Optional[int] = Field(description="The width of the model")
|
||||||
height: Optional[int] = Field(description="The height of the model")
|
height: Optional[int] = Field(description="The height of the model")
|
||||||
|
|
||||||
class DiffusersModelInfo(ModelInfo):
|
class SafetensorsModelInfo(CkptModelInfo):
|
||||||
format: Literal['diffusers'] = 'diffusers'
|
format: Literal['safetensors'] = 'safetensors'
|
||||||
|
|
||||||
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
|
|
||||||
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
|
|
||||||
path: Optional[str] = Field(description="The path to the model")
|
|
||||||
|
|
||||||
class CreateModelRequest(BaseModel):
|
class CreateModelRequest(BaseModel):
|
||||||
name: str = Field(description="The name of the model")
|
name: str = Field(description="The name of the model")
|
||||||
@ -56,7 +58,7 @@ class ConvertedModelResponse(BaseModel):
|
|||||||
info: DiffusersModelInfo = Field(description="The converted model info")
|
info: DiffusersModelInfo = Field(description="The converted model info")
|
||||||
|
|
||||||
class ModelsList(BaseModel):
|
class ModelsList(BaseModel):
|
||||||
models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]]
|
models: dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]
|
||||||
|
|
||||||
|
|
||||||
@models_router.get(
|
@models_router.get(
|
||||||
@ -121,7 +123,7 @@ async def delete_model(model_name: str) -> None:
|
|||||||
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
|
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.error(f"Model not found")
|
logger.error("Model not found")
|
||||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,14 +60,14 @@ class ModelLoaderInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||||
|
|
||||||
# TODO: not found exceptions
|
# TODO: not found exceptions
|
||||||
if not context.services.model_manager.valid_model(
|
if not context.services.model_manager.model_exists(
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
model_type=SDModelType.diffusers,
|
model_type=SDModelType.diffusers,
|
||||||
):
|
):
|
||||||
raise Exception(f"Unkown model name: {self.model_name}!")
|
raise Exception(f"Unkown model name: {self.model_name}!")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not context.services.model_manager.valid_model(
|
if not context.services.model_manager.model_exists(
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
model_type=SDModelType.diffusers,
|
model_type=SDModelType.diffusers,
|
||||||
submodel=SDModelType.tokenizer,
|
submodel=SDModelType.tokenizer,
|
||||||
@ -76,7 +76,7 @@ class ModelLoaderInvocation(BaseInvocation):
|
|||||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not context.services.model_manager.valid_model(
|
if not context.services.model_manager.model_exists(
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
model_type=SDModelType.diffusers,
|
model_type=SDModelType.diffusers,
|
||||||
submodel=SDModelType.text_encoder,
|
submodel=SDModelType.text_encoder,
|
||||||
@ -85,7 +85,7 @@ class ModelLoaderInvocation(BaseInvocation):
|
|||||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not context.services.model_manager.valid_model(
|
if not context.services.model_manager.model_exists(
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
model_type=SDModelType.diffusers,
|
model_type=SDModelType.diffusers,
|
||||||
submodel=SDModelType.unet,
|
submodel=SDModelType.unet,
|
||||||
|
@ -59,35 +59,35 @@ class ModelManagerServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def valid_model(self, model_name: str) -> bool:
|
def model_exists(
|
||||||
"""
|
self,
|
||||||
Given a model name, returns True if it is a valid
|
model_name: str,
|
||||||
identifier.
|
model_type: SDModelType
|
||||||
"""
|
) -> bool:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def default_model(self) -> Union[str,None]:
|
def default_model(self) -> Union[Tuple(str, SDModelType),None],
|
||||||
"""
|
"""
|
||||||
Returns the name of the default model, or None
|
Returns the name and typeof the default model, or None
|
||||||
if none is defined.
|
if none is defined.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def set_default_model(self, model_name:str):
|
def set_default_model(self, model_name: str, model_type: SDModelType):
|
||||||
"""Sets the default model to the indicated name."""
|
"""Sets the default model to the indicated name."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def model_info(self, model_name: str)->dict:
|
def model_info(self, model_name: str, model_type: SDModelType)->dict:
|
||||||
"""
|
"""
|
||||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def model_names(self)->list[str]:
|
def model_names(self)->List[Tuple(str, SDModelType)]:
|
||||||
"""
|
"""
|
||||||
Returns a list of all the model names known.
|
Returns a list of all the model names known.
|
||||||
"""
|
"""
|
||||||
@ -97,18 +97,25 @@ class ModelManagerServiceBase(ABC):
|
|||||||
def list_models(self)->dict:
|
def list_models(self)->dict:
|
||||||
"""
|
"""
|
||||||
Return a dict of models in the format:
|
Return a dict of models in the format:
|
||||||
{ model_name1: {'status': ('active'|'cached'|'not loaded'),
|
{ model_key1: {'status': 'active'|'cached'|'not loaded',
|
||||||
'description': description,
|
'model_name' : name,
|
||||||
'format': ('ckpt'|'diffusers'|'vae'|'text_encoder'|'tokenizer'|'lora'...),
|
'model_type' : SDModelType,
|
||||||
|
'description': description,
|
||||||
|
'format': 'folder'|'safetensors'|'ckpt'
|
||||||
},
|
},
|
||||||
model_name2: { etc }
|
model_key2: { etc }
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def add_model(
|
def add_model(
|
||||||
self, model_name: str, model_attributes: dict, clobber: bool = False)->None:
|
self,
|
||||||
|
model_name: str,
|
||||||
|
model_type: SDModelType,
|
||||||
|
model_attributes: dict,
|
||||||
|
clobber: bool = False
|
||||||
|
)->None:
|
||||||
"""
|
"""
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
Update the named model with a dictionary of attributes. Will fail with an
|
||||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||||
@ -121,7 +128,7 @@ class ModelManagerServiceBase(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def del_model(self,
|
def del_model(self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_type: SDModelType=SDModelType.diffusers,
|
model_type: SDModelType,
|
||||||
delete_files: bool = False):
|
delete_files: bool = False):
|
||||||
"""
|
"""
|
||||||
Delete the named model from configuration. If delete_files is true,
|
Delete the named model from configuration. If delete_files is true,
|
||||||
@ -332,33 +339,37 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
|
|
||||||
return model_info
|
return model_info
|
||||||
|
|
||||||
def valid_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers) -> bool:
|
def model_exists(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
model_type: SDModelType
|
||||||
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Given a model name, returns True if it is a valid
|
Given a model name, returns True if it is a valid
|
||||||
identifier.
|
identifier.
|
||||||
"""
|
"""
|
||||||
return self.mgr.valid_model(
|
return self.mgr.model_exists(
|
||||||
model_name,
|
model_name,
|
||||||
model_type)
|
model_type)
|
||||||
|
|
||||||
def default_model(self) -> Union[str,None]:
|
def default_model(self) -> Union[Tuple(str, SDModelType),None],
|
||||||
"""
|
"""
|
||||||
Returns the name of the default model, or None
|
Returns the name of the default model, or None
|
||||||
if none is defined.
|
if none is defined.
|
||||||
"""
|
"""
|
||||||
return self.mgr.default_model()
|
return self.mgr.default_model()
|
||||||
|
|
||||||
def set_default_model(self, model_name:str):
|
def set_default_model(self, model_name:str, model_type: SDModelType):
|
||||||
"""Sets the default model to the indicated name."""
|
"""Sets the default model to the indicated name."""
|
||||||
self.mgr.set_default_model(model_name)
|
self.mgr.set_default_model(model_name)
|
||||||
|
|
||||||
def model_info(self, model_name: str)->dict:
|
def model_info(self, model_name: str, model_type: SDModelType)->dict:
|
||||||
"""
|
"""
|
||||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||||
"""
|
"""
|
||||||
return self.mgr.model_info(model_name)
|
return self.mgr.model_info(model_name)
|
||||||
|
|
||||||
def model_names(self)->list[str]:
|
def model_names(self)->List[Tuple(str, SDModelType)]:
|
||||||
"""
|
"""
|
||||||
Returns a list of all the model names known.
|
Returns a list of all the model names known.
|
||||||
"""
|
"""
|
||||||
@ -367,16 +378,21 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
def list_models(self)->dict:
|
def list_models(self)->dict:
|
||||||
"""
|
"""
|
||||||
Return a dict of models in the format:
|
Return a dict of models in the format:
|
||||||
{ model_name1: {'status': ('active'|'cached'|'not loaded'),
|
{ model_key: {'status': 'active'|'cached'|'not loaded',
|
||||||
|
'model_name' : name,
|
||||||
|
'model_type' : SDModelType,
|
||||||
'description': description,
|
'description': description,
|
||||||
'format': ('ckpt'|'diffusers'|'vae'|'text_encoder'|'tokenizer'|'lora'...),
|
'format': 'folder'|'safetensors'|'ckpt'
|
||||||
},
|
},
|
||||||
model_name2: { etc }
|
|
||||||
"""
|
"""
|
||||||
return self.mgr.list_models()
|
return self.mgr.list_models()
|
||||||
|
|
||||||
def add_model(
|
def add_model(
|
||||||
self, model_name: str, model_attributes: dict, clobber: bool = False)->None:
|
self,
|
||||||
|
model_name: str,
|
||||||
|
model_type: SDModelType,
|
||||||
|
model_attributes: dict,
|
||||||
|
clobber: bool = False)->None:
|
||||||
"""
|
"""
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
Update the named model with a dictionary of attributes. Will fail with an
|
||||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||||
@ -384,7 +400,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
with an assertion error if provided attributes are incorrect or
|
with an assertion error if provided attributes are incorrect or
|
||||||
the model name is missing. Call commit() to write changes to disk.
|
the model name is missing. Call commit() to write changes to disk.
|
||||||
"""
|
"""
|
||||||
return self.mgr.add_model(model_name, model_attributes, dict, clobber)
|
return self.mgr.add_model(model_name, model_type, model_attributes, dict, clobber)
|
||||||
|
|
||||||
|
|
||||||
def del_model(self,
|
def del_model(self,
|
||||||
@ -439,7 +455,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
Creates an entry for the indicated textual inversion embedding file.
|
Creates an entry for the indicated textual inversion embedding file.
|
||||||
Call commit() to write out the configuration to models.yaml
|
Call commit() to write out the configuration to models.yaml
|
||||||
"""
|
"""
|
||||||
self.mgr(path, model_name, description)
|
self.mgr.import_embedding(path, model_name, description)
|
||||||
|
|
||||||
def heuristic_import(
|
def heuristic_import(
|
||||||
self,
|
self,
|
||||||
|
@ -227,7 +227,7 @@ class Generate:
|
|||||||
# don't accept invalid models
|
# don't accept invalid models
|
||||||
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
|
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
|
||||||
model = model or fallback
|
model = model or fallback
|
||||||
if not self.model_manager.valid_model(model):
|
if not self.model_manager.model_exists(model):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f'"{model}" is not a known model name; falling back to {fallback}.'
|
f'"{model}" is not a known model name; falling back to {fallback}.'
|
||||||
)
|
)
|
||||||
@ -877,7 +877,7 @@ class Generate:
|
|||||||
|
|
||||||
# the model cache does the loading and offloading
|
# the model cache does the loading and offloading
|
||||||
cache = self.model_manager
|
cache = self.model_manager
|
||||||
if not cache.valid_model(model_name):
|
if not cache.model_exists(model_name):
|
||||||
raise KeyError(
|
raise KeyError(
|
||||||
f'** "{model_name}" is not a known model name. Cannot change.'
|
f'** "{model_name}" is not a known model name. Cannot change.'
|
||||||
)
|
)
|
||||||
|
@ -9,6 +9,7 @@ return a SDModelInfo object that contains the following attributes:
|
|||||||
model into GPU VRAM and returns the model for use.
|
model into GPU VRAM and returns the model for use.
|
||||||
See below for usage.
|
See below for usage.
|
||||||
* name -- symbolic name of the model
|
* name -- symbolic name of the model
|
||||||
|
* type -- SDModelType of the model
|
||||||
* hash -- unique hash for the model
|
* hash -- unique hash for the model
|
||||||
* location -- path or repo_id of the model
|
* location -- path or repo_id of the model
|
||||||
* revision -- revision of the model if coming from a repo id,
|
* revision -- revision of the model if coming from a repo id,
|
||||||
@ -26,7 +27,7 @@ Typical usage:
|
|||||||
max_cache_size=8
|
max_cache_size=8
|
||||||
) # gigabytes
|
) # gigabytes
|
||||||
|
|
||||||
model_info = manager.get_model('stable-diffusion-1.5')
|
model_info = manager.get_model('stable-diffusion-1.5', SDModelType.diffusers)
|
||||||
with model_info.context as my_model:
|
with model_info.context as my_model:
|
||||||
my_model.latents_from_embeddings(...)
|
my_model.latents_from_embeddings(...)
|
||||||
|
|
||||||
@ -54,15 +55,20 @@ MODELS.YAML
|
|||||||
|
|
||||||
The general format of a models.yaml section is:
|
The general format of a models.yaml section is:
|
||||||
|
|
||||||
name-of-model:
|
type-of-model/name-of-model:
|
||||||
format: diffusers|ckpt|vae|text_encoder|tokenizer...
|
format: folder|ckpt|safetensors
|
||||||
repo_id: owner/repo
|
repo_id: owner/repo
|
||||||
path: /path/to/local/file/or/directory
|
path: /path/to/local/file/or/directory
|
||||||
subfolder: subfolder-name
|
subfolder: subfolder-name
|
||||||
|
|
||||||
The format is one of {diffusers, ckpt, vae, text_encoder, tokenizer,
|
The type of model is given in the stanza key, and is one of
|
||||||
unet, scheduler, safety_checker, feature_extractor}, and correspond to
|
{diffusers, ckpt, vae, text_encoder, tokenizer, unet, scheduler,
|
||||||
items in the SDModelType enum defined in model_cache.py
|
safety_checker, feature_extractor, lora, textual_inversion}, and
|
||||||
|
correspond to items in the SDModelType enum defined in model_cache.py
|
||||||
|
|
||||||
|
The format indicates whether the model is organized as a folder with
|
||||||
|
model subdirectories, or is contained in a single checkpoint or
|
||||||
|
safetensors file.
|
||||||
|
|
||||||
One, but not both, of repo_id and path are provided. repo_id is the
|
One, but not both, of repo_id and path are provided. repo_id is the
|
||||||
HuggingFace repository ID of the model, and path points to the file or
|
HuggingFace repository ID of the model, and path points to the file or
|
||||||
@ -74,13 +80,13 @@ the main model. These are usually named after the model type, such as
|
|||||||
|
|
||||||
This example summarizes the two ways of getting a non-diffuser model:
|
This example summarizes the two ways of getting a non-diffuser model:
|
||||||
|
|
||||||
clip-test-1:
|
text_encoder/clip-test-1:
|
||||||
format: text_encoder
|
format: folder
|
||||||
repo_id: openai/clip-vit-large-patch14
|
repo_id: openai/clip-vit-large-patch14
|
||||||
description: Returns standalone CLIPTextModel
|
description: Returns standalone CLIPTextModel
|
||||||
|
|
||||||
clip-test-2:
|
text_encoder/clip-test-2:
|
||||||
format: text_encoder
|
format: folder
|
||||||
repo_id: stabilityai/stable-diffusion-2
|
repo_id: stabilityai/stable-diffusion-2
|
||||||
subfolder: text_encoder
|
subfolder: text_encoder
|
||||||
description: Returns the text_encoder in the subfolder of the diffusers model (just the encoder in RAM)
|
description: Returns the text_encoder in the subfolder of the diffusers model (just the encoder in RAM)
|
||||||
@ -101,12 +107,12 @@ You may wish to use the same name for a related family of models. To
|
|||||||
do this, disambiguate the stanza key with the model and and format
|
do this, disambiguate the stanza key with the model and and format
|
||||||
separated by "/". Example:
|
separated by "/". Example:
|
||||||
|
|
||||||
clip-large/tokenizer:
|
tokenizer/clip-large:
|
||||||
format: tokenizer
|
format: tokenizer
|
||||||
repo_id: openai/clip-vit-large-patch14
|
repo_id: openai/clip-vit-large-patch14
|
||||||
description: Returns standalone tokenizer
|
description: Returns standalone tokenizer
|
||||||
|
|
||||||
clip-large/text_encoder:
|
text_encoder/clip-large:
|
||||||
format: text_encoder
|
format: text_encoder
|
||||||
repo_id: openai/clip-vit-large-patch14
|
repo_id: openai/clip-vit-large-patch14
|
||||||
description: Returns standalone text encoder
|
description: Returns standalone text encoder
|
||||||
@ -128,31 +134,41 @@ from __future__ import annotations
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import textwrap
|
import textwrap
|
||||||
|
from contextlib import suppress
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
|
from packaging import version
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import rmtree
|
from shutil import rmtree
|
||||||
from typing import Union, Callable, types
|
from typing import Callable, Optional, List, Tuple, Union, types
|
||||||
from contextlib import suppress
|
|
||||||
|
|
||||||
import safetensors
|
import safetensors
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from huggingface_hub import scan_cache_dir
|
from huggingface_hub import scan_cache_dir
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.dictconfig import DictConfig
|
from omegaconf.dictconfig import DictConfig
|
||||||
|
|
||||||
from invokeai.backend.globals import Globals, global_cache_dir, global_resolve_path
|
import invokeai.backend.util.logging as logger
|
||||||
from .model_cache import ModelCache, ModelLocker, SDModelType, ModelStatus, SilenceWarnings
|
from invokeai.backend.globals import (Globals, global_cache_dir,
|
||||||
|
global_resolve_path)
|
||||||
|
from invokeai.backend.util import download_with_resume
|
||||||
|
|
||||||
from ..util import CUDA_DEVICE
|
from ..util import CUDA_DEVICE
|
||||||
|
from .model_cache import (ModelCache, ModelLocker, ModelStatus, SDModelType,
|
||||||
|
SilenceWarnings)
|
||||||
|
|
||||||
|
# We are only starting to number the config file with release 3.
|
||||||
|
# The config file version doesn't have to start at release version, but it will help
|
||||||
|
# reduce confusion.
|
||||||
|
CONFIG_FILE_VERSION='3.0.0'
|
||||||
|
|
||||||
# wanted to use pydantic here, but Generator objects not supported
|
# wanted to use pydantic here, but Generator objects not supported
|
||||||
@dataclass
|
@dataclass
|
||||||
class SDModelInfo():
|
class SDModelInfo():
|
||||||
context: ModelLocker
|
context: ModelLocker
|
||||||
name: str
|
name: str
|
||||||
|
type: SDModelType
|
||||||
hash: str
|
hash: str
|
||||||
location: Union[Path,str]
|
location: Union[Path,str]
|
||||||
precision: torch.dtype
|
precision: torch.dtype
|
||||||
@ -208,14 +224,17 @@ class ModelManager(object):
|
|||||||
type and precision are set up for a CUDA system running at half precision.
|
type and precision are set up for a CUDA system running at half precision.
|
||||||
"""
|
"""
|
||||||
if isinstance(config, DictConfig):
|
if isinstance(config, DictConfig):
|
||||||
self.config = config
|
|
||||||
self.config_path = None
|
self.config_path = None
|
||||||
|
self.config = config
|
||||||
elif isinstance(config,(str,Path)):
|
elif isinstance(config,(str,Path)):
|
||||||
self.config_path = config
|
self.config_path = config
|
||||||
self.config = OmegaConf.load(self.config_path)
|
self.config = OmegaConf.load(self.config_path)
|
||||||
else:
|
else:
|
||||||
raise ValueError('config argument must be an OmegaConf object, a Path or a string')
|
raise ValueError('config argument must be an OmegaConf object, a Path or a string')
|
||||||
|
|
||||||
|
# check config version number and update on disk/RAM if necessary
|
||||||
|
self._update_config_file_version()
|
||||||
|
|
||||||
self.cache = ModelCache(
|
self.cache = ModelCache(
|
||||||
max_cache_size=max_cache_size,
|
max_cache_size=max_cache_size,
|
||||||
execution_device = device_type,
|
execution_device = device_type,
|
||||||
@ -226,8 +245,7 @@ class ModelManager(object):
|
|||||||
self.cache_keys = dict()
|
self.cache_keys = dict()
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
# TODO: rename to smth like - is_model_exists
|
def model_exists(
|
||||||
def valid_model(
|
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_type: SDModelType = SDModelType.diffusers,
|
model_type: SDModelType = SDModelType.diffusers,
|
||||||
@ -246,14 +264,14 @@ class ModelManager(object):
|
|||||||
model_type_str, model_name = model_key.split('/', 1)
|
model_type_str, model_name = model_key.split('/', 1)
|
||||||
if model_type_str not in SDModelType.__members__:
|
if model_type_str not in SDModelType.__members__:
|
||||||
# TODO:
|
# TODO:
|
||||||
raise Exception(f"Unkown model type: {model_type_str}")
|
raise Exception(f"Unknown model type: {model_type_str}")
|
||||||
|
|
||||||
return (model_name, SDModelType[model_type_str])
|
return (model_name, SDModelType[model_type_str])
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_type: SDModelType=None,
|
model_type: SDModelType=SDModelType.diffusers,
|
||||||
submodel: SDModelType=None,
|
submodel: SDModelType=None,
|
||||||
) -> SDModelInfo:
|
) -> SDModelInfo:
|
||||||
"""Given a model named identified in models.yaml, return
|
"""Given a model named identified in models.yaml, return
|
||||||
@ -311,7 +329,7 @@ class ModelManager(object):
|
|||||||
if model_type == SDModelType.diffusers:
|
if model_type == SDModelType.diffusers:
|
||||||
# intercept stanzas that point to checkpoint weights and replace them
|
# intercept stanzas that point to checkpoint weights and replace them
|
||||||
# with the equivalent diffusers model
|
# with the equivalent diffusers model
|
||||||
if 'weights' in mconfig:
|
if mconfig.format in ["ckpt", "diffusers"]:
|
||||||
location = self.convert_ckpt_and_cache(mconfig)
|
location = self.convert_ckpt_and_cache(mconfig)
|
||||||
else:
|
else:
|
||||||
location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id')
|
location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id')
|
||||||
@ -350,6 +368,7 @@ class ModelManager(object):
|
|||||||
return SDModelInfo(
|
return SDModelInfo(
|
||||||
context = model_context,
|
context = model_context,
|
||||||
name = model_name,
|
name = model_name,
|
||||||
|
type = submodel or model_type,
|
||||||
hash = hash,
|
hash = hash,
|
||||||
location = location,
|
location = location,
|
||||||
revision = revision,
|
revision = revision,
|
||||||
@ -358,43 +377,50 @@ class ModelManager(object):
|
|||||||
_cache = self.cache
|
_cache = self.cache
|
||||||
)
|
)
|
||||||
|
|
||||||
def default_model(self) -> Union[str,None]:
|
def default_model(self) -> Union[Tuple(str, SDModelType),None]:
|
||||||
"""
|
"""
|
||||||
Returns the name of the default model, or None
|
Returns the name of the default model, or None
|
||||||
if none is defined.
|
if none is defined.
|
||||||
"""
|
"""
|
||||||
for model_name in self.config:
|
for model_name, model_type in self.model_names():
|
||||||
if self.config[model_name].get("default"):
|
model_key = self.create_key(model_name, model_type)
|
||||||
return model_name
|
if self.config[model_key].get("default"):
|
||||||
return list(self.config.keys())[0] # first one
|
return (model_name, model_type)
|
||||||
|
return self.model_names()[0][0]
|
||||||
|
|
||||||
def set_default_model(self, model_name: str) -> None:
|
def set_default_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers) -> None:
|
||||||
"""
|
"""
|
||||||
Set the default model. The change will not take
|
Set the default model. The change will not take
|
||||||
effect until you call model_manager.commit()
|
effect until you call model_manager.commit()
|
||||||
"""
|
"""
|
||||||
assert model_name in self.model_names(), f"unknown model '{model_name}'"
|
assert self.model_exists(model_name, model_type), f"unknown model '{model_name}'"
|
||||||
|
|
||||||
config = self.config
|
config = self.config
|
||||||
for model in config:
|
for model_name, model_type in self.model_names():
|
||||||
config[model].pop("default", None)
|
key = self.create_key(model_name, model_type)
|
||||||
config[model_name]["default"] = True
|
config[key].pop("default", None)
|
||||||
|
config[self.create_key(model_name, model_type)]["default"] = True
|
||||||
|
|
||||||
def model_info(self, model_name: str) -> dict:
|
def model_info(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
model_type: SDModelType=SDModelType.diffusers
|
||||||
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Given a model name returns the OmegaConf (dict-like) object describing it.
|
Given a model name returns the OmegaConf (dict-like) object describing it.
|
||||||
"""
|
"""
|
||||||
if model_name not in self.config:
|
if not self.exists(model_name, model_type):
|
||||||
return None
|
return None
|
||||||
return self.config[model_name]
|
return self.config[self.create_key(model_name,model_type)]
|
||||||
|
|
||||||
def model_names(self) -> list[str]:
|
def model_names(self) -> List[Tuple(str, SDModelType)]:
|
||||||
"""
|
"""
|
||||||
Return a list consisting of all the names of models defined in models.yaml
|
Return a list of (str, SDModelType) corresponding to all models
|
||||||
|
known to the configuration.
|
||||||
"""
|
"""
|
||||||
return list(self.config.keys())
|
return [(self.parse_key(x)) for x in self.config.keys() if isinstance(self.config[x],DictConfig)]
|
||||||
|
|
||||||
def is_legacy(self, model_name: str) -> bool:
|
def is_legacy(self, model_name: str, model_type: SDModelType.diffusers) -> bool:
|
||||||
"""
|
"""
|
||||||
Return true if this is a legacy (.ckpt) model
|
Return true if this is a legacy (.ckpt) model
|
||||||
"""
|
"""
|
||||||
@ -402,7 +428,7 @@ class ModelManager(object):
|
|||||||
# there are no legacy ckpts!
|
# there are no legacy ckpts!
|
||||||
if Globals.ckpt_convert:
|
if Globals.ckpt_convert:
|
||||||
return False
|
return False
|
||||||
info = self.model_info(model_name)
|
info = self.model_info(model_name, model_type)
|
||||||
if "weights" in info and info["weights"].endswith((".ckpt", ".safetensors")):
|
if "weights" in info and info["weights"].endswith((".ckpt", ".safetensors")):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
@ -422,15 +448,17 @@ class ModelManager(object):
|
|||||||
# don't include VAEs in listing (legacy style)
|
# don't include VAEs in listing (legacy style)
|
||||||
if "config" in stanza and "/VAE/" in stanza["config"]:
|
if "config" in stanza and "/VAE/" in stanza["config"]:
|
||||||
continue
|
continue
|
||||||
|
if model_key=='config_file_version':
|
||||||
|
continue
|
||||||
|
|
||||||
model_name, model_type = self.parse_key(model_key)
|
model_name, model_type = self.parse_key(model_key)
|
||||||
models[model_name] = dict()
|
models[model_key] = dict()
|
||||||
|
|
||||||
# TODO: return all models in future
|
# TODO: return all models in future
|
||||||
if model_type != SDModelType.diffusers:
|
if model_type != SDModelType.diffusers:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
model_format = "ckpt" if "weights" in stanza else "diffusers"
|
model_format = stanza.get('format')
|
||||||
|
|
||||||
# Common Attribs
|
# Common Attribs
|
||||||
status = self.cache.status(
|
status = self.cache.status(
|
||||||
@ -439,26 +467,25 @@ class ModelManager(object):
|
|||||||
subfolder=stanza.get('subfolder')
|
subfolder=stanza.get('subfolder')
|
||||||
)
|
)
|
||||||
description = stanza.get("description", None)
|
description = stanza.get("description", None)
|
||||||
models[model_name].update(
|
models[model_key].update(
|
||||||
description=description,
|
model_name=model_name,
|
||||||
type=model_type,
|
model_type=model_type.name,
|
||||||
format=model_format,
|
format=model_format,
|
||||||
|
description=description,
|
||||||
status=status.value
|
status=status.value
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Checkpoint Config Parse
|
# Checkpoint Config Parse
|
||||||
if model_format == "ckpt":
|
if model_format in ["ckpt","safetensors"]:
|
||||||
models[model_name].update(
|
models[model_key].update(
|
||||||
config = str(stanza.get("config", None)),
|
config = str(stanza.get("config", None)),
|
||||||
weights = str(stanza.get("weights", None)),
|
weights = str(stanza.get("weights", None)),
|
||||||
vae = str(stanza.get("vae", None)),
|
vae = str(stanza.get("vae", None)),
|
||||||
width = str(stanza.get("width", 512)),
|
|
||||||
height = str(stanza.get("height", 512)),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Diffusers Config Parse
|
# Diffusers Config Parse
|
||||||
elif model_format == "diffusers":
|
elif model_format == "folder":
|
||||||
if vae := stanza.get("vae", None):
|
if vae := stanza.get("vae", None):
|
||||||
if isinstance(vae, DictConfig):
|
if isinstance(vae, DictConfig):
|
||||||
vae = dict(
|
vae = dict(
|
||||||
@ -467,7 +494,7 @@ class ModelManager(object):
|
|||||||
subfolder = str(vae.get("subfolder", None)),
|
subfolder = str(vae.get("subfolder", None)),
|
||||||
)
|
)
|
||||||
|
|
||||||
models[model_name].update(
|
models[model_key].update(
|
||||||
vae = vae,
|
vae = vae,
|
||||||
repo_id = str(stanza.get("repo_id", None)),
|
repo_id = str(stanza.get("repo_id", None)),
|
||||||
path = str(stanza.get("path", None)),
|
path = str(stanza.get("path", None)),
|
||||||
@ -479,12 +506,9 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
Print a table of models, their descriptions, and load status
|
Print a table of models, their descriptions, and load status
|
||||||
"""
|
"""
|
||||||
models = self.list_models()
|
for model_key, model_info in self.list_models().items():
|
||||||
for name in models:
|
line = f'{model_info["model_name"]:25s} {model_info["status"]:>15s} {model_info["model_type"]:10s} {model_info["description"]}'
|
||||||
if models[name]["type"] == "vae":
|
if model_info["status"] in ["in gpu","locked in gpu"]:
|
||||||
continue
|
|
||||||
line = f'{name:25s} {models[name]["status"]:>15s} {models[name]["type"]:10s} {models[name]["description"]}'
|
|
||||||
if models[name]["status"] == "active":
|
|
||||||
line = f"\033[1m{line}\033[0m"
|
line = f"\033[1m{line}\033[0m"
|
||||||
print(line)
|
print(line)
|
||||||
|
|
||||||
@ -511,9 +535,9 @@ class ModelManager(object):
|
|||||||
# self.stack.remove(model_name)
|
# self.stack.remove(model_name)
|
||||||
|
|
||||||
if delete_files:
|
if delete_files:
|
||||||
repo_id = conf.get("repo_id", None)
|
repo_id = model_cfg.get("repo_id", None)
|
||||||
path = self._abs_path(conf.get("path", None))
|
path = self._abs_path(model_cfg.get("path", None))
|
||||||
weights = self._abs_path(conf.get("weights", None))
|
weights = self._abs_path(model_cfg.get("weights", None))
|
||||||
if "weights" in model_cfg:
|
if "weights" in model_cfg:
|
||||||
weights = self._abs_path(model_cfg["weights"])
|
weights = self._abs_path(model_cfg["weights"])
|
||||||
self.logger.info(f"Deleting file {weights}")
|
self.logger.info(f"Deleting file {weights}")
|
||||||
@ -558,7 +582,7 @@ class ModelManager(object):
|
|||||||
), 'model must have either the "path" or "repo_id" fields defined'
|
), 'model must have either the "path" or "repo_id" fields defined'
|
||||||
|
|
||||||
elif model_format == "ckpt":
|
elif model_format == "ckpt":
|
||||||
for field in ("description", "weights", "height", "width", "config"):
|
for field in ("description", "weights", "config"):
|
||||||
assert field in model_attributes, f"required field {field} is missing"
|
assert field in model_attributes, f"required field {field} is missing"
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -579,11 +603,6 @@ class ModelManager(object):
|
|||||||
self.cache.uncache_model(self.cache_keys[model_key])
|
self.cache.uncache_model(self.cache_keys[model_key])
|
||||||
del self.cache_keys[model_key]
|
del self.cache_keys[model_key]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def import_diffuser_model(
|
def import_diffuser_model(
|
||||||
self,
|
self,
|
||||||
repo_or_path: Union[str, Path],
|
repo_or_path: Union[str, Path],
|
||||||
@ -604,7 +623,6 @@ class ModelManager(object):
|
|||||||
models.yaml file.
|
models.yaml file.
|
||||||
"""
|
"""
|
||||||
model_name = model_name or Path(repo_or_path).stem
|
model_name = model_name or Path(repo_or_path).stem
|
||||||
model_key = f'{model_name}/diffusers'
|
|
||||||
model_description = description or f"Imported diffusers model {model_name}"
|
model_description = description or f"Imported diffusers model {model_name}"
|
||||||
new_config = dict(
|
new_config = dict(
|
||||||
description=model_description,
|
description=model_description,
|
||||||
@ -616,10 +634,10 @@ class ModelManager(object):
|
|||||||
else:
|
else:
|
||||||
new_config.update(repo_id=repo_or_path)
|
new_config.update(repo_id=repo_or_path)
|
||||||
|
|
||||||
self.add_model(model_key, new_config, True)
|
self.add_model(model_name, SDModelType.diffusers, new_config, True)
|
||||||
if commit_to_conf:
|
if commit_to_conf:
|
||||||
self.commit(commit_to_conf)
|
self.commit(commit_to_conf)
|
||||||
return model_key
|
return self.create_key(model_name, SDModelType.diffusers)
|
||||||
|
|
||||||
def import_lora(
|
def import_lora(
|
||||||
self,
|
self,
|
||||||
@ -635,7 +653,8 @@ class ModelManager(object):
|
|||||||
model_name = model_name or path.stem
|
model_name = model_name or path.stem
|
||||||
model_description = description or f"LoRA model {model_name}"
|
model_description = description or f"LoRA model {model_name}"
|
||||||
self.add_model(
|
self.add_model(
|
||||||
f'{model_name}/{SDModelType.lora.name}',
|
model_name,
|
||||||
|
SDModelType.lora,
|
||||||
dict(
|
dict(
|
||||||
format="lora",
|
format="lora",
|
||||||
weights=str(path),
|
weights=str(path),
|
||||||
@ -663,7 +682,8 @@ class ModelManager(object):
|
|||||||
model_name = model_name or path.stem
|
model_name = model_name or path.stem
|
||||||
model_description = description or f"Textual embedding model {model_name}"
|
model_description = description or f"Textual embedding model {model_name}"
|
||||||
self.add_model(
|
self.add_model(
|
||||||
f'{model_name}/{SDModelType.textual_inversion.name}',
|
model_name,
|
||||||
|
SDModelType.textual_inversion,
|
||||||
dict(
|
dict(
|
||||||
format="textual_inversion",
|
format="textual_inversion",
|
||||||
weights=str(weights),
|
weights=str(weights),
|
||||||
@ -1025,9 +1045,14 @@ class ModelManager(object):
|
|||||||
description=model_description,
|
description=model_description,
|
||||||
format="diffusers",
|
format="diffusers",
|
||||||
)
|
)
|
||||||
if model_name in self.config:
|
if self.model_exists(model_name, SDModelType.diffusers):
|
||||||
self.del_model(model_name)
|
self.del_model(model_name, SDModelType.diffusers)
|
||||||
self.add_model(model_name, new_config, True)
|
self.add_model(
|
||||||
|
model_name,
|
||||||
|
SDModelType.diffusers,
|
||||||
|
new_config,
|
||||||
|
True
|
||||||
|
)
|
||||||
if commit_to_conf:
|
if commit_to_conf:
|
||||||
self.commit(commit_to_conf)
|
self.commit(commit_to_conf)
|
||||||
self.logger.debug("Conversion succeeded")
|
self.logger.debug("Conversion succeeded")
|
||||||
@ -1081,11 +1106,6 @@ class ModelManager(object):
|
|||||||
"""\
|
"""\
|
||||||
# This file describes the alternative machine learning models
|
# This file describes the alternative machine learning models
|
||||||
# available to InvokeAI script.
|
# available to InvokeAI script.
|
||||||
#
|
|
||||||
# To add a new model, follow the examples below. Each
|
|
||||||
# model requires a model config file, a weights file,
|
|
||||||
# and the width and height of the images it
|
|
||||||
# was trained on.
|
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1130,3 +1150,53 @@ class ModelManager(object):
|
|||||||
source = os.path.join(Globals.root, source)
|
source = os.path.join(Globals.root, source)
|
||||||
resolved_path = Path(source)
|
resolved_path = Path(source)
|
||||||
return resolved_path
|
return resolved_path
|
||||||
|
|
||||||
|
def _update_config_file_version(self):
|
||||||
|
"""
|
||||||
|
This gets called at object init time and will update
|
||||||
|
from older versions of the config file to new ones
|
||||||
|
as necessary.
|
||||||
|
"""
|
||||||
|
current_version = self.config.get("config_file_version","1.0.0")
|
||||||
|
if version.parse(current_version) < version.parse(CONFIG_FILE_VERSION):
|
||||||
|
self.logger.info(f'models.yaml version {current_version} detected. Updating to {CONFIG_FILE_VERSION}')
|
||||||
|
|
||||||
|
new_config = OmegaConf.create()
|
||||||
|
new_config["config_file_version"] = CONFIG_FILE_VERSION
|
||||||
|
|
||||||
|
for model_key in self.config:
|
||||||
|
|
||||||
|
old_stanza = self.config[model_key]
|
||||||
|
|
||||||
|
# ignore old and ugly way of associating a legacy
|
||||||
|
# vae with a legacy checkpont model
|
||||||
|
if old_stanza.get("config") and '/VAE/' in old_stanza.get("config"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# bare keys are updated to be prefixed with 'diffusers/'
|
||||||
|
if '/' not in model_key:
|
||||||
|
new_key = f'diffusers/{model_key}'
|
||||||
|
else:
|
||||||
|
new_key = model_key
|
||||||
|
|
||||||
|
if old_stanza.get('format')=='diffusers':
|
||||||
|
model_format = 'folder'
|
||||||
|
elif old_stanza.get('weights') and Path(old_stanza.get('weights')).suffix == '.ckpt':
|
||||||
|
model_format = 'ckpt'
|
||||||
|
elif old_stanza.get('weights') and Path(old_stanza.get('weights')).suffix == '.safetensors':
|
||||||
|
model_format = 'safetensors'
|
||||||
|
|
||||||
|
# copy fields over manually rather than doing a copy() or deepcopy()
|
||||||
|
# in order to avoid bringing in unwanted fields.
|
||||||
|
new_config[new_key] = dict(
|
||||||
|
description = old_stanza.get('description'),
|
||||||
|
format = model_format,
|
||||||
|
)
|
||||||
|
for field in ["repo_id", "path", "weights", "config", "vae"]:
|
||||||
|
if field_value := old_stanza.get(field):
|
||||||
|
new_config[new_key].update({field: field_value})
|
||||||
|
|
||||||
|
self.config = new_config
|
||||||
|
if self.config_path:
|
||||||
|
self.commit()
|
||||||
|
|
||||||
|
@ -345,13 +345,14 @@ class Completer(object):
|
|||||||
partial = text
|
partial = text
|
||||||
matches = list()
|
matches = list()
|
||||||
for s in self.models:
|
for s in self.models:
|
||||||
|
name = self.models[s]["model_name"]
|
||||||
format = self.models[s]["format"]
|
format = self.models[s]["format"]
|
||||||
if format == "vae":
|
if format == "vae":
|
||||||
continue
|
continue
|
||||||
if ckpt_only and format != "ckpt":
|
if ckpt_only and format != "ckpt":
|
||||||
continue
|
continue
|
||||||
if s.startswith(partial):
|
if name.startswith(partial):
|
||||||
matches.append(switch + s)
|
matches.append(switch + name)
|
||||||
matches.sort()
|
matches.sort()
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user