mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
restore printing of version when invokeai-web and invokeai called with --version
This commit is contained in:
@ -213,6 +213,10 @@ def invoke_api():
|
|||||||
|
|
||||||
from invokeai.backend.install.check_root import check_invokeai_root
|
from invokeai.backend.install.check_root import check_invokeai_root
|
||||||
|
|
||||||
|
if app_config.version:
|
||||||
|
print(f"InvokeAI version {__version__}")
|
||||||
|
return
|
||||||
|
|
||||||
check_invokeai_root(app_config) # note, may exit with an exception if root not set up
|
check_invokeai_root(app_config) # note, may exit with an exception if root not set up
|
||||||
|
|
||||||
if app_config.dev_reload:
|
if app_config.dev_reload:
|
||||||
|
@ -63,7 +63,10 @@ def add_parsers(
|
|||||||
for command in commands:
|
for command in commands:
|
||||||
hints = get_type_hints(command)
|
hints = get_type_hints(command)
|
||||||
cmd_name = get_args(hints[command_field])[0]
|
cmd_name = get_args(hints[command_field])[0]
|
||||||
command_parser = subparsers.add_parser(cmd_name, help=command.__doc__)
|
try:
|
||||||
|
command_parser = subparsers.add_parser(cmd_name, help=command.__doc__)
|
||||||
|
except argparse.ArgumentError:
|
||||||
|
continue
|
||||||
|
|
||||||
if add_arguments is not None:
|
if add_arguments is not None:
|
||||||
add_arguments(command_parser)
|
add_arguments(command_parser)
|
||||||
|
@ -10,8 +10,7 @@ from pathlib import Path
|
|||||||
from typing import Dict, List, Literal, get_args, get_origin, get_type_hints
|
from typing import Dict, List, Literal, get_args, get_origin, get_type_hints
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
from invokeai.backend.model_manager import ModelType
|
||||||
from ...backend import ModelManager, ModelType
|
|
||||||
from ..invocations.baseinvocation import BaseInvocation
|
from ..invocations.baseinvocation import BaseInvocation
|
||||||
from ..services.invocation_services import InvocationServices
|
from ..services.invocation_services import InvocationServices
|
||||||
from ..services.model_record_service import ModelRecordServiceBase
|
from ..services.model_record_service import ModelRecordServiceBase
|
||||||
|
@ -22,10 +22,12 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
|||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||||
|
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||||
from invokeai.app.services.board_image_record_storage import SqliteBoardImageRecordStorage
|
from invokeai.app.services.board_image_record_storage import SqliteBoardImageRecordStorage
|
||||||
from invokeai.app.services.board_images import BoardImagesService, BoardImagesServiceDependencies
|
from invokeai.app.services.board_images import BoardImagesService, BoardImagesServiceDependencies
|
||||||
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
||||||
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
||||||
|
from invokeai.app.services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||||
@ -47,6 +49,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
|||||||
LibraryGraph,
|
LibraryGraph,
|
||||||
are_connection_types_compatible,
|
are_connection_types_compatible,
|
||||||
)
|
)
|
||||||
|
from .services.thread import lock
|
||||||
from .services.image_file_storage import DiskImageFileStorage
|
from .services.image_file_storage import DiskImageFileStorage
|
||||||
from .services.invocation_queue import MemoryInvocationQueue
|
from .services.invocation_queue import MemoryInvocationQueue
|
||||||
from .services.invocation_services import InvocationServices
|
from .services.invocation_services import InvocationServices
|
||||||
@ -230,7 +233,12 @@ def invoke_all(context: CliContext):
|
|||||||
|
|
||||||
|
|
||||||
def invoke_cli():
|
def invoke_cli():
|
||||||
|
if config.version:
|
||||||
|
print(f"InvokeAI version {__version__}")
|
||||||
|
return
|
||||||
|
|
||||||
logger.info(f"InvokeAI version {__version__}")
|
logger.info(f"InvokeAI version {__version__}")
|
||||||
|
|
||||||
# get the optional list of invocations to execute on the command line
|
# get the optional list of invocations to execute on the command line
|
||||||
parser = config.get_parser()
|
parser = config.get_parser()
|
||||||
parser.add_argument("commands", nargs="*")
|
parser.add_argument("commands", nargs="*")
|
||||||
@ -255,18 +263,18 @@ def invoke_cli():
|
|||||||
logger.info(f'InvokeAI database location is "{db_location}"')
|
logger.info(f'InvokeAI database location is "{db_location}"')
|
||||||
|
|
||||||
model_record_store = ModelRecordServiceBase.get_impl(config, conn=db_conn, lock=None)
|
model_record_store = ModelRecordServiceBase.get_impl(config, conn=db_conn, lock=None)
|
||||||
model_loader = ModelLoadService(config, model_record_store, events)
|
model_loader = ModelLoadService(config, model_record_store)
|
||||||
model_installer = ModelInstallService(config, model_record_store, events)
|
model_installer = ModelInstallService(config, model_record_store, events)
|
||||||
|
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions", lock=lock)
|
||||||
|
|
||||||
urls = LocalUrlService()
|
urls = LocalUrlService()
|
||||||
image_record_storage = SqliteImageRecordStorage(conn=db_conn)
|
image_record_storage = SqliteImageRecordStorage(conn=db_conn, lock=lock)
|
||||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||||
names = SimpleNameService()
|
names = SimpleNameService()
|
||||||
|
|
||||||
board_record_storage = SqliteBoardRecordStorage(conn=db_conn)
|
board_record_storage = SqliteBoardRecordStorage(conn=db_conn, lock=lock)
|
||||||
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn)
|
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn, lock=lock)
|
||||||
|
|
||||||
boards = BoardService(
|
boards = BoardService(
|
||||||
services=BoardServiceDependencies(
|
services=BoardServiceDependencies(
|
||||||
@ -307,7 +315,7 @@ def invoke_cli():
|
|||||||
boards=boards,
|
boards=boards,
|
||||||
board_images=board_images,
|
board_images=board_images,
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
|
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs", lock=lock),
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||||
@ -317,6 +325,8 @@ def invoke_cli():
|
|||||||
model_installer=model_installer,
|
model_installer=model_installer,
|
||||||
configuration=config,
|
configuration=config,
|
||||||
invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size),
|
invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size),
|
||||||
|
session_queue=SqliteSessionQueue(conn=db_conn, lock=lock),
|
||||||
|
session_processor=DefaultSessionProcessor(),
|
||||||
)
|
)
|
||||||
|
|
||||||
system_graphs = create_system_graphs(services.graph_library)
|
system_graphs = create_system_graphs(services.graph_library)
|
||||||
@ -484,7 +494,4 @@ def invoke_cli():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if config.version:
|
invoke_cli()
|
||||||
print(f"InvokeAI version {__version__}")
|
|
||||||
else:
|
|
||||||
invoke_cli()
|
|
||||||
|
@ -426,7 +426,6 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
processed_image = zoe_depth_processor(image)
|
processed_image = zoe_depth_processor(image)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"mediapipe_face_processor",
|
"mediapipe_face_processor",
|
||||||
title="Mediapipe Face Processor",
|
title="Mediapipe Face Processor",
|
||||||
|
@ -3,13 +3,15 @@ Migrate the models directory and models.yaml file from an existing
|
|||||||
InvokeAI 2.3 installation to 3.0.0.
|
InvokeAI 2.3 installation to 3.0.0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
#### NOTE: THIS SCRIPT NO LONGER WORKS WITH REFACTORED MODEL MANAGER, AND WILL NOT BE UPDATED.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union, Optional
|
||||||
|
|
||||||
import diffusers
|
import diffusers
|
||||||
import transformers
|
import transformers
|
||||||
@ -22,6 +24,7 @@ from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel,
|
|||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
||||||
|
from invokeai.app.services.model_install_service import ModelInstallService
|
||||||
from invokeai.backend.model_manager import BaseModelType, ModelProbe, ModelProbeInfo, ModelType
|
from invokeai.backend.model_manager import BaseModelType, ModelProbe, ModelProbeInfo, ModelType
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
@ -43,19 +46,14 @@ class MigrateTo3(object):
|
|||||||
self,
|
self,
|
||||||
from_root: Path,
|
from_root: Path,
|
||||||
to_models: Path,
|
to_models: Path,
|
||||||
model_manager: ModelRecordServiceBase,
|
installer: ModelInstallService,
|
||||||
src_paths: ModelPaths,
|
src_paths: ModelPaths,
|
||||||
):
|
):
|
||||||
self.root_directory = from_root
|
self.root_directory = from_root
|
||||||
self.dest_models = to_models
|
self.dest_models = to_models
|
||||||
self.mgr = model_manager
|
self.installer = installer
|
||||||
self.src_paths = src_paths
|
self.src_paths = src_paths
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def initialize_yaml(cls, yaml_file: Path):
|
|
||||||
with open(yaml_file, "w") as file:
|
|
||||||
file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
|
|
||||||
|
|
||||||
def create_directory_structure(self):
|
def create_directory_structure(self):
|
||||||
"""
|
"""
|
||||||
Create the basic directory structure for the models folder.
|
Create the basic directory structure for the models folder.
|
||||||
@ -107,44 +105,10 @@ class MigrateTo3(object):
|
|||||||
Recursively walk through src directory, probe anything
|
Recursively walk through src directory, probe anything
|
||||||
that looks like a model, and copy the model into the
|
that looks like a model, and copy the model into the
|
||||||
appropriate location within the destination models directory.
|
appropriate location within the destination models directory.
|
||||||
|
|
||||||
|
This is now trivially easy using the installer service.
|
||||||
"""
|
"""
|
||||||
directories_scanned = set()
|
self.installer.scan_directory(src_dir)
|
||||||
for root, dirs, files in os.walk(src_dir, followlinks=True):
|
|
||||||
for d in dirs:
|
|
||||||
try:
|
|
||||||
model = Path(root, d)
|
|
||||||
info = ModelProbe().heuristic_probe(model)
|
|
||||||
if not info:
|
|
||||||
continue
|
|
||||||
dest = self._model_probe_to_path(info) / model.name
|
|
||||||
self.copy_dir(model, dest)
|
|
||||||
directories_scanned.add(model)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
for f in files:
|
|
||||||
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
|
|
||||||
# let them be copied as part of a tree copy operation
|
|
||||||
try:
|
|
||||||
if f in {"learned_embeds.bin", "pytorch_lora_weights.bin"}:
|
|
||||||
continue
|
|
||||||
model = Path(root, f)
|
|
||||||
if model.parent in directories_scanned:
|
|
||||||
continue
|
|
||||||
info = ModelProbe().heuristic_probe(model)
|
|
||||||
if not info:
|
|
||||||
continue
|
|
||||||
dest = self._model_probe_to_path(info) / f
|
|
||||||
self.copy_file(model, dest)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
|
|
||||||
def migrate_support_models(self):
|
def migrate_support_models(self):
|
||||||
"""
|
"""
|
||||||
@ -260,23 +224,21 @@ class MigrateTo3(object):
|
|||||||
model.save_pretrained(download_path, safe_serialization=True)
|
model.save_pretrained(download_path, safe_serialization=True)
|
||||||
download_path.replace(dest)
|
download_path.replace(dest)
|
||||||
|
|
||||||
def _download_vae(self, repo_id: str, subfolder: str = None) -> Path:
|
def _download_vae(self, repo_id: str, subfolder: str = None) -> Optional[Path]:
|
||||||
vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / "models/hub", subfolder=subfolder)
|
self.installer.install(repo_id) # bug! We don't support subfolder yet.
|
||||||
info = ModelProbe().heuristic_probe(vae)
|
ids = self.installer.wait_for_installs()
|
||||||
_, model_name = repo_id.split("/")
|
if key := ids.get(repo_id):
|
||||||
dest = self._model_probe_to_path(info) / self.unique_name(model_name, info)
|
return self.installer.store.get_model(key).path
|
||||||
vae.save_pretrained(dest, safe_serialization=True)
|
else:
|
||||||
return dest
|
return None
|
||||||
|
|
||||||
def _vae_path(self, vae: Union[str, dict]) -> Path:
|
def _vae_path(self, vae: Union[str, dict]) -> Optional[Path]:
|
||||||
"""
|
"""Convert 2.3 VAE stanza to a straight path."""
|
||||||
Convert 2.3 VAE stanza to a straight path.
|
vae_path: Optional[Path] = None
|
||||||
"""
|
|
||||||
vae_path = None
|
|
||||||
|
|
||||||
# First get a path
|
# First get a path
|
||||||
if isinstance(vae, str):
|
if isinstance(vae, str):
|
||||||
vae_path = vae
|
vae_path = Path(vae)
|
||||||
|
|
||||||
elif isinstance(vae, DictConfig):
|
elif isinstance(vae, DictConfig):
|
||||||
if p := vae.get("path"):
|
if p := vae.get("path"):
|
||||||
@ -284,28 +246,21 @@ class MigrateTo3(object):
|
|||||||
elif repo_id := vae.get("repo_id"):
|
elif repo_id := vae.get("repo_id"):
|
||||||
if repo_id == "stabilityai/sd-vae-ft-mse": # this guy is already downloaded
|
if repo_id == "stabilityai/sd-vae-ft-mse": # this guy is already downloaded
|
||||||
vae_path = "models/core/convert/sd-vae-ft-mse"
|
vae_path = "models/core/convert/sd-vae-ft-mse"
|
||||||
return vae_path
|
return Path(vae_path)
|
||||||
else:
|
else:
|
||||||
vae_path = self._download_vae(repo_id, vae.get("subfolder"))
|
vae_path = self._download_vae(repo_id, vae.get("subfolder"))
|
||||||
|
|
||||||
assert vae_path is not None, "Couldn't find VAE for this model"
|
if vae_path is None:
|
||||||
|
return None
|
||||||
|
|
||||||
# if the VAE is in the old models directory, then we must move it into the new
|
# if the VAE is in the old models directory, then we must move it into the new
|
||||||
# one. VAEs outside of this directory can stay where they are.
|
# one. VAEs outside of this directory can stay where they are.
|
||||||
vae_path = Path(vae_path)
|
|
||||||
if vae_path.is_relative_to(self.src_paths.models):
|
if vae_path.is_relative_to(self.src_paths.models):
|
||||||
info = ModelProbe().heuristic_probe(vae_path)
|
key = self.installer.install_path(vae_path) # this will move the model
|
||||||
dest = self._model_probe_to_path(info) / vae_path.name
|
return self.installer.store.get_model(key).path
|
||||||
if not dest.exists():
|
elif vae_path.is_relative_to(self.dest_models):
|
||||||
if vae_path.is_dir():
|
key = self.installer.register_path(vae_path) # this will keep the model in place
|
||||||
self.copy_dir(vae_path, dest)
|
return self.installer.store.get_model(key).path
|
||||||
else:
|
|
||||||
self.copy_file(vae_path, dest)
|
|
||||||
vae_path = dest
|
|
||||||
|
|
||||||
if vae_path.is_relative_to(self.dest_models):
|
|
||||||
rel_path = vae_path.relative_to(self.dest_models)
|
|
||||||
return Path("models", rel_path)
|
|
||||||
else:
|
else:
|
||||||
return vae_path
|
return vae_path
|
||||||
|
|
||||||
@ -505,40 +460,24 @@ def do_migrate(config: InvokeAIAppConfig, src_directory: Path, dest_directory: P
|
|||||||
"""
|
"""
|
||||||
Migrate models from src to dest InvokeAI root directories
|
Migrate models from src to dest InvokeAI root directories
|
||||||
"""
|
"""
|
||||||
config_file = dest_directory / "configs" / "models.yaml.3"
|
|
||||||
dest_models = dest_directory / "models.3"
|
dest_models = dest_directory / "models.3"
|
||||||
|
mm_store = ModelRecordServiceBase.get_impl(config)
|
||||||
|
mm_install = ModelInstallService(config=config, store=mm_store)
|
||||||
|
|
||||||
version_3 = (dest_directory / "models" / "core").exists()
|
version_3 = (dest_directory / "models" / "core").exists()
|
||||||
|
if not version_3:
|
||||||
# Here we create the destination models.yaml file.
|
src_directory = (dest_directory / "models").replace(src_directory / "models.orig")
|
||||||
# If we are writing into a version 3 directory and the
|
print(f"Original models directory moved to {dest_directory}/models.orig")
|
||||||
# file already exists, then we write into a copy of it to
|
|
||||||
# avoid deleting its previous customizations. Otherwise we
|
|
||||||
# create a new empty one.
|
|
||||||
if version_3: # write into the dest directory
|
|
||||||
try:
|
|
||||||
shutil.copy(dest_directory / "configs" / "models.yaml", config_file)
|
|
||||||
except Exception:
|
|
||||||
MigrateTo3.initialize_yaml(config_file)
|
|
||||||
mgr = ModelRecordServiceBase.get_impl(config)
|
|
||||||
(dest_directory / "models").replace(dest_models)
|
|
||||||
else:
|
|
||||||
MigrateTo3.initialize_yaml(config_file)
|
|
||||||
mgr = ModelManager(config_file)
|
|
||||||
|
|
||||||
paths = get_legacy_embeddings(src_directory)
|
paths = get_legacy_embeddings(src_directory)
|
||||||
migrator = MigrateTo3(from_root=src_directory, to_models=dest_models, model_manager=mgr, src_paths=paths)
|
migrator = MigrateTo3(from_root=src_directory, to_models=dest_models, installer=mm_install, src_paths=paths)
|
||||||
migrator.migrate()
|
migrator.migrate()
|
||||||
print("Migration successful.")
|
print("Migration successful.")
|
||||||
|
|
||||||
if not version_3:
|
|
||||||
(dest_directory / "models").replace(src_directory / "models.orig")
|
|
||||||
print(f"Original models directory moved to {dest_directory}/models.orig")
|
|
||||||
|
|
||||||
(dest_directory / "configs" / "models.yaml").replace(src_directory / "configs" / "models.yaml.orig")
|
(dest_directory / "configs" / "models.yaml").replace(src_directory / "configs" / "models.yaml.orig")
|
||||||
print(f"Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig")
|
print(f"Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig")
|
||||||
|
|
||||||
config_file.replace(config_file.with_suffix(""))
|
|
||||||
dest_models.replace(dest_models.with_suffix(""))
|
dest_models.replace(dest_models.with_suffix(""))
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,5 +24,4 @@ from .storage import ( # noqa F401
|
|||||||
ModelConfigStoreSQL,
|
ModelConfigStoreSQL,
|
||||||
ModelConfigStoreYAML,
|
ModelConfigStoreYAML,
|
||||||
UnknownModelException,
|
UnknownModelException,
|
||||||
get_config_store,
|
|
||||||
)
|
)
|
||||||
|
@ -10,12 +10,13 @@ from typing import Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
||||||
from invokeai.backend.util import InvokeAILogger, Logger, choose_precision, choose_torch_device
|
from invokeai.backend.util import InvokeAILogger, Logger, choose_precision, choose_torch_device
|
||||||
|
|
||||||
from .cache import CacheStats, ModelCache
|
from .cache import CacheStats, ModelCache
|
||||||
from .config import BaseModelType, ModelConfigBase, ModelType, SubModelType
|
from .config import BaseModelType, ModelConfigBase, ModelType, SubModelType
|
||||||
from .models import MODEL_CLASSES, InvalidModelException, ModelBase
|
from .models import MODEL_CLASSES, InvalidModelException, ModelBase
|
||||||
from .storage import ModelConfigStore, get_config_store
|
from .storage import ModelConfigStore
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -112,7 +113,7 @@ class ModelLoad(ModelLoadBase):
|
|||||||
:param config: The app's InvokeAIAppConfig object.
|
:param config: The app's InvokeAIAppConfig object.
|
||||||
"""
|
"""
|
||||||
self._app_config = config
|
self._app_config = config
|
||||||
self._store = store or get_config_store(config.root_path / config.model_config_db)
|
self._store = store or ModelRecordServiceBase.get_impl(config)
|
||||||
self._logger = InvokeAILogger.get_logger()
|
self._logger = InvokeAILogger.get_logger()
|
||||||
self._cache_keys = dict()
|
self._cache_keys = dict()
|
||||||
device = torch.device(choose_torch_device())
|
device = torch.device(choose_torch_device())
|
||||||
|
@ -11,16 +11,3 @@ from .base import ( # noqa F401
|
|||||||
from .migrate import migrate_models_store # noqa F401
|
from .migrate import migrate_models_store # noqa F401
|
||||||
from .sql import ModelConfigStoreSQL # noqa F401
|
from .sql import ModelConfigStoreSQL # noqa F401
|
||||||
from .yaml import ModelConfigStoreYAML # noqa F401
|
from .yaml import ModelConfigStoreYAML # noqa F401
|
||||||
|
|
||||||
|
|
||||||
def get_config_store(location: pathlib.Path) -> ModelConfigStore:
|
|
||||||
"""Return the type of ModelConfigStore appropriate to the path."""
|
|
||||||
location = pathlib.Path(location)
|
|
||||||
if location.suffix == ".yaml":
|
|
||||||
return ModelConfigStoreYAML(location)
|
|
||||||
elif location.suffix == ".db":
|
|
||||||
return ModelConfigStoreSQL(location)
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
f"Unable to determine type of configuration file '{location}'. Type 'auto' is not supported outside the app."
|
|
||||||
)
|
|
||||||
|
@ -142,7 +142,6 @@ dependencies = [
|
|||||||
"invokeai-merge" = "invokeai.frontend.merge:invokeai_merge_diffusers"
|
"invokeai-merge" = "invokeai.frontend.merge:invokeai_merge_diffusers"
|
||||||
"invokeai-ti" = "invokeai.frontend.training:invokeai_textual_inversion"
|
"invokeai-ti" = "invokeai.frontend.training:invokeai_textual_inversion"
|
||||||
"invokeai-model-install" = "invokeai.frontend.install.model_install:main"
|
"invokeai-model-install" = "invokeai.frontend.install.model_install:main"
|
||||||
"invokeai-migrate3" = "invokeai.backend.install.migrate_to_3:main"
|
|
||||||
"invokeai-update" = "invokeai.frontend.install.invokeai_update:main"
|
"invokeai-update" = "invokeai.frontend.install.invokeai_update:main"
|
||||||
"invokeai-metadata" = "invokeai.backend.image_util.invoke_metadata:main"
|
"invokeai-metadata" = "invokeai.backend.image_util.invoke_metadata:main"
|
||||||
"invokeai-node-cli" = "invokeai.app.cli_app:invoke_cli"
|
"invokeai-node-cli" = "invokeai.app.cli_app:invoke_cli"
|
||||||
|
Reference in New Issue
Block a user