mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
model manager now running as a service
This commit is contained in:
parent
8ad8c5c67a
commit
df5b968954
@ -13,18 +13,20 @@ from typing import (
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import Field
|
||||
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
import invokeai.version
|
||||
|
||||
from invokeai.app.services.metadata import PngMetadataService
|
||||
from .services.default_graphs import create_system_graphs
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from ..backend import Args
|
||||
from ..backend import Globals # this should go when pr 3340 merged
|
||||
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers
|
||||
from .cli.completer import set_autocompleter
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.events import EventServiceBase
|
||||
from .services.model_manager_initializer import get_model_manager
|
||||
from .services.restoration_services import RestorationServices
|
||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
|
||||
from .services.default_graphs import default_text_to_image_graph_id
|
||||
@ -34,7 +36,7 @@ from .services.invocation_services import InvocationServices
|
||||
from .services.invoker import Invoker
|
||||
from .services.processor import DefaultInvocationProcessor
|
||||
from .services.sqlite import SqliteItemStorage
|
||||
|
||||
from .services.model_manager_service import ModelManagerService
|
||||
|
||||
class CliCommand(BaseModel):
|
||||
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
|
||||
@ -191,7 +193,11 @@ def invoke_all(context: CliContext):
|
||||
def invoke_cli():
|
||||
config = Args()
|
||||
config.parse_args()
|
||||
model_manager = get_model_manager(config,logger=logger)
|
||||
|
||||
logger.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
||||
logger.info(f'InvokeAI runtime directory is "{Globals.root}"')
|
||||
|
||||
model_manager = ModelManagerService(config,logger)
|
||||
|
||||
# This initializes the autocompleter and returns it.
|
||||
# Currently nothing is done with the returned Completer
|
||||
|
@ -1,7 +1,6 @@
|
||||
from typing import Literal, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
@ -58,74 +57,74 @@ class CompelInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
|
||||
# TODO: load without model
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
pipeline = model.context.model
|
||||
tokenizer = pipeline.tokenizer
|
||||
text_encoder = pipeline.text_encoder
|
||||
model = context.services.model_manager.get_model(self.model)
|
||||
with model.context as pipeline:
|
||||
tokenizer = pipeline.tokenizer
|
||||
text_encoder = pipeline.text_encoder
|
||||
|
||||
# TODO: global? input?
|
||||
#use_full_precision = precision == "float32" or precision == "autocast"
|
||||
#use_full_precision = False
|
||||
# TODO: global? input?
|
||||
#use_full_precision = precision == "float32" or precision == "autocast"
|
||||
#use_full_precision = False
|
||||
|
||||
# TODO: redo TI when separate model loding implemented
|
||||
#textual_inversion_manager = TextualInversionManager(
|
||||
# tokenizer=tokenizer,
|
||||
# text_encoder=text_encoder,
|
||||
# full_precision=use_full_precision,
|
||||
#)
|
||||
# TODO: redo TI when separate model loding implemented
|
||||
#textual_inversion_manager = TextualInversionManager(
|
||||
# tokenizer=tokenizer,
|
||||
# text_encoder=text_encoder,
|
||||
# full_precision=use_full_precision,
|
||||
#)
|
||||
|
||||
def load_huggingface_concepts(concepts: list[str]):
|
||||
pipeline.textual_inversion_manager.load_huggingface_concepts(concepts)
|
||||
def load_huggingface_concepts(concepts: list[str]):
|
||||
pipeline.textual_inversion_manager.load_huggingface_concepts(concepts)
|
||||
|
||||
# apply the concepts library to the prompt
|
||||
prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers(
|
||||
self.prompt,
|
||||
lambda concepts: load_huggingface_concepts(concepts),
|
||||
pipeline.textual_inversion_manager.get_all_trigger_strings(),
|
||||
)
|
||||
# apply the concepts library to the prompt
|
||||
prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers(
|
||||
self.prompt,
|
||||
lambda concepts: load_huggingface_concepts(concepts),
|
||||
pipeline.textual_inversion_manager.get_all_trigger_strings(),
|
||||
)
|
||||
|
||||
# lazy-load any deferred textual inversions.
|
||||
# this might take a couple of seconds the first time a textual inversion is used.
|
||||
pipeline.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
|
||||
prompt_str
|
||||
)
|
||||
# lazy-load any deferred textual inversions.
|
||||
# this might take a couple of seconds the first time a textual inversion is used.
|
||||
pipeline.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
|
||||
prompt_str
|
||||
)
|
||||
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=pipeline.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=True, # TODO:
|
||||
)
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=pipeline.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=True, # TODO:
|
||||
)
|
||||
|
||||
# TODO: support legacy blend?
|
||||
# TODO: support legacy blend?
|
||||
|
||||
prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(prompt_str)
|
||||
prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(prompt_str)
|
||||
|
||||
if getattr(Globals, "log_tokenization", False):
|
||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
||||
if getattr(Globals, "log_tokenization", False):
|
||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
||||
|
||||
# TODO: long prompt support
|
||||
#if not self.truncate_long_prompts:
|
||||
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||
# TODO: long prompt support
|
||||
#if not self.truncate_long_prompts:
|
||||
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt),
|
||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||
)
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt),
|
||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||
)
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
|
||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||
context.services.latents.set(conditioning_name, (c, ec))
|
||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||
context.services.latents.set(conditioning_name, (c, ec))
|
||||
|
||||
return CompelOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
return CompelOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_max_token_count(
|
||||
|
@ -9,7 +9,6 @@ from torch import Tensor
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.models.image import ColorField, ImageField, ImageType
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.generator.inpaint import infill_methods
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
@ -72,7 +71,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Handle invalid model parameter
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
model = context.services.model_manager.get_model(self.model)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
|
@ -6,7 +6,6 @@ import einops
|
||||
from pydantic import BaseModel, Field
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
@ -177,7 +176,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
|
||||
model_info = choose_model(model_manager, self.model)
|
||||
model_info = model_manager.get_model(self.model)
|
||||
model_name = model_info.name
|
||||
model_hash = model_info.hash
|
||||
model_ctx: StableDiffusionGeneratorPipeline = model_info.context
|
||||
|
@ -1,14 +0,0 @@
|
||||
from invokeai.backend.model_management.model_manager_service import ModelManagerService, SDModelType
|
||||
|
||||
|
||||
def choose_model(model_manager: ModelManagerService, model_name: str, model_type: SDModelType=SDModelType.diffusers):
|
||||
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
|
||||
logger = model_manager.logger
|
||||
if model_name and not model_manager.valid_model(model_name, model_type):
|
||||
default_model_name = model_manager.default_model()
|
||||
logger.warning(f"\'{model_name}\' is not a valid model name. Using default model \'{default_model_name}\' instead.")
|
||||
model = model_manager.get_model()
|
||||
else:
|
||||
model = model_manager.get_model(model_name, model_type)
|
||||
|
||||
return model
|
@ -1,211 +0,0 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Union, Callable
|
||||
|
||||
from invokeai.backend import ModelManager, SDModelType, SDModelInfo
|
||||
|
||||
class ModelManagerServiceBase(ABC):
|
||||
"""Responsible for managing models on disk and in memory"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, model_name: str, submodel: SDModelType=None)->SDModelInfo:
|
||||
"""Retrieve the indicated model. submodel can be used to get a
|
||||
part (such as the vae) of a diffusers mode.l"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def valid_model(self, model_name: str) -> bool:
|
||||
"""
|
||||
Given a model name, returns True if it is a valid
|
||||
identifier.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def default_model(self) -> Union[str,None]:
|
||||
"""
|
||||
Returns the name of the default model, or None
|
||||
if none is defined.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_default_model(self, model_name:str):
|
||||
"""Sets the default model to the indicated name."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_info(self, model_name: str)->dict:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_names(self)->list[str]:
|
||||
"""
|
||||
Returns a list of all the model names known.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_models(self)->dict:
|
||||
"""
|
||||
Return a dict of models in the format:
|
||||
{ model_name1: {'status': ('active'|'cached'|'not loaded'),
|
||||
'description': description,
|
||||
'format': ('ckpt'|'diffusers'|'vae'|'text_encoder'|'tokenizer'|'lora'...),
|
||||
},
|
||||
model_name2: { etc }
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def add_model(
|
||||
self, model_name: str, model_attributes: dict, clobber: bool = False)->None:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def del_model(self, model_name: str, delete_files: bool = False) -> None:
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files is true,
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
as well. Call commit() to write to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def import_diffuser_model(
|
||||
repo_or_path: Union[str, Path],
|
||||
model_name: str = None,
|
||||
description: str = None,
|
||||
vae: dict = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Install the indicated diffuser model and returns True if successful.
|
||||
|
||||
"repo_or_path" can be either a repo-id or a path-like object corresponding to the
|
||||
top of a downloaded diffusers directory.
|
||||
|
||||
You can optionally provide a model name and/or description. If not provided,
|
||||
then these will be derived from the repo name. Call commit() to write to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def import_lora(
|
||||
self,
|
||||
path: Path,
|
||||
model_name: str=None,
|
||||
description: str=None,
|
||||
):
|
||||
"""
|
||||
Creates an entry for the indicated lora file. Call
|
||||
mgr.commit() to write out the configuration to models.yaml
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def import_embedding(
|
||||
self,
|
||||
path: Path,
|
||||
model_name: str=None,
|
||||
description: str=None,
|
||||
):
|
||||
"""
|
||||
Creates an entry for the indicated textual inversion embedding file.
|
||||
Call commit() to write out the configuration to models.yaml
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def heuristic_import(
|
||||
self,
|
||||
path_url_or_repo: str,
|
||||
model_name: str = None,
|
||||
description: str = None,
|
||||
model_config_file: Path = None,
|
||||
commit_to_conf: Path = None,
|
||||
config_file_callback: Callable[[Path], Path] = None,
|
||||
) -> str:
|
||||
"""Accept a string which could be:
|
||||
- a HF diffusers repo_id
|
||||
- a URL pointing to a legacy .ckpt or .safetensors file
|
||||
- a local path pointing to a legacy .ckpt or .safetensors file
|
||||
- a local directory containing .ckpt and .safetensors files
|
||||
- a local directory containing a diffusers model
|
||||
|
||||
After determining the nature of the model and downloading it
|
||||
(if necessary), the file is probed to determine the correct
|
||||
configuration file (if needed) and it is imported.
|
||||
|
||||
The model_name and/or description can be provided. If not, they will
|
||||
be generated automatically.
|
||||
|
||||
If commit_to_conf is provided, the newly loaded model will be written
|
||||
to the `models.yaml` file at the indicated path. Otherwise, the changes
|
||||
will only remain in memory.
|
||||
|
||||
The routine will do its best to figure out the config file
|
||||
needed to convert legacy checkpoint file, but if it can't it
|
||||
will call the config_file_callback routine, if provided. The
|
||||
callback accepts a single argument, the Path to the checkpoint
|
||||
file, and returns a Path to the config file to use.
|
||||
|
||||
The (potentially derived) name of the model is returned on
|
||||
success, or None on failure. When multiple models are added
|
||||
from a directory, only the last imported one is returned.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def commit(self, conf_file: Path=None) -> None:
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
If no conf_file is provided, then replaces the
|
||||
original file/database used to initialize the object.
|
||||
"""
|
||||
pass
|
||||
|
||||
# simple implementation
|
||||
class ModelManagerService(ModelManagerServiceBase):
|
||||
"""Responsible for managing models on disk and in memory"""
|
||||
def __init__(
|
||||
self,
|
||||
config: Union[Path, DictConfig, str],
|
||||
device_type: torch.device = CUDA_DEVICE,
|
||||
precision: torch.dtype = torch.float16,
|
||||
max_cache_size=MAX_CACHE_SIZE,
|
||||
sequential_offload=False,
|
||||
logger: types.ModuleType = logger,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
Optional parameters are the torch device type, precision, max_models,
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
"""
|
||||
self.mgr = ModelManager(config=config,
|
||||
device_type=device_type,
|
||||
precision=precision,
|
||||
max_cache_size=max_cache_size,
|
||||
sequential_offload=sequential_offload,
|
||||
logger=logger
|
||||
)
|
||||
|
||||
def get(self, model_name: str, submodel: SDModelType=None)->SDModelInfo:
|
||||
"""Retrieve the indicated model. submodel can be used to get a
|
||||
part (such as the vae) of a diffusers mode."""
|
||||
self.mgr.get_model(
|
||||
|
@ -1,124 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
from argparse import Namespace
|
||||
from invokeai.backend import Args
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
from typing import types
|
||||
|
||||
import invokeai.version
|
||||
from .model_manager_service import ModelManagerService
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
from ...backend import Globals
|
||||
|
||||
# temporary function - should call ModelManagerService() directly
|
||||
def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManagerService:
|
||||
if not config.conf:
|
||||
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
||||
if not os.path.exists(config_file):
|
||||
report_model_error(
|
||||
config, FileNotFoundError(f"The file {config_file} could not be found."), logger
|
||||
)
|
||||
|
||||
logger.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
||||
logger.info(f'InvokeAI runtime directory is "{Globals.root}"')
|
||||
|
||||
# these two lines prevent a horrible warning message from appearing
|
||||
# when the frozen CLIP tokenizer is imported
|
||||
import transformers # type: ignore
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
import diffusers
|
||||
|
||||
diffusers.logging.set_verbosity_error()
|
||||
|
||||
# normalize the config directory relative to root
|
||||
if not os.path.isabs(config.conf):
|
||||
config.conf = os.path.normpath(os.path.join(Globals.root, config.conf))
|
||||
|
||||
if config.embeddings:
|
||||
if not os.path.isabs(config.embedding_path):
|
||||
embedding_path = os.path.normpath(
|
||||
os.path.join(Globals.root, config.embedding_path)
|
||||
)
|
||||
else:
|
||||
embedding_path = config.embedding_path
|
||||
else:
|
||||
embedding_path = None
|
||||
|
||||
# creating the model manager
|
||||
try:
|
||||
device = torch.device(choose_torch_device())
|
||||
if config.precision=="auto":
|
||||
precision = choose_precision(device)
|
||||
dtype = torch.float32 if precision=='float32' \
|
||||
else torch.float16
|
||||
|
||||
max_cache_size = config.max_cache_size \
|
||||
if hasattr(config,'max_cache_size') \
|
||||
else config.max_loaded_models * 2.5
|
||||
|
||||
model_manager = ModelManagerService(
|
||||
config.conf,
|
||||
precision=dtype,
|
||||
device_type=device,
|
||||
max_cache_size=config.max_cache_size,
|
||||
# temporarily disabled until model manager stabilizes
|
||||
# embedding_path = Path(embedding_path),
|
||||
logger = logger,
|
||||
)
|
||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||
report_model_error(config, e, logger)
|
||||
except (IOError, KeyError) as e:
|
||||
logger.error(f"{e}. Aborting.")
|
||||
sys.exit(-1)
|
||||
|
||||
# try to autoconvert new models
|
||||
# autoimport new .ckpt files
|
||||
if path := config.autoconvert:
|
||||
model_manager.autoconvert_weights(
|
||||
conf_path=config.conf,
|
||||
weights_directory=path,
|
||||
)
|
||||
logger.info('Model manager initialized')
|
||||
return model_manager
|
||||
|
||||
def report_model_error(opt: Namespace, e: Exception, logger: types.ModuleType):
|
||||
logger.error(f'An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||
logger.error(
|
||||
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
||||
)
|
||||
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
||||
if yes_to_all:
|
||||
logger.warning(
|
||||
"Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||
)
|
||||
else:
|
||||
response = input(
|
||||
"Do you want to run invokeai-configure script to select and/or reinstall models? [y] "
|
||||
)
|
||||
if response.startswith(("n", "N")):
|
||||
return
|
||||
|
||||
logger.info("invokeai-configure is launching....\n")
|
||||
|
||||
# Match arguments that were set on the CLI
|
||||
# only the arguments accepted by the configuration script are parsed
|
||||
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
|
||||
config = ["--config", opt.conf] if opt.conf is not None else []
|
||||
sys.argv = ["invokeai-configure"]
|
||||
sys.argv.extend(root_dir)
|
||||
sys.argv.extend(config.to_dict())
|
||||
if yes_to_all is not None:
|
||||
for arg in yes_to_all.split():
|
||||
sys.argv.append(arg)
|
||||
|
||||
from invokeai.frontend.install import invokeai_configure
|
||||
|
||||
invokeai_configure()
|
||||
# TODO: Figure out how to restart
|
||||
# print('** InvokeAI will now restart')
|
||||
# sys.argv = previous_args
|
||||
# main() # would rather do a os.exec(), but doesn't exist?
|
||||
# sys.exit(0)
|
@ -3,23 +3,35 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Union, Callable
|
||||
from typing import Union, Callable, types
|
||||
|
||||
from invokeai.backend.util import CUDA_DEVICE
|
||||
from invokeai.backend.model_management.model_manager import (
|
||||
ModelManager,
|
||||
SDModelType,
|
||||
SDModelInfo,
|
||||
DictConfig,
|
||||
MAX_CACHE_SIZE,
|
||||
types,
|
||||
torch,
|
||||
logger,
|
||||
)
|
||||
from ...backend import Args,Globals # this must go when pr 3340 merged
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
|
||||
class ModelManagerServiceBase(ABC):
|
||||
"""Responsible for managing models on disk and in memory"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
config: Args,
|
||||
logger: types.ModuleType
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
Optional parameters are the torch device type, precision, max_models,
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model(self,
|
||||
model_name: str,
|
||||
@ -207,26 +219,50 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
"""Responsible for managing models on disk and in memory"""
|
||||
def __init__(
|
||||
self,
|
||||
config: Union[Path, DictConfig, str],
|
||||
device_type: torch.device = CUDA_DEVICE,
|
||||
precision: torch.dtype = torch.float16,
|
||||
max_cache_size=MAX_CACHE_SIZE,
|
||||
sequential_offload=False,
|
||||
logger: types.ModuleType = logger,
|
||||
):
|
||||
config: Args,
|
||||
logger: types.ModuleType
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
Optional parameters are the torch device type, precision, max_models,
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
"""
|
||||
self.mgr = ModelManager(config=config,
|
||||
device_type=device_type,
|
||||
precision=precision,
|
||||
if config.conf and Path(config.conf).exists():
|
||||
config_file = config.conf
|
||||
else:
|
||||
config_file = Path(Globals.root, "configs", "models.yaml")
|
||||
if not config_file.exists():
|
||||
raise IOError(f"The file {config_file} could not be found.")
|
||||
|
||||
logger.debug(f'config file={config_file}')
|
||||
|
||||
device = torch.device(choose_torch_device())
|
||||
if config.precision=="auto":
|
||||
precision = choose_precision(device)
|
||||
dtype = torch.float32 if precision=='float32' \
|
||||
else torch.float16
|
||||
|
||||
# this is transitional backward compatibility
|
||||
# support for the deprecated `max_loaded_models`
|
||||
# configuration value. If present, then the
|
||||
# cache size is set to 2.5 GB times
|
||||
# the number of max_loaded_models. Otherwise
|
||||
# use new `max_cache_size` config setting
|
||||
max_cache_size = config.max_cache_size \
|
||||
if hasattr(config,'max_cache_size') \
|
||||
else config.max_loaded_models * 2.5
|
||||
|
||||
sequential_offload = config.sequential_guidance
|
||||
|
||||
self.mgr = ModelManager(config=config_file,
|
||||
device_type=device,
|
||||
precision=dtype,
|
||||
max_cache_size=max_cache_size,
|
||||
sequential_offload=sequential_offload,
|
||||
logger=logger
|
||||
)
|
||||
logger.info('Model manager service initialized')
|
||||
|
||||
def get_model(self,
|
||||
model_name: str,
|
||||
|
@ -209,7 +209,7 @@ class ModelManager(object):
|
||||
if isinstance(config, DictConfig):
|
||||
self.config = config
|
||||
self.config_path = None
|
||||
elif type(config) in [str,DictConfig]:
|
||||
elif isinstance(config,(str,Path)):
|
||||
self.config_path = config
|
||||
self.config = OmegaConf.load(self.config_path)
|
||||
else:
|
||||
|
@ -38,7 +38,7 @@ dependencies = [
|
||||
"albumentations",
|
||||
"click",
|
||||
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||
"compel~=1.1.5",
|
||||
"compel~=1.0.5",
|
||||
"datasets",
|
||||
"diffusers[torch]~=0.16.1",
|
||||
"dnspython==2.2.1",
|
||||
|
Loading…
Reference in New Issue
Block a user