mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
model manager now running as a service
This commit is contained in:
parent
8ad8c5c67a
commit
df5b968954
@ -13,18 +13,20 @@ from typing import (
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
import invokeai.version
|
||||||
|
|
||||||
from invokeai.app.services.metadata import PngMetadataService
|
from invokeai.app.services.metadata import PngMetadataService
|
||||||
from .services.default_graphs import create_system_graphs
|
from .services.default_graphs import create_system_graphs
|
||||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
|
|
||||||
from ..backend import Args
|
from ..backend import Args
|
||||||
|
from ..backend import Globals # this should go when pr 3340 merged
|
||||||
|
|
||||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers
|
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers
|
||||||
from .cli.completer import set_autocompleter
|
from .cli.completer import set_autocompleter
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation
|
||||||
from .services.events import EventServiceBase
|
from .services.events import EventServiceBase
|
||||||
from .services.model_manager_initializer import get_model_manager
|
|
||||||
from .services.restoration_services import RestorationServices
|
from .services.restoration_services import RestorationServices
|
||||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
|
from .services.graph import Edge, EdgeConnection, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
|
||||||
from .services.default_graphs import default_text_to_image_graph_id
|
from .services.default_graphs import default_text_to_image_graph_id
|
||||||
@ -34,7 +36,7 @@ from .services.invocation_services import InvocationServices
|
|||||||
from .services.invoker import Invoker
|
from .services.invoker import Invoker
|
||||||
from .services.processor import DefaultInvocationProcessor
|
from .services.processor import DefaultInvocationProcessor
|
||||||
from .services.sqlite import SqliteItemStorage
|
from .services.sqlite import SqliteItemStorage
|
||||||
|
from .services.model_manager_service import ModelManagerService
|
||||||
|
|
||||||
class CliCommand(BaseModel):
|
class CliCommand(BaseModel):
|
||||||
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
|
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
|
||||||
@ -191,7 +193,11 @@ def invoke_all(context: CliContext):
|
|||||||
def invoke_cli():
|
def invoke_cli():
|
||||||
config = Args()
|
config = Args()
|
||||||
config.parse_args()
|
config.parse_args()
|
||||||
model_manager = get_model_manager(config,logger=logger)
|
|
||||||
|
logger.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
||||||
|
logger.info(f'InvokeAI runtime directory is "{Globals.root}"')
|
||||||
|
|
||||||
|
model_manager = ModelManagerService(config,logger)
|
||||||
|
|
||||||
# This initializes the autocompleter and returns it.
|
# This initializes the autocompleter and returns it.
|
||||||
# Currently nothing is done with the returned Completer
|
# Currently nothing is done with the returned Completer
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.invocations.util.choose_model import choose_model
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||||
|
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||||
@ -58,74 +57,74 @@ class CompelInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
|
|
||||||
# TODO: load without model
|
# TODO: load without model
|
||||||
model = choose_model(context.services.model_manager, self.model)
|
model = context.services.model_manager.get_model(self.model)
|
||||||
pipeline = model.context.model
|
with model.context as pipeline:
|
||||||
tokenizer = pipeline.tokenizer
|
tokenizer = pipeline.tokenizer
|
||||||
text_encoder = pipeline.text_encoder
|
text_encoder = pipeline.text_encoder
|
||||||
|
|
||||||
# TODO: global? input?
|
# TODO: global? input?
|
||||||
#use_full_precision = precision == "float32" or precision == "autocast"
|
#use_full_precision = precision == "float32" or precision == "autocast"
|
||||||
#use_full_precision = False
|
#use_full_precision = False
|
||||||
|
|
||||||
# TODO: redo TI when separate model loding implemented
|
# TODO: redo TI when separate model loding implemented
|
||||||
#textual_inversion_manager = TextualInversionManager(
|
#textual_inversion_manager = TextualInversionManager(
|
||||||
# tokenizer=tokenizer,
|
# tokenizer=tokenizer,
|
||||||
# text_encoder=text_encoder,
|
# text_encoder=text_encoder,
|
||||||
# full_precision=use_full_precision,
|
# full_precision=use_full_precision,
|
||||||
#)
|
#)
|
||||||
|
|
||||||
def load_huggingface_concepts(concepts: list[str]):
|
def load_huggingface_concepts(concepts: list[str]):
|
||||||
pipeline.textual_inversion_manager.load_huggingface_concepts(concepts)
|
pipeline.textual_inversion_manager.load_huggingface_concepts(concepts)
|
||||||
|
|
||||||
# apply the concepts library to the prompt
|
# apply the concepts library to the prompt
|
||||||
prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers(
|
prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers(
|
||||||
self.prompt,
|
self.prompt,
|
||||||
lambda concepts: load_huggingface_concepts(concepts),
|
lambda concepts: load_huggingface_concepts(concepts),
|
||||||
pipeline.textual_inversion_manager.get_all_trigger_strings(),
|
pipeline.textual_inversion_manager.get_all_trigger_strings(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# lazy-load any deferred textual inversions.
|
# lazy-load any deferred textual inversions.
|
||||||
# this might take a couple of seconds the first time a textual inversion is used.
|
# this might take a couple of seconds the first time a textual inversion is used.
|
||||||
pipeline.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
|
pipeline.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
|
||||||
prompt_str
|
prompt_str
|
||||||
)
|
)
|
||||||
|
|
||||||
compel = Compel(
|
compel = Compel(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
textual_inversion_manager=pipeline.textual_inversion_manager,
|
textual_inversion_manager=pipeline.textual_inversion_manager,
|
||||||
dtype_for_device_getter=torch_dtype,
|
dtype_for_device_getter=torch_dtype,
|
||||||
truncate_long_prompts=True, # TODO:
|
truncate_long_prompts=True, # TODO:
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: support legacy blend?
|
# TODO: support legacy blend?
|
||||||
|
|
||||||
prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(prompt_str)
|
prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(prompt_str)
|
||||||
|
|
||||||
if getattr(Globals, "log_tokenization", False):
|
if getattr(Globals, "log_tokenization", False):
|
||||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
log_tokenization_for_prompt_object(prompt, tokenizer)
|
||||||
|
|
||||||
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
||||||
|
|
||||||
# TODO: long prompt support
|
# TODO: long prompt support
|
||||||
#if not self.truncate_long_prompts:
|
#if not self.truncate_long_prompts:
|
||||||
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||||
|
|
||||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt),
|
tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt),
|
||||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||||
|
|
||||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||||
context.services.latents.set(conditioning_name, (c, ec))
|
context.services.latents.set(conditioning_name, (c, ec))
|
||||||
|
|
||||||
return CompelOutput(
|
return CompelOutput(
|
||||||
conditioning=ConditioningField(
|
conditioning=ConditioningField(
|
||||||
conditioning_name=conditioning_name,
|
conditioning_name=conditioning_name,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_max_token_count(
|
def get_max_token_count(
|
||||||
|
@ -9,7 +9,6 @@ from torch import Tensor
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.models.image import ColorField, ImageField, ImageType
|
from invokeai.app.models.image import ColorField, ImageField, ImageType
|
||||||
from invokeai.app.invocations.util.choose_model import choose_model
|
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
from invokeai.backend.generator.inpaint import infill_methods
|
from invokeai.backend.generator.inpaint import infill_methods
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||||
@ -72,7 +71,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
model = choose_model(context.services.model_manager, self.model)
|
model = context.services.model_manager.get_model(self.model)
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(
|
graph_execution_state = context.services.graph_execution_manager.get(
|
||||||
|
@ -6,7 +6,6 @@ import einops
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.app.invocations.util.choose_model import choose_model
|
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
|
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
@ -177,7 +176,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
|
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
|
||||||
model_info = choose_model(model_manager, self.model)
|
model_info = model_manager.get_model(self.model)
|
||||||
model_name = model_info.name
|
model_name = model_info.name
|
||||||
model_hash = model_info.hash
|
model_hash = model_info.hash
|
||||||
model_ctx: StableDiffusionGeneratorPipeline = model_info.context
|
model_ctx: StableDiffusionGeneratorPipeline = model_info.context
|
||||||
|
@ -1,14 +0,0 @@
|
|||||||
from invokeai.backend.model_management.model_manager_service import ModelManagerService, SDModelType
|
|
||||||
|
|
||||||
|
|
||||||
def choose_model(model_manager: ModelManagerService, model_name: str, model_type: SDModelType=SDModelType.diffusers):
|
|
||||||
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
|
|
||||||
logger = model_manager.logger
|
|
||||||
if model_name and not model_manager.valid_model(model_name, model_type):
|
|
||||||
default_model_name = model_manager.default_model()
|
|
||||||
logger.warning(f"\'{model_name}\' is not a valid model name. Using default model \'{default_model_name}\' instead.")
|
|
||||||
model = model_manager.get_model()
|
|
||||||
else:
|
|
||||||
model = model_manager.get_model(model_name, model_type)
|
|
||||||
|
|
||||||
return model
|
|
@ -1,211 +0,0 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Union, Callable
|
|
||||||
|
|
||||||
from invokeai.backend import ModelManager, SDModelType, SDModelInfo
|
|
||||||
|
|
||||||
class ModelManagerServiceBase(ABC):
|
|
||||||
"""Responsible for managing models on disk and in memory"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get(self, model_name: str, submodel: SDModelType=None)->SDModelInfo:
|
|
||||||
"""Retrieve the indicated model. submodel can be used to get a
|
|
||||||
part (such as the vae) of a diffusers mode.l"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def valid_model(self, model_name: str) -> bool:
|
|
||||||
"""
|
|
||||||
Given a model name, returns True if it is a valid
|
|
||||||
identifier.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def default_model(self) -> Union[str,None]:
|
|
||||||
"""
|
|
||||||
Returns the name of the default model, or None
|
|
||||||
if none is defined.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def set_default_model(self, model_name:str):
|
|
||||||
"""Sets the default model to the indicated name."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def model_info(self, model_name: str)->dict:
|
|
||||||
"""
|
|
||||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def model_names(self)->list[str]:
|
|
||||||
"""
|
|
||||||
Returns a list of all the model names known.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def list_models(self)->dict:
|
|
||||||
"""
|
|
||||||
Return a dict of models in the format:
|
|
||||||
{ model_name1: {'status': ('active'|'cached'|'not loaded'),
|
|
||||||
'description': description,
|
|
||||||
'format': ('ckpt'|'diffusers'|'vae'|'text_encoder'|'tokenizer'|'lora'...),
|
|
||||||
},
|
|
||||||
model_name2: { etc }
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def add_model(
|
|
||||||
self, model_name: str, model_attributes: dict, clobber: bool = False)->None:
|
|
||||||
"""
|
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
|
||||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
|
||||||
On a successful update, the config will be changed in memory. Will fail
|
|
||||||
with an assertion error if provided attributes are incorrect or
|
|
||||||
the model name is missing. Call commit() to write changes to disk.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def del_model(self, model_name: str, delete_files: bool = False) -> None:
|
|
||||||
"""
|
|
||||||
Delete the named model from configuration. If delete_files is true,
|
|
||||||
then the underlying weight file or diffusers directory will be deleted
|
|
||||||
as well. Call commit() to write to disk.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def import_diffuser_model(
|
|
||||||
repo_or_path: Union[str, Path],
|
|
||||||
model_name: str = None,
|
|
||||||
description: str = None,
|
|
||||||
vae: dict = None,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Install the indicated diffuser model and returns True if successful.
|
|
||||||
|
|
||||||
"repo_or_path" can be either a repo-id or a path-like object corresponding to the
|
|
||||||
top of a downloaded diffusers directory.
|
|
||||||
|
|
||||||
You can optionally provide a model name and/or description. If not provided,
|
|
||||||
then these will be derived from the repo name. Call commit() to write to disk.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def import_lora(
|
|
||||||
self,
|
|
||||||
path: Path,
|
|
||||||
model_name: str=None,
|
|
||||||
description: str=None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Creates an entry for the indicated lora file. Call
|
|
||||||
mgr.commit() to write out the configuration to models.yaml
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def import_embedding(
|
|
||||||
self,
|
|
||||||
path: Path,
|
|
||||||
model_name: str=None,
|
|
||||||
description: str=None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Creates an entry for the indicated textual inversion embedding file.
|
|
||||||
Call commit() to write out the configuration to models.yaml
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def heuristic_import(
|
|
||||||
self,
|
|
||||||
path_url_or_repo: str,
|
|
||||||
model_name: str = None,
|
|
||||||
description: str = None,
|
|
||||||
model_config_file: Path = None,
|
|
||||||
commit_to_conf: Path = None,
|
|
||||||
config_file_callback: Callable[[Path], Path] = None,
|
|
||||||
) -> str:
|
|
||||||
"""Accept a string which could be:
|
|
||||||
- a HF diffusers repo_id
|
|
||||||
- a URL pointing to a legacy .ckpt or .safetensors file
|
|
||||||
- a local path pointing to a legacy .ckpt or .safetensors file
|
|
||||||
- a local directory containing .ckpt and .safetensors files
|
|
||||||
- a local directory containing a diffusers model
|
|
||||||
|
|
||||||
After determining the nature of the model and downloading it
|
|
||||||
(if necessary), the file is probed to determine the correct
|
|
||||||
configuration file (if needed) and it is imported.
|
|
||||||
|
|
||||||
The model_name and/or description can be provided. If not, they will
|
|
||||||
be generated automatically.
|
|
||||||
|
|
||||||
If commit_to_conf is provided, the newly loaded model will be written
|
|
||||||
to the `models.yaml` file at the indicated path. Otherwise, the changes
|
|
||||||
will only remain in memory.
|
|
||||||
|
|
||||||
The routine will do its best to figure out the config file
|
|
||||||
needed to convert legacy checkpoint file, but if it can't it
|
|
||||||
will call the config_file_callback routine, if provided. The
|
|
||||||
callback accepts a single argument, the Path to the checkpoint
|
|
||||||
file, and returns a Path to the config file to use.
|
|
||||||
|
|
||||||
The (potentially derived) name of the model is returned on
|
|
||||||
success, or None on failure. When multiple models are added
|
|
||||||
from a directory, only the last imported one is returned.
|
|
||||||
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def commit(self, conf_file: Path=None) -> None:
|
|
||||||
"""
|
|
||||||
Write current configuration out to the indicated file.
|
|
||||||
If no conf_file is provided, then replaces the
|
|
||||||
original file/database used to initialize the object.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
# simple implementation
|
|
||||||
class ModelManagerService(ModelManagerServiceBase):
|
|
||||||
"""Responsible for managing models on disk and in memory"""
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: Union[Path, DictConfig, str],
|
|
||||||
device_type: torch.device = CUDA_DEVICE,
|
|
||||||
precision: torch.dtype = torch.float16,
|
|
||||||
max_cache_size=MAX_CACHE_SIZE,
|
|
||||||
sequential_offload=False,
|
|
||||||
logger: types.ModuleType = logger,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize with the path to the models.yaml config file.
|
|
||||||
Optional parameters are the torch device type, precision, max_models,
|
|
||||||
and sequential_offload boolean. Note that the default device
|
|
||||||
type and precision are set up for a CUDA system running at half precision.
|
|
||||||
"""
|
|
||||||
self.mgr = ModelManager(config=config,
|
|
||||||
device_type=device_type,
|
|
||||||
precision=precision,
|
|
||||||
max_cache_size=max_cache_size,
|
|
||||||
sequential_offload=sequential_offload,
|
|
||||||
logger=logger
|
|
||||||
)
|
|
||||||
|
|
||||||
def get(self, model_name: str, submodel: SDModelType=None)->SDModelInfo:
|
|
||||||
"""Retrieve the indicated model. submodel can be used to get a
|
|
||||||
part (such as the vae) of a diffusers mode."""
|
|
||||||
self.mgr.get_model(
|
|
||||||
|
|
@ -1,124 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import torch
|
|
||||||
from argparse import Namespace
|
|
||||||
from invokeai.backend import Args
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import types
|
|
||||||
|
|
||||||
import invokeai.version
|
|
||||||
from .model_manager_service import ModelManagerService
|
|
||||||
from ...backend.util import choose_precision, choose_torch_device
|
|
||||||
from ...backend import Globals
|
|
||||||
|
|
||||||
# temporary function - should call ModelManagerService() directly
|
|
||||||
def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManagerService:
|
|
||||||
if not config.conf:
|
|
||||||
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
|
||||||
if not os.path.exists(config_file):
|
|
||||||
report_model_error(
|
|
||||||
config, FileNotFoundError(f"The file {config_file} could not be found."), logger
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
|
||||||
logger.info(f'InvokeAI runtime directory is "{Globals.root}"')
|
|
||||||
|
|
||||||
# these two lines prevent a horrible warning message from appearing
|
|
||||||
# when the frozen CLIP tokenizer is imported
|
|
||||||
import transformers # type: ignore
|
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
|
||||||
import diffusers
|
|
||||||
|
|
||||||
diffusers.logging.set_verbosity_error()
|
|
||||||
|
|
||||||
# normalize the config directory relative to root
|
|
||||||
if not os.path.isabs(config.conf):
|
|
||||||
config.conf = os.path.normpath(os.path.join(Globals.root, config.conf))
|
|
||||||
|
|
||||||
if config.embeddings:
|
|
||||||
if not os.path.isabs(config.embedding_path):
|
|
||||||
embedding_path = os.path.normpath(
|
|
||||||
os.path.join(Globals.root, config.embedding_path)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
embedding_path = config.embedding_path
|
|
||||||
else:
|
|
||||||
embedding_path = None
|
|
||||||
|
|
||||||
# creating the model manager
|
|
||||||
try:
|
|
||||||
device = torch.device(choose_torch_device())
|
|
||||||
if config.precision=="auto":
|
|
||||||
precision = choose_precision(device)
|
|
||||||
dtype = torch.float32 if precision=='float32' \
|
|
||||||
else torch.float16
|
|
||||||
|
|
||||||
max_cache_size = config.max_cache_size \
|
|
||||||
if hasattr(config,'max_cache_size') \
|
|
||||||
else config.max_loaded_models * 2.5
|
|
||||||
|
|
||||||
model_manager = ModelManagerService(
|
|
||||||
config.conf,
|
|
||||||
precision=dtype,
|
|
||||||
device_type=device,
|
|
||||||
max_cache_size=config.max_cache_size,
|
|
||||||
# temporarily disabled until model manager stabilizes
|
|
||||||
# embedding_path = Path(embedding_path),
|
|
||||||
logger = logger,
|
|
||||||
)
|
|
||||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
|
||||||
report_model_error(config, e, logger)
|
|
||||||
except (IOError, KeyError) as e:
|
|
||||||
logger.error(f"{e}. Aborting.")
|
|
||||||
sys.exit(-1)
|
|
||||||
|
|
||||||
# try to autoconvert new models
|
|
||||||
# autoimport new .ckpt files
|
|
||||||
if path := config.autoconvert:
|
|
||||||
model_manager.autoconvert_weights(
|
|
||||||
conf_path=config.conf,
|
|
||||||
weights_directory=path,
|
|
||||||
)
|
|
||||||
logger.info('Model manager initialized')
|
|
||||||
return model_manager
|
|
||||||
|
|
||||||
def report_model_error(opt: Namespace, e: Exception, logger: types.ModuleType):
|
|
||||||
logger.error(f'An error occurred while attempting to initialize the model: "{str(e)}"')
|
|
||||||
logger.error(
|
|
||||||
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
|
||||||
)
|
|
||||||
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
|
||||||
if yes_to_all:
|
|
||||||
logger.warning(
|
|
||||||
"Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = input(
|
|
||||||
"Do you want to run invokeai-configure script to select and/or reinstall models? [y] "
|
|
||||||
)
|
|
||||||
if response.startswith(("n", "N")):
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info("invokeai-configure is launching....\n")
|
|
||||||
|
|
||||||
# Match arguments that were set on the CLI
|
|
||||||
# only the arguments accepted by the configuration script are parsed
|
|
||||||
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
|
|
||||||
config = ["--config", opt.conf] if opt.conf is not None else []
|
|
||||||
sys.argv = ["invokeai-configure"]
|
|
||||||
sys.argv.extend(root_dir)
|
|
||||||
sys.argv.extend(config.to_dict())
|
|
||||||
if yes_to_all is not None:
|
|
||||||
for arg in yes_to_all.split():
|
|
||||||
sys.argv.append(arg)
|
|
||||||
|
|
||||||
from invokeai.frontend.install import invokeai_configure
|
|
||||||
|
|
||||||
invokeai_configure()
|
|
||||||
# TODO: Figure out how to restart
|
|
||||||
# print('** InvokeAI will now restart')
|
|
||||||
# sys.argv = previous_args
|
|
||||||
# main() # would rather do a os.exec(), but doesn't exist?
|
|
||||||
# sys.exit(0)
|
|
@ -3,23 +3,35 @@
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union, Callable
|
from typing import Union, Callable, types
|
||||||
|
|
||||||
from invokeai.backend.util import CUDA_DEVICE
|
|
||||||
from invokeai.backend.model_management.model_manager import (
|
from invokeai.backend.model_management.model_manager import (
|
||||||
ModelManager,
|
ModelManager,
|
||||||
SDModelType,
|
SDModelType,
|
||||||
SDModelInfo,
|
SDModelInfo,
|
||||||
DictConfig,
|
|
||||||
MAX_CACHE_SIZE,
|
|
||||||
types,
|
types,
|
||||||
torch,
|
torch,
|
||||||
logger,
|
|
||||||
)
|
)
|
||||||
|
from ...backend import Args,Globals # this must go when pr 3340 merged
|
||||||
|
from ...backend.util import choose_precision, choose_torch_device
|
||||||
|
|
||||||
class ModelManagerServiceBase(ABC):
|
class ModelManagerServiceBase(ABC):
|
||||||
"""Responsible for managing models on disk and in memory"""
|
"""Responsible for managing models on disk and in memory"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Args,
|
||||||
|
logger: types.ModuleType
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize with the path to the models.yaml config file.
|
||||||
|
Optional parameters are the torch device type, precision, max_models,
|
||||||
|
and sequential_offload boolean. Note that the default device
|
||||||
|
type and precision are set up for a CUDA system running at half precision.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_model(self,
|
def get_model(self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@ -207,26 +219,50 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
"""Responsible for managing models on disk and in memory"""
|
"""Responsible for managing models on disk and in memory"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Union[Path, DictConfig, str],
|
config: Args,
|
||||||
device_type: torch.device = CUDA_DEVICE,
|
logger: types.ModuleType
|
||||||
precision: torch.dtype = torch.float16,
|
):
|
||||||
max_cache_size=MAX_CACHE_SIZE,
|
|
||||||
sequential_offload=False,
|
|
||||||
logger: types.ModuleType = logger,
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Initialize with the path to the models.yaml config file.
|
Initialize with the path to the models.yaml config file.
|
||||||
Optional parameters are the torch device type, precision, max_models,
|
Optional parameters are the torch device type, precision, max_models,
|
||||||
and sequential_offload boolean. Note that the default device
|
and sequential_offload boolean. Note that the default device
|
||||||
type and precision are set up for a CUDA system running at half precision.
|
type and precision are set up for a CUDA system running at half precision.
|
||||||
"""
|
"""
|
||||||
self.mgr = ModelManager(config=config,
|
if config.conf and Path(config.conf).exists():
|
||||||
device_type=device_type,
|
config_file = config.conf
|
||||||
precision=precision,
|
else:
|
||||||
|
config_file = Path(Globals.root, "configs", "models.yaml")
|
||||||
|
if not config_file.exists():
|
||||||
|
raise IOError(f"The file {config_file} could not be found.")
|
||||||
|
|
||||||
|
logger.debug(f'config file={config_file}')
|
||||||
|
|
||||||
|
device = torch.device(choose_torch_device())
|
||||||
|
if config.precision=="auto":
|
||||||
|
precision = choose_precision(device)
|
||||||
|
dtype = torch.float32 if precision=='float32' \
|
||||||
|
else torch.float16
|
||||||
|
|
||||||
|
# this is transitional backward compatibility
|
||||||
|
# support for the deprecated `max_loaded_models`
|
||||||
|
# configuration value. If present, then the
|
||||||
|
# cache size is set to 2.5 GB times
|
||||||
|
# the number of max_loaded_models. Otherwise
|
||||||
|
# use new `max_cache_size` config setting
|
||||||
|
max_cache_size = config.max_cache_size \
|
||||||
|
if hasattr(config,'max_cache_size') \
|
||||||
|
else config.max_loaded_models * 2.5
|
||||||
|
|
||||||
|
sequential_offload = config.sequential_guidance
|
||||||
|
|
||||||
|
self.mgr = ModelManager(config=config_file,
|
||||||
|
device_type=device,
|
||||||
|
precision=dtype,
|
||||||
max_cache_size=max_cache_size,
|
max_cache_size=max_cache_size,
|
||||||
sequential_offload=sequential_offload,
|
sequential_offload=sequential_offload,
|
||||||
logger=logger
|
logger=logger
|
||||||
)
|
)
|
||||||
|
logger.info('Model manager service initialized')
|
||||||
|
|
||||||
def get_model(self,
|
def get_model(self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
|
@ -209,7 +209,7 @@ class ModelManager(object):
|
|||||||
if isinstance(config, DictConfig):
|
if isinstance(config, DictConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.config_path = None
|
self.config_path = None
|
||||||
elif type(config) in [str,DictConfig]:
|
elif isinstance(config,(str,Path)):
|
||||||
self.config_path = config
|
self.config_path = config
|
||||||
self.config = OmegaConf.load(self.config_path)
|
self.config = OmegaConf.load(self.config_path)
|
||||||
else:
|
else:
|
||||||
|
@ -38,7 +38,7 @@ dependencies = [
|
|||||||
"albumentations",
|
"albumentations",
|
||||||
"click",
|
"click",
|
||||||
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||||
"compel~=1.1.5",
|
"compel~=1.0.5",
|
||||||
"datasets",
|
"datasets",
|
||||||
"diffusers[torch]~=0.16.1",
|
"diffusers[torch]~=0.16.1",
|
||||||
"dnspython==2.2.1",
|
"dnspython==2.2.1",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user