mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Model manager draft
This commit is contained in:
parent
e971a7f35c
commit
fd82763412
@ -10,7 +10,7 @@ from .generator import (
|
||||
Img2Img,
|
||||
Inpaint
|
||||
)
|
||||
from .model_management import ModelManager, ModelCache, ModelStatus, SDModelType, SDModelInfo
|
||||
from .model_management import ModelManager, ModelCache, SDModelType, SDModelInfo
|
||||
from .safety_checker import SafetyChecker
|
||||
from .args import Args
|
||||
from .globals import Globals
|
||||
|
@ -2,4 +2,4 @@
|
||||
Initialization file for invokeai.backend.model_management
|
||||
"""
|
||||
from .model_manager import ModelManager, SDModelInfo
|
||||
from .model_cache import ModelCache, ModelStatus, SDModelType
|
||||
from .model_cache import ModelCache, SDModelType
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -15,8 +15,6 @@ return a SDModelInfo object that contains the following attributes:
|
||||
* revision -- revision of the model if coming from a repo id,
|
||||
e.g. 'fp16'
|
||||
* precision -- torch precision of the model
|
||||
* status -- a ModelStatus enum corresponding to one of
|
||||
'not_loaded', 'in_ram', 'in_vram' or 'active'
|
||||
|
||||
Typical usage:
|
||||
|
||||
@ -157,8 +155,8 @@ from invokeai.backend.globals import (Globals, global_cache_dir,
|
||||
from invokeai.backend.util import download_with_resume
|
||||
|
||||
from ..util import CUDA_DEVICE
|
||||
from .model_cache import (ModelCache, ModelLocker, ModelStatus, SDModelType,
|
||||
SilenceWarnings, DIFFUSERS_PARTS)
|
||||
from .model_cache import (ModelCache, ModelLocker, SDModelType,
|
||||
SilenceWarnings)
|
||||
|
||||
# We are only starting to number the config file with release 3.
|
||||
# The config file version doesn't have to start at release version, but it will help
|
||||
@ -174,7 +172,6 @@ class SDModelInfo():
|
||||
hash: str
|
||||
location: Union[Path,str]
|
||||
precision: torch.dtype
|
||||
subfolder: Path = None
|
||||
revision: str = None
|
||||
_cache: ModelCache = None
|
||||
|
||||
@ -183,17 +180,6 @@ class SDModelInfo():
|
||||
|
||||
def __exit__(self,*args, **kwargs):
|
||||
self.context.__exit__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def status(self)->ModelStatus:
|
||||
'''Return load status of this model as a model_cache.ModelStatus enum'''
|
||||
if not self._cache:
|
||||
return ModelStatus.unknown
|
||||
return self._cache.status(
|
||||
self.location,
|
||||
revision = self.revision,
|
||||
subfolder = self.subfolder
|
||||
)
|
||||
|
||||
class InvalidModelError(Exception):
|
||||
"Raised when an invalid model is requested"
|
||||
@ -355,7 +341,6 @@ class ModelManager(object):
|
||||
or global_resolve_path(mconfig.get('weights')
|
||||
)
|
||||
|
||||
subfolder = mconfig.get('subfolder')
|
||||
revision = mconfig.get('revision')
|
||||
hash = self.cache.model_hash(location, revision)
|
||||
|
||||
@ -367,31 +352,18 @@ class ModelManager(object):
|
||||
model_type = submodel
|
||||
submodel = None
|
||||
|
||||
# We don't need to load whole model if the user is asking for just a piece of it
|
||||
elif model_type == SDModelType.Diffusers and submodel and not subfolder:
|
||||
model_type = submodel
|
||||
subfolder = submodel.value
|
||||
submodel = None
|
||||
|
||||
# to support the traditional way of attaching a VAE
|
||||
# to a model, we hacked in `attach_model_part`
|
||||
# TODO: generalize this
|
||||
external_parts = set()
|
||||
if model_type == SDModelType.Diffusers:
|
||||
for part in DIFFUSERS_PARTS:
|
||||
with suppress(Exception):
|
||||
if part_config := mconfig.get(part):
|
||||
id = part_config.get('path') or part_config.get('repo_id')
|
||||
subfolder = part_config.get('subfolder')
|
||||
external_parts.add((part, id, subfolder))
|
||||
# TODO:
|
||||
if model_type == SDModelType.Vae and "vae" in mconfig:
|
||||
print("NOT_IMPLEMENTED - RETURN CUSTOM VAE")
|
||||
|
||||
|
||||
model_context = self.cache.get_model(
|
||||
location,
|
||||
model_type = model_type,
|
||||
revision = revision,
|
||||
subfolder = subfolder,
|
||||
submodel = submodel,
|
||||
attach_model_parts = external_parts,
|
||||
)
|
||||
|
||||
# in case we need to communicate information about this
|
||||
@ -407,7 +379,6 @@ class ModelManager(object):
|
||||
location = location,
|
||||
revision = revision,
|
||||
precision = self.cache.precision,
|
||||
subfolder = subfolder,
|
||||
_cache = self.cache
|
||||
)
|
||||
|
||||
@ -513,18 +484,13 @@ class ModelManager(object):
|
||||
model_format = stanza.get('format')
|
||||
|
||||
# Common Attribs
|
||||
status = self.cache.status(
|
||||
stanza.get('weights') or stanza.get('repo_id'),
|
||||
revision=stanza.get('revision'),
|
||||
subfolder=stanza.get('subfolder'),
|
||||
)
|
||||
description = stanza.get("description", None)
|
||||
models[stanza_type][model_name].update(
|
||||
model_name=model_name,
|
||||
model_type=stanza_type,
|
||||
format=model_format,
|
||||
description=description,
|
||||
status=status.value,
|
||||
status="unknown", # TODO: no more status as model loaded separately
|
||||
)
|
||||
|
||||
# Checkpoint Config Parse
|
||||
|
@ -14,7 +14,7 @@ export const receivedModels = createAppAsyncThunk(
|
||||
const response = await ModelsService.listModels();
|
||||
|
||||
const deserializedModels = reduce(
|
||||
response.models,
|
||||
response.models['diffusers'],
|
||||
(modelsAccumulator, model, modelName) => {
|
||||
modelsAccumulator[modelName] = { ...model, name: modelName };
|
||||
|
||||
@ -23,7 +23,10 @@ export const receivedModels = createAppAsyncThunk(
|
||||
{} as Record<string, Model>
|
||||
);
|
||||
|
||||
models.info({ response }, `Received ${size(response.models)} models`);
|
||||
models.info(
|
||||
{ response },
|
||||
`Received ${size(response.models['diffusers'])} models`
|
||||
);
|
||||
|
||||
return deserializedModels;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user