mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Model manager draft
This commit is contained in:
parent
e971a7f35c
commit
fd82763412
@ -10,7 +10,7 @@ from .generator import (
|
|||||||
Img2Img,
|
Img2Img,
|
||||||
Inpaint
|
Inpaint
|
||||||
)
|
)
|
||||||
from .model_management import ModelManager, ModelCache, ModelStatus, SDModelType, SDModelInfo
|
from .model_management import ModelManager, ModelCache, SDModelType, SDModelInfo
|
||||||
from .safety_checker import SafetyChecker
|
from .safety_checker import SafetyChecker
|
||||||
from .args import Args
|
from .args import Args
|
||||||
from .globals import Globals
|
from .globals import Globals
|
||||||
|
@ -2,4 +2,4 @@
|
|||||||
Initialization file for invokeai.backend.model_management
|
Initialization file for invokeai.backend.model_management
|
||||||
"""
|
"""
|
||||||
from .model_manager import ModelManager, SDModelInfo
|
from .model_manager import ModelManager, SDModelInfo
|
||||||
from .model_cache import ModelCache, ModelStatus, SDModelType
|
from .model_cache import ModelCache, SDModelType
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -15,8 +15,6 @@ return a SDModelInfo object that contains the following attributes:
|
|||||||
* revision -- revision of the model if coming from a repo id,
|
* revision -- revision of the model if coming from a repo id,
|
||||||
e.g. 'fp16'
|
e.g. 'fp16'
|
||||||
* precision -- torch precision of the model
|
* precision -- torch precision of the model
|
||||||
* status -- a ModelStatus enum corresponding to one of
|
|
||||||
'not_loaded', 'in_ram', 'in_vram' or 'active'
|
|
||||||
|
|
||||||
Typical usage:
|
Typical usage:
|
||||||
|
|
||||||
@ -157,8 +155,8 @@ from invokeai.backend.globals import (Globals, global_cache_dir,
|
|||||||
from invokeai.backend.util import download_with_resume
|
from invokeai.backend.util import download_with_resume
|
||||||
|
|
||||||
from ..util import CUDA_DEVICE
|
from ..util import CUDA_DEVICE
|
||||||
from .model_cache import (ModelCache, ModelLocker, ModelStatus, SDModelType,
|
from .model_cache import (ModelCache, ModelLocker, SDModelType,
|
||||||
SilenceWarnings, DIFFUSERS_PARTS)
|
SilenceWarnings)
|
||||||
|
|
||||||
# We are only starting to number the config file with release 3.
|
# We are only starting to number the config file with release 3.
|
||||||
# The config file version doesn't have to start at release version, but it will help
|
# The config file version doesn't have to start at release version, but it will help
|
||||||
@ -174,7 +172,6 @@ class SDModelInfo():
|
|||||||
hash: str
|
hash: str
|
||||||
location: Union[Path,str]
|
location: Union[Path,str]
|
||||||
precision: torch.dtype
|
precision: torch.dtype
|
||||||
subfolder: Path = None
|
|
||||||
revision: str = None
|
revision: str = None
|
||||||
_cache: ModelCache = None
|
_cache: ModelCache = None
|
||||||
|
|
||||||
@ -183,17 +180,6 @@ class SDModelInfo():
|
|||||||
|
|
||||||
def __exit__(self,*args, **kwargs):
|
def __exit__(self,*args, **kwargs):
|
||||||
self.context.__exit__(*args, **kwargs)
|
self.context.__exit__(*args, **kwargs)
|
||||||
|
|
||||||
@property
|
|
||||||
def status(self)->ModelStatus:
|
|
||||||
'''Return load status of this model as a model_cache.ModelStatus enum'''
|
|
||||||
if not self._cache:
|
|
||||||
return ModelStatus.unknown
|
|
||||||
return self._cache.status(
|
|
||||||
self.location,
|
|
||||||
revision = self.revision,
|
|
||||||
subfolder = self.subfolder
|
|
||||||
)
|
|
||||||
|
|
||||||
class InvalidModelError(Exception):
|
class InvalidModelError(Exception):
|
||||||
"Raised when an invalid model is requested"
|
"Raised when an invalid model is requested"
|
||||||
@ -355,7 +341,6 @@ class ModelManager(object):
|
|||||||
or global_resolve_path(mconfig.get('weights')
|
or global_resolve_path(mconfig.get('weights')
|
||||||
)
|
)
|
||||||
|
|
||||||
subfolder = mconfig.get('subfolder')
|
|
||||||
revision = mconfig.get('revision')
|
revision = mconfig.get('revision')
|
||||||
hash = self.cache.model_hash(location, revision)
|
hash = self.cache.model_hash(location, revision)
|
||||||
|
|
||||||
@ -367,31 +352,18 @@ class ModelManager(object):
|
|||||||
model_type = submodel
|
model_type = submodel
|
||||||
submodel = None
|
submodel = None
|
||||||
|
|
||||||
# We don't need to load whole model if the user is asking for just a piece of it
|
|
||||||
elif model_type == SDModelType.Diffusers and submodel and not subfolder:
|
|
||||||
model_type = submodel
|
|
||||||
subfolder = submodel.value
|
|
||||||
submodel = None
|
|
||||||
|
|
||||||
# to support the traditional way of attaching a VAE
|
# to support the traditional way of attaching a VAE
|
||||||
# to a model, we hacked in `attach_model_part`
|
# to a model, we hacked in `attach_model_part`
|
||||||
# TODO: generalize this
|
# TODO:
|
||||||
external_parts = set()
|
if model_type == SDModelType.Vae and "vae" in mconfig:
|
||||||
if model_type == SDModelType.Diffusers:
|
print("NOT_IMPLEMENTED - RETURN CUSTOM VAE")
|
||||||
for part in DIFFUSERS_PARTS:
|
|
||||||
with suppress(Exception):
|
|
||||||
if part_config := mconfig.get(part):
|
|
||||||
id = part_config.get('path') or part_config.get('repo_id')
|
|
||||||
subfolder = part_config.get('subfolder')
|
|
||||||
external_parts.add((part, id, subfolder))
|
|
||||||
|
|
||||||
model_context = self.cache.get_model(
|
model_context = self.cache.get_model(
|
||||||
location,
|
location,
|
||||||
model_type = model_type,
|
model_type = model_type,
|
||||||
revision = revision,
|
revision = revision,
|
||||||
subfolder = subfolder,
|
|
||||||
submodel = submodel,
|
submodel = submodel,
|
||||||
attach_model_parts = external_parts,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# in case we need to communicate information about this
|
# in case we need to communicate information about this
|
||||||
@ -407,7 +379,6 @@ class ModelManager(object):
|
|||||||
location = location,
|
location = location,
|
||||||
revision = revision,
|
revision = revision,
|
||||||
precision = self.cache.precision,
|
precision = self.cache.precision,
|
||||||
subfolder = subfolder,
|
|
||||||
_cache = self.cache
|
_cache = self.cache
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -513,18 +484,13 @@ class ModelManager(object):
|
|||||||
model_format = stanza.get('format')
|
model_format = stanza.get('format')
|
||||||
|
|
||||||
# Common Attribs
|
# Common Attribs
|
||||||
status = self.cache.status(
|
|
||||||
stanza.get('weights') or stanza.get('repo_id'),
|
|
||||||
revision=stanza.get('revision'),
|
|
||||||
subfolder=stanza.get('subfolder'),
|
|
||||||
)
|
|
||||||
description = stanza.get("description", None)
|
description = stanza.get("description", None)
|
||||||
models[stanza_type][model_name].update(
|
models[stanza_type][model_name].update(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
model_type=stanza_type,
|
model_type=stanza_type,
|
||||||
format=model_format,
|
format=model_format,
|
||||||
description=description,
|
description=description,
|
||||||
status=status.value,
|
status="unknown", # TODO: no more status as model loaded separately
|
||||||
)
|
)
|
||||||
|
|
||||||
# Checkpoint Config Parse
|
# Checkpoint Config Parse
|
||||||
|
@ -14,7 +14,7 @@ export const receivedModels = createAppAsyncThunk(
|
|||||||
const response = await ModelsService.listModels();
|
const response = await ModelsService.listModels();
|
||||||
|
|
||||||
const deserializedModels = reduce(
|
const deserializedModels = reduce(
|
||||||
response.models,
|
response.models['diffusers'],
|
||||||
(modelsAccumulator, model, modelName) => {
|
(modelsAccumulator, model, modelName) => {
|
||||||
modelsAccumulator[modelName] = { ...model, name: modelName };
|
modelsAccumulator[modelName] = { ...model, name: modelName };
|
||||||
|
|
||||||
@ -23,7 +23,10 @@ export const receivedModels = createAppAsyncThunk(
|
|||||||
{} as Record<string, Model>
|
{} as Record<string, Model>
|
||||||
);
|
);
|
||||||
|
|
||||||
models.info({ response }, `Received ${size(response.models)} models`);
|
models.info(
|
||||||
|
{ response },
|
||||||
|
`Received ${size(response.models['diffusers'])} models`
|
||||||
|
);
|
||||||
|
|
||||||
return deserializedModels;
|
return deserializedModels;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user