Model manager draft

This commit is contained in:
Sergey Borisov 2023-05-18 03:56:52 +03:00
parent e971a7f35c
commit fd82763412
5 changed files with 561 additions and 520 deletions

View File

@ -10,7 +10,7 @@ from .generator import (
Img2Img, Img2Img,
Inpaint Inpaint
) )
from .model_management import ModelManager, ModelCache, ModelStatus, SDModelType, SDModelInfo from .model_management import ModelManager, ModelCache, SDModelType, SDModelInfo
from .safety_checker import SafetyChecker from .safety_checker import SafetyChecker
from .args import Args from .args import Args
from .globals import Globals from .globals import Globals

View File

@ -2,4 +2,4 @@
Initialization file for invokeai.backend.model_management Initialization file for invokeai.backend.model_management
""" """
from .model_manager import ModelManager, SDModelInfo from .model_manager import ModelManager, SDModelInfo
from .model_cache import ModelCache, ModelStatus, SDModelType from .model_cache import ModelCache, SDModelType

File diff suppressed because it is too large Load Diff

View File

@ -15,8 +15,6 @@ return a SDModelInfo object that contains the following attributes:
* revision -- revision of the model if coming from a repo id, * revision -- revision of the model if coming from a repo id,
e.g. 'fp16' e.g. 'fp16'
* precision -- torch precision of the model * precision -- torch precision of the model
* status -- a ModelStatus enum corresponding to one of
'not_loaded', 'in_ram', 'in_vram' or 'active'
Typical usage: Typical usage:
@ -157,8 +155,8 @@ from invokeai.backend.globals import (Globals, global_cache_dir,
from invokeai.backend.util import download_with_resume from invokeai.backend.util import download_with_resume
from ..util import CUDA_DEVICE from ..util import CUDA_DEVICE
from .model_cache import (ModelCache, ModelLocker, ModelStatus, SDModelType, from .model_cache import (ModelCache, ModelLocker, SDModelType,
SilenceWarnings, DIFFUSERS_PARTS) SilenceWarnings)
# We are only starting to number the config file with release 3. # We are only starting to number the config file with release 3.
# The config file version doesn't have to start at release version, but it will help # The config file version doesn't have to start at release version, but it will help
@ -174,7 +172,6 @@ class SDModelInfo():
hash: str hash: str
location: Union[Path,str] location: Union[Path,str]
precision: torch.dtype precision: torch.dtype
subfolder: Path = None
revision: str = None revision: str = None
_cache: ModelCache = None _cache: ModelCache = None
@ -183,17 +180,6 @@ class SDModelInfo():
def __exit__(self,*args, **kwargs): def __exit__(self,*args, **kwargs):
self.context.__exit__(*args, **kwargs) self.context.__exit__(*args, **kwargs)
@property
def status(self)->ModelStatus:
'''Return load status of this model as a model_cache.ModelStatus enum'''
if not self._cache:
return ModelStatus.unknown
return self._cache.status(
self.location,
revision = self.revision,
subfolder = self.subfolder
)
class InvalidModelError(Exception): class InvalidModelError(Exception):
"Raised when an invalid model is requested" "Raised when an invalid model is requested"
@ -355,7 +341,6 @@ class ModelManager(object):
or global_resolve_path(mconfig.get('weights') or global_resolve_path(mconfig.get('weights')
) )
subfolder = mconfig.get('subfolder')
revision = mconfig.get('revision') revision = mconfig.get('revision')
hash = self.cache.model_hash(location, revision) hash = self.cache.model_hash(location, revision)
@ -367,31 +352,18 @@ class ModelManager(object):
model_type = submodel model_type = submodel
submodel = None submodel = None
# We don't need to load whole model if the user is asking for just a piece of it
elif model_type == SDModelType.Diffusers and submodel and not subfolder:
model_type = submodel
subfolder = submodel.value
submodel = None
# to support the traditional way of attaching a VAE # to support the traditional way of attaching a VAE
# to a model, we hacked in `attach_model_part` # to a model, we hacked in `attach_model_part`
# TODO: generalize this # TODO:
external_parts = set() if model_type == SDModelType.Vae and "vae" in mconfig:
if model_type == SDModelType.Diffusers: print("NOT_IMPLEMENTED - RETURN CUSTOM VAE")
for part in DIFFUSERS_PARTS:
with suppress(Exception):
if part_config := mconfig.get(part):
id = part_config.get('path') or part_config.get('repo_id')
subfolder = part_config.get('subfolder')
external_parts.add((part, id, subfolder))
model_context = self.cache.get_model( model_context = self.cache.get_model(
location, location,
model_type = model_type, model_type = model_type,
revision = revision, revision = revision,
subfolder = subfolder,
submodel = submodel, submodel = submodel,
attach_model_parts = external_parts,
) )
# in case we need to communicate information about this # in case we need to communicate information about this
@ -407,7 +379,6 @@ class ModelManager(object):
location = location, location = location,
revision = revision, revision = revision,
precision = self.cache.precision, precision = self.cache.precision,
subfolder = subfolder,
_cache = self.cache _cache = self.cache
) )
@ -513,18 +484,13 @@ class ModelManager(object):
model_format = stanza.get('format') model_format = stanza.get('format')
# Common Attribs # Common Attribs
status = self.cache.status(
stanza.get('weights') or stanza.get('repo_id'),
revision=stanza.get('revision'),
subfolder=stanza.get('subfolder'),
)
description = stanza.get("description", None) description = stanza.get("description", None)
models[stanza_type][model_name].update( models[stanza_type][model_name].update(
model_name=model_name, model_name=model_name,
model_type=stanza_type, model_type=stanza_type,
format=model_format, format=model_format,
description=description, description=description,
status=status.value, status="unknown", # TODO: no more status as model loaded separately
) )
# Checkpoint Config Parse # Checkpoint Config Parse

View File

@ -14,7 +14,7 @@ export const receivedModels = createAppAsyncThunk(
const response = await ModelsService.listModels(); const response = await ModelsService.listModels();
const deserializedModels = reduce( const deserializedModels = reduce(
response.models, response.models['diffusers'],
(modelsAccumulator, model, modelName) => { (modelsAccumulator, model, modelName) => {
modelsAccumulator[modelName] = { ...model, name: modelName }; modelsAccumulator[modelName] = { ...model, name: modelName };
@ -23,7 +23,10 @@ export const receivedModels = createAppAsyncThunk(
{} as Record<string, Model> {} as Record<string, Model>
); );
models.info({ response }, `Received ${size(response.models)} models`); models.info(
{ response },
`Received ${size(response.models['diffusers'])} models`
);
return deserializedModels; return deserializedModels;
} }