mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor model_manager_service.py into small functional modules
This commit is contained in:
@ -5,7 +5,8 @@ from typing import Any, Optional
|
||||
from invokeai.app.models.image import ProgressImage
|
||||
from invokeai.app.services.session_queue.session_queue_common import EnqueueBatchResult, SessionQueueItem
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.backend.model_manager import ModelInfo, SubModelType
|
||||
from invokeai.backend.model_manager import SubModelType
|
||||
from invokeai.backend.model_manager.loader import ModelInfo
|
||||
from invokeai.backend.model_manager.download import DownloadJobBase
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
449
invokeai/app/services/model_install_service.py
Normal file
449
invokeai/app/services/model_install_service.py
Normal file
@ -0,0 +1,449 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union, Set, Literal
|
||||
|
||||
from pydantic import Field, parse_obj_as
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from invokeai.backend import get_precision
|
||||
from invokeai.backend.model_manager.install import ModelInstallBase, ModelInstall, ModelInstallJob
|
||||
from invokeai.backend.model_manager import (
|
||||
ModelConfigBase,
|
||||
ModelSearch,
|
||||
)
|
||||
from invokeai.backend.model_manager.storage import ModelConfigStore
|
||||
from invokeai.backend.model_manager.download import DownloadJobBase
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||
from invokeai.backend.util.logging import InvokeAILogger, Logger
|
||||
|
||||
from .config import InvokeAIAppConfig
|
||||
from .events import EventServiceBase
|
||||
from .model_record_service import ModelRecordServiceBase
|
||||
|
||||
|
||||
class ModelInstallServiceBase(ABC):
|
||||
"""Responsible for downloading, installing and deleting models."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self,
|
||||
config: InvokeAIAppConfig,
|
||||
store: Union[ModelConfigStore, ModelRecordServiceBase],
|
||||
event_bus: Optional[EventServiceBase] = None
|
||||
):
|
||||
"""
|
||||
Initialize a ModelInstallService instance.
|
||||
|
||||
:param config: InvokeAIAppConfig object
|
||||
:param store: Either a ModelRecordServiceBase object or a ModelConfigStore
|
||||
:param event_bus: Optional EventServiceBase object. If provided,
|
||||
|
||||
Installation and download events will be sent to the event bus as "model_event".
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def convert_model(
|
||||
self,
|
||||
key: str,
|
||||
convert_dest_directory: Path,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder.
|
||||
|
||||
This will delete the cached version if there is any and delete the original
|
||||
checkpoint file if it is in the models directory.
|
||||
:param key: Unique key for the model to convert.
|
||||
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
|
||||
|
||||
This will raise a ValueError unless the model is not a checkpoint. It will
|
||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||
directory already in place.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def install_model(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
priority: int = 10,
|
||||
model_attributes: Optional[Dict[str, Any]] = None,
|
||||
) -> ModelInstallJob:
|
||||
"""Import a path, repo_id or URL. Returns an ModelInstallJob.
|
||||
|
||||
:param model_attributes: Additional attributes to supplement/override
|
||||
the model information gained from automated probing.
|
||||
:param priority: Queue priority. Lower values have higher priority.
|
||||
|
||||
Typical usage:
|
||||
job = model_manager.install(
|
||||
'stabilityai/stable-diffusion-2-1',
|
||||
model_attributes={'prediction_type": 'v-prediction'}
|
||||
)
|
||||
|
||||
The result is an ModelInstallJob object, which provides
|
||||
information on the asynchronous model download and install
|
||||
process. A series of "install_model_event" events will be emitted
|
||||
until the install is completed, cancelled or errors out.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_install_jobs(self) -> List[ModelInstallJob]:
|
||||
"""Return a series of active or enqueued ModelInstallJobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def id_to_job(self, id: int) -> ModelInstallJob:
|
||||
"""Return the ModelInstallJob instance corresponding to the given job ID."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
|
||||
"""
|
||||
Wait for all pending installs to complete.
|
||||
|
||||
This will block until all pending downloads have
|
||||
completed, been cancelled, or errored out. It will
|
||||
block indefinitely if one or more jobs are in the
|
||||
paused state.
|
||||
|
||||
It will return a dict that maps the source model
|
||||
path, URL or repo_id to the ID of the installed model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start_job(self, job_id: int):
|
||||
"""Start the given install job if it is paused or idle."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pause_job(self, job_id: int):
|
||||
"""Pause the given install job if it is paused or idle."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_job(self, job_id: int):
|
||||
"""Cancel the given install job."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_all_jobs(self):
|
||||
"""Cancel all active jobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prune_jobs(self):
|
||||
"""Remove completed or errored install jobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def change_job_priority(self, job_id: int, delta: int):
|
||||
"""
|
||||
Change an install job's priority.
|
||||
|
||||
:param job_id: Job to change
|
||||
:param delta: Value to increment or decrement priority.
|
||||
|
||||
Lower values are higher priority. The default starting value is 10.
|
||||
Thus to make this a really high priority job:
|
||||
manager.change_job_priority(-10).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def merge_models(
|
||||
self,
|
||||
model_keys: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model keys to merge"
|
||||
),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
|
||||
:param model_keys: List of 2-3 model unique keys to merge
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_checkpoint_configs(self) -> List[Path]:
|
||||
"""List the checkpoint config paths from ROOT/configs/stable-diffusion."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_for_models(self, directory: Path) -> Set[Path]:
|
||||
"""Return list of all models found in the designated directory."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sync_to_config(self):
|
||||
"""
|
||||
Synchronize the in-memory models with on-disk.
|
||||
|
||||
Re-read models.yaml, rescan the models directory, and reimport models
|
||||
in the autoimport directories. Call after making changes outside the
|
||||
model manager API.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# implementation
|
||||
class ModelInstallService(ModelInstallServiceBase):
|
||||
"""Responsible for managing models on disk and in memory."""
|
||||
|
||||
_installer: ModelInstallBase = Field(description="ModelInstall object for the current process")
|
||||
_config: InvokeAIAppConfig = Field(description="App configuration object")
|
||||
_precision: Literal['float16', 'float32'] = Field(description="Floating point precision, string form")
|
||||
_event_bus: Optional[EventServiceBase] = Field(description="an event bus to send install events to", default=None)
|
||||
_logger: Logger = Field(description="logger module")
|
||||
|
||||
def __init__(self,
|
||||
config: InvokeAIAppConfig,
|
||||
store: Union[ModelConfigStore, ModelRecordServiceBase],
|
||||
event_bus: Optional[EventServiceBase] = None
|
||||
):
|
||||
"""
|
||||
Initialize a ModelInstallService instance.
|
||||
|
||||
:param config: InvokeAIAppConfig object
|
||||
:param store: Either a ModelRecordService object or a ModelConfigStore
|
||||
:param event_bus: Optional EventServiceBase object. If provided,
|
||||
|
||||
Installation and download events will be sent to the event bus as "model_event".
|
||||
"""
|
||||
self._event_bus = event_bus
|
||||
self._config = config
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if self._event_bus:
|
||||
kwargs.update(event_handlers=[self._event_bus.emit_model_event])
|
||||
self._precision = get_precision()
|
||||
self._installer = ModelInstall(store, config, **kwargs)
|
||||
self._logger = InvokeAILogger.get_logger()
|
||||
|
||||
def start(self, invoker: Any): # Because .processor is giving circular import errors, declaring invoker an 'Any'
|
||||
"""Call automatically at process start."""
|
||||
self._installer.scan_models_directory() # synchronize new/deleted models found in models directory
|
||||
|
||||
def install_model(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
priority: int = 10,
|
||||
model_attributes: Optional[Dict[str, Any]] = None,
|
||||
) -> ModelInstallJob:
|
||||
"""
|
||||
Add a model using a path, repo_id or URL.
|
||||
|
||||
:param model_attributes: Dictionary of ModelConfigBase fields to
|
||||
attach to the model. When installing a URL or repo_id, some metadata
|
||||
values, such as `tags` will be automagically added.
|
||||
:param priority: Queue priority for this install job. Lower value jobs
|
||||
will run before higher value ones.
|
||||
"""
|
||||
self.logger.debug(f"add model {source}")
|
||||
variant = "fp16" if self._precision == "float16" else None
|
||||
job = self._installer.install(
|
||||
source,
|
||||
probe_override=model_attributes,
|
||||
variant=variant,
|
||||
priority=priority,
|
||||
)
|
||||
assert isinstance(job, ModelInstallJob)
|
||||
return job
|
||||
|
||||
def list_install_jobs(self) -> List[ModelInstallJob]:
|
||||
"""Return a series of active or enqueued ModelInstallJobs."""
|
||||
queue = self._installer.queue
|
||||
jobs: List[DownloadJobBase] = queue.list_jobs()
|
||||
return [parse_obj_as(ModelInstallJob, x) for x in jobs] # downcast to proper type
|
||||
|
||||
def id_to_job(self, id: int) -> ModelInstallJob:
|
||||
"""Return the ModelInstallJob instance corresponding to the given job ID."""
|
||||
job = self._installer.queue.id_to_job(id)
|
||||
assert isinstance(job, ModelInstallJob)
|
||||
return job
|
||||
|
||||
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
|
||||
"""
|
||||
Wait for all pending installs to complete.
|
||||
|
||||
This will block until all pending downloads have
|
||||
completed, been cancelled, or errored out. It will
|
||||
block indefinitely if one or more jobs are in the
|
||||
paused state.
|
||||
|
||||
It will return a dict that maps the source model
|
||||
path, URL or repo_id to the ID of the installed model.
|
||||
"""
|
||||
return self._installer.wait_for_installs()
|
||||
|
||||
def start_job(self, job_id: int):
|
||||
"""Start the given install job if it is paused or idle."""
|
||||
queue = self._installer.queue
|
||||
queue.start_job(queue.id_to_job(job_id))
|
||||
|
||||
def pause_job(self, job_id: int):
|
||||
"""Pause the given install job if it is paused or idle."""
|
||||
queue = self._installer.queue
|
||||
queue.pause_job(queue.id_to_job(job_id))
|
||||
|
||||
def cancel_job(self, job_id: int):
|
||||
"""Cancel the given install job."""
|
||||
queue = self._installer.queue
|
||||
queue.cancel_job(queue.id_to_job(job_id))
|
||||
|
||||
def cancel_all_jobs(self):
|
||||
"""Cancel all active install job."""
|
||||
queue = self._loader.queue
|
||||
queue.cancel_all_jobs()
|
||||
|
||||
def prune_jobs(self):
|
||||
"""Cancel all active install job."""
|
||||
queue = self._loader.queue
|
||||
queue.prune_jobs()
|
||||
|
||||
def change_job_priority(self, job_id: int, delta: int):
|
||||
"""
|
||||
Change an install job's priority.
|
||||
|
||||
:param job_id: Job to change
|
||||
:param delta: Value to increment or decrement priority.
|
||||
|
||||
Lower values are higher priority. The default starting value is 10.
|
||||
Thus to make this a really high priority job:
|
||||
manager.change_job_priority(-10).
|
||||
"""
|
||||
queue = self._installer.queue
|
||||
queue.change_priority(queue.id_to_job(job_id), delta)
|
||||
|
||||
def del_model(
|
||||
self,
|
||||
key: str,
|
||||
delete_files: bool = False,
|
||||
):
|
||||
"""
|
||||
Delete the named model from configuration.
|
||||
|
||||
If delete_files is true,
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
as well.
|
||||
"""
|
||||
model_info = self.store.get_model(key)
|
||||
self.logger.debug(f"delete model {model_info.name}")
|
||||
self.store.del_model(key)
|
||||
if delete_files and Path(self._config.models_path / model_info.path).exists():
|
||||
path = Path(model_info.path)
|
||||
if path.is_dir():
|
||||
shutil.rmtree(path)
|
||||
else:
|
||||
path.unlink()
|
||||
|
||||
def convert_model(
|
||||
self,
|
||||
key: str,
|
||||
convert_dest_directory: Path,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder.
|
||||
|
||||
Delete the cached
|
||||
version and delete the original checkpoint file if it is in the models
|
||||
directory.
|
||||
|
||||
:param key: Key of the model to convert
|
||||
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
|
||||
|
||||
This will raise a ValueError unless the model is not a checkpoint. It will
|
||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||
directory already in place.
|
||||
"""
|
||||
model_info = self.store.get_model(key)
|
||||
self.logger.info(f"Converting model {model_info.name} into a diffusers")
|
||||
return self._installer.convert_model(key, convert_dest_directory)
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
"""Get the logger associated with this instance."""
|
||||
return self._loader.logger
|
||||
|
||||
@property
|
||||
def store(self):
|
||||
"""Get the store associated with this instance."""
|
||||
return self._installer.store
|
||||
|
||||
def merge_models(
|
||||
self,
|
||||
model_keys: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model keys to merge"
|
||||
),
|
||||
merged_model_name: Optional[str] = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
|
||||
:param model_keys: List of 2-3 model unique keys to merge
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
"""
|
||||
merger = ModelMerger(self.store)
|
||||
try:
|
||||
if not merged_model_name:
|
||||
merged_model_name = "+".join([self.store.get_model(x).name for x in model_keys])
|
||||
raise Exception("not implemented")
|
||||
|
||||
result = merger.merge_diffusion_models_and_save(
|
||||
model_keys=model_keys,
|
||||
merged_model_name=merged_model_name,
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory=merge_dest_directory,
|
||||
)
|
||||
except AssertionError as e:
|
||||
raise ValueError(e)
|
||||
return result
|
||||
|
||||
def search_for_models(self, directory: Path) -> Set[Path]:
|
||||
"""
|
||||
Return list of all models found in the designated directory.
|
||||
|
||||
:param directory: Path to the directory to recursively search.
|
||||
returns a list of model paths
|
||||
"""
|
||||
return ModelSearch().search(directory)
|
||||
|
||||
def sync_to_config(self):
|
||||
"""
|
||||
Synchronize the model manager to the database.
|
||||
|
||||
Re-read models.yaml, rescan the models directory, and reimport models
|
||||
in the autoimport directories. Call after making changes outside the
|
||||
model manager API.
|
||||
"""
|
||||
return self._installer.sync_to_config()
|
||||
|
||||
def list_checkpoint_configs(self) -> List[Path]:
|
||||
"""List the checkpoint config paths from ROOT/configs/stable-diffusion."""
|
||||
config = self._config
|
||||
conf_path = config.legacy_conf_path
|
||||
root_path = config.root_path
|
||||
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")]
|
148
invokeai/app/services/model_loader_service.py
Normal file
148
invokeai/app/services/model_loader_service.py
Normal file
@ -0,0 +1,148 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, Union, Optional
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from invokeai.backend.model_manager import (
|
||||
ModelConfigStore,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.cache import CacheStats
|
||||
from invokeai.backend.model_manager.loader import ModelInfo, ModelLoad
|
||||
|
||||
from .config import InvokeAIAppConfig
|
||||
from .events import EventServiceBase
|
||||
from .model_record_service import ModelRecordServiceBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
|
||||
|
||||
class ModelLoadServiceBase(ABC):
|
||||
"""Load models into memory."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self,
|
||||
config: InvokeAIAppConfig,
|
||||
store: Union[ModelConfigStore, ModelRecordServiceBase],
|
||||
event_bus: Optional[EventServiceBase] = None):
|
||||
"""
|
||||
Initialize a ModelLoadService
|
||||
|
||||
:param config: InvokeAIAppConfig object
|
||||
:param store: ModelConfigStore object for fetching configuration information
|
||||
:param event_bus: Optional EventServiceBase object. If provided,
|
||||
installation and download events will be sent to the event bus.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
"""Retrieve the indicated model identified by key.
|
||||
|
||||
:param key: Unique key returned by the ModelConfigStore module.
|
||||
:param submodel_type: Submodel to return (required for main models)
|
||||
:param context" Optional InvocationContext, used in event reporting.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""Reset model cache statistics for graph with graph_id."""
|
||||
pass
|
||||
|
||||
|
||||
# implementation
|
||||
class ModelLoadService(ModelLoadServiceBase):
|
||||
"""Responsible for managing models on disk and in memory."""
|
||||
|
||||
_loader: ModelLoad = Field(description="InvokeAIAppConfig object for the current process")
|
||||
_event_bus: Optional[EventServiceBase] = Field(description="an event bus to send install events to", default=None)
|
||||
|
||||
def __init__(self,
|
||||
config: InvokeAIAppConfig,
|
||||
store: Union[ModelConfigStore, ModelRecordServiceBase],
|
||||
event_bus: Optional[EventServiceBase] = None
|
||||
):
|
||||
"""
|
||||
Initialize a ModelManagerService.
|
||||
|
||||
:param config: InvokeAIAppConfig object
|
||||
:param store: ModelRecordServiceBase or ModelConfigStore object for fetching configuration information
|
||||
:param event_bus: Optional EventServiceBase object. If provided,
|
||||
installation and download events will be sent to the event bus.
|
||||
"""
|
||||
self._event_bus = event_bus
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if self._event_bus:
|
||||
kwargs.update(event_handlers=[self._event_bus.emit_model_event])
|
||||
self._loader = ModelLoad(config, store, **kwargs)
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
"""
|
||||
Retrieve the indicated model.
|
||||
|
||||
The submodel is required when fetching a main model.
|
||||
"""
|
||||
model_info: ModelInfo = self._loader.get_model(key, submodel_type)
|
||||
|
||||
# we can emit model loading events if we are executing with access to the invocation context
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
model_key=key,
|
||||
submodel=submodel_type,
|
||||
model_info=model_info,
|
||||
)
|
||||
|
||||
return model_info
|
||||
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""
|
||||
Reset model cache statistics. Is this used?
|
||||
"""
|
||||
self._loader.collect_cache_stats(cache_stats)
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
model_key: str,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
):
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException()
|
||||
|
||||
if model_info:
|
||||
context.services.events.emit_model_load_completed(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_key=model_key,
|
||||
submodel=submodel,
|
||||
model_info=model_info,
|
||||
)
|
||||
else:
|
||||
context.services.events.emit_model_load_started(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_key=model_key,
|
||||
submodel=submodel,
|
||||
)
|
301
invokeai/app/services/model_record_service.py
Normal file
301
invokeai/app/services/model_record_service.py
Normal file
@ -0,0 +1,301 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
import sqlite3
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from pydantic import Field
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from invokeai.backend.model_manager.storage import (
|
||||
AnyModelConfig,
|
||||
ModelConfigStore,
|
||||
ModelConfigStoreSQL,
|
||||
ModelConfigStoreYAML,
|
||||
)
|
||||
from invokeai.backend.model_manager import (
|
||||
ModelConfigBase,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
DuplicateModelException,
|
||||
UnknownModelException,
|
||||
)
|
||||
|
||||
|
||||
class ModelRecordServiceBase(ABC):
|
||||
"""Responsible for managing model configuration records."""
|
||||
|
||||
@abstractmethod
|
||||
def get_model(self, key: str) -> AnyModelConfig:
|
||||
"""
|
||||
Retrieve the configuration for the indicated model.
|
||||
|
||||
:param key: Key of model config to be fetched.
|
||||
|
||||
Exceptions: UnknownModelException
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_exists(
|
||||
self,
|
||||
key: str,
|
||||
) -> bool:
|
||||
"""Return true if the model configuration identified by key exists in the database."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_info(self, key: str) -> ModelConfigBase:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
Uses the exact format as the omegaconf stanza.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_models(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> List[ModelConfigBase]:
|
||||
"""
|
||||
Return a list of ModelConfigBases that match the base, type and name criteria.
|
||||
:param base_model: Filter by the base model type.
|
||||
:param model_type: Filter by the model type.
|
||||
:param model_name: Filter by the model name.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_info_by_name(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> ModelConfigBase:
|
||||
"""
|
||||
Return information about the model using the same format as list_models().
|
||||
|
||||
If there are more than one model that match, raises a DuplicateModelException.
|
||||
If no model matches, raises an UnknownModelException
|
||||
"""
|
||||
pass
|
||||
model_configs = self.list_models(model_name=model_name, base_model=base_model, model_type=model_type)
|
||||
if len(model_configs) > 1:
|
||||
raise DuplicateModelException(
|
||||
"More than one model share the same name and type: {base_model}/{model_type}/{model_name}"
|
||||
)
|
||||
if len(model_configs) == 0:
|
||||
raise UnknownModelException("No known model with name and type: {base_model}/{model_type}/{model_name}")
|
||||
return model_configs[0]
|
||||
|
||||
@abstractmethod
|
||||
def all_models(self) -> List[ModelConfigBase]:
|
||||
"""Return a list of all the models."""
|
||||
pass
|
||||
return self.list_models()
|
||||
|
||||
@abstractmethod
|
||||
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
:param key: Unique key for the model
|
||||
:param config: Model configuration record, either a dict with the
|
||||
required fields or a ModelConfigBase instance.
|
||||
|
||||
Can raise DuplicateModelException and InvalidModelConfig exceptions.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_model(
|
||||
self,
|
||||
key: str,
|
||||
new_config: Union[dict, ModelConfigBase],
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes.
|
||||
|
||||
Will fail with a
|
||||
UnknownModelException if the name does not already exist.
|
||||
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model key is unknown.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def del_model(self, key: str, delete_files: bool = False):
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files
|
||||
is true, then the underlying file or directory will be
|
||||
deleted as well.
|
||||
"""
|
||||
pass
|
||||
|
||||
def rename_model(
|
||||
self,
|
||||
key: str,
|
||||
new_name: str,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Rename the indicated model.
|
||||
"""
|
||||
return self.update_model(key, {"name": new_name})
|
||||
|
||||
|
||||
# implementation base class
|
||||
class ModelRecordService(ModelRecordServiceBase):
|
||||
"""Responsible for managing models on disk and in memory."""
|
||||
|
||||
_store: ModelConfigStore = Field(description="Config record storage backend")
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self):
|
||||
"""Initialize object -- abstract method."""
|
||||
pass
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
key: str,
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Retrieve the indicated model.
|
||||
|
||||
:param key: Key of model config to be fetched.
|
||||
|
||||
Exceptions: UnknownModelException
|
||||
"""
|
||||
return self._store.get_model(key)
|
||||
|
||||
def model_exists(
|
||||
self,
|
||||
key: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Verify that a model with the given key exists.
|
||||
|
||||
Given a model key, returns True if it is a valid
|
||||
identifier.
|
||||
"""
|
||||
return self._store.exists(key)
|
||||
|
||||
def model_info(self, key: str) -> ModelConfigBase:
|
||||
"""
|
||||
Return configuration information about a model.
|
||||
|
||||
Given a model key returns the ModelConfigBase describing it.
|
||||
"""
|
||||
return self._store.get_model(key)
|
||||
|
||||
def list_models(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> List[ModelConfigBase]:
|
||||
"""
|
||||
Return a ModelConfigBase object for each model in the database.
|
||||
"""
|
||||
return self._store.search_by_name(model_name=model_name, base_model=base_model, model_type=model_type)
|
||||
|
||||
def model_info_by_name(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> ModelConfigBase:
|
||||
"""
|
||||
Return information about the model using the same format as list_models().
|
||||
|
||||
If there are more than one model that match, raises a DuplicateModelException.
|
||||
If no model matches, raises an UnknownModelException
|
||||
"""
|
||||
model_configs = self.list_models(model_name=model_name, base_model=base_model, model_type=model_type)
|
||||
if len(model_configs) > 1:
|
||||
raise DuplicateModelException(
|
||||
"More than one model share the same name and type: {base_model}/{model_type}/{model_name}"
|
||||
)
|
||||
if len(model_configs) == 0:
|
||||
raise UnknownModelException("No known model with name and type: {base_model}/{model_type}/{model_name}")
|
||||
return model_configs[0]
|
||||
|
||||
def all_models(self) -> List[ModelConfigBase]:
|
||||
"""Return a list of all the models."""
|
||||
return self.list_models()
|
||||
|
||||
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
:param key: Unique key for the model
|
||||
:param config: Model configuration record, either a dict with the
|
||||
required fields or a ModelConfigBase instance.
|
||||
|
||||
Can raise DuplicateModelException and InvalidModelConfig exceptions.
|
||||
"""
|
||||
self._store.add_model(key, config)
|
||||
|
||||
def update_model(
|
||||
self,
|
||||
key: str,
|
||||
new_config: Union[dict, ModelConfigBase],
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes.
|
||||
|
||||
Will fail with a
|
||||
UnknownModelException if the name does not already exist.
|
||||
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model key is unknown.
|
||||
"""
|
||||
new_info = self._store.update_model(key, new_config)
|
||||
print('FIX ME!!! need to call sync_model_path() somewhere... maybe router?')
|
||||
# self._loader.installer.sync_model_path(new_info.key) Maybe this goes into the router call?
|
||||
return new_info
|
||||
|
||||
def del_model(
|
||||
self,
|
||||
key: str,
|
||||
):
|
||||
"""
|
||||
Delete the named model from configuration.
|
||||
"""
|
||||
model_info = self.model_info(key)
|
||||
self._store.del_model(key)
|
||||
|
||||
def rename_model(
|
||||
self,
|
||||
key: str,
|
||||
new_name: str,
|
||||
):
|
||||
"""
|
||||
Rename the indicated model to the new name.
|
||||
|
||||
:param key: Unique key for the model.
|
||||
:param new_name: New name for the model
|
||||
"""
|
||||
return self.update_model(key, {"name": new_name})
|
||||
|
||||
|
||||
class ModelRecordServiceSQL(ModelRecordService):
|
||||
"""ModelRecordService that uses Sqlite for its backend."""
|
||||
|
||||
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock):
|
||||
"""
|
||||
Initialize a ModelRecordService that uses a SQLITE3 database backend.
|
||||
|
||||
:param conn: sqlite3 Connection object
|
||||
:param lock: Thread lock object
|
||||
"""
|
||||
self._store = ModelConfigStoreSQL(conn, lock)
|
||||
|
||||
|
||||
class ModelRecordServiceFile(ModelRecordService):
|
||||
"""ModelRecordService that uses a YAML file for its backend."""
|
||||
|
||||
def __init__(self, models_file: Path):
|
||||
"""
|
||||
Initialize a ModelRecordService that uses a YAML file as the backend.
|
||||
|
||||
:param models_file: Path to the YAML file backend.
|
||||
"""
|
||||
self._store = ModelConfigStoreYAML(models_file)
|
@ -6,11 +6,12 @@ from .model_manager import ( # noqa F401
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
ModelConfigStore,
|
||||
ModelInstall,
|
||||
ModelLoad,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SilenceWarnings,
|
||||
SubModelType,
|
||||
)
|
||||
from .model_manager.install import ModelInstall # noqa F401
|
||||
from .model_manager.loader import ModelLoad # noqa F401
|
||||
from .util.devices import get_precision # noqa F401
|
||||
|
@ -11,9 +11,9 @@ from .config import ( # noqa F401
|
||||
SilenceWarnings,
|
||||
SubModelType,
|
||||
)
|
||||
from .install import ModelInstall, ModelInstallJob # noqa F401
|
||||
from .loader import ModelInfo, ModelLoad # noqa F401
|
||||
from .lora import ModelPatcher, ONNXModelPatcher # noqa F401
|
||||
# from .install import ModelInstall, ModelInstallJob # noqa F401
|
||||
# from .loader import ModelInfo, ModelLoad # noqa F401
|
||||
# from .lora import ModelPatcher, ONNXModelPatcher # noqa F401
|
||||
from .models import OPENAPI_MODEL_CONFIGS, InvalidModelException, read_checkpoint_meta # noqa F401
|
||||
from .probe import ModelProbe, ModelProbeInfo # noqa F401
|
||||
from .search import ModelSearch # noqa F401
|
||||
|
@ -60,6 +60,7 @@ from pydantic import Field
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
||||
from invokeai.backend.util import Chdir, InvokeAILogger, Logger
|
||||
|
||||
from .config import (
|
||||
@ -319,6 +320,11 @@ class ModelInstallBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sync_to_config(self):
|
||||
"""Synchronize models on disk to those in memory."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def scan_models_directory(self):
|
||||
"""
|
||||
@ -367,7 +373,7 @@ class ModelInstall(ModelInstallBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
store: Optional[ModelConfigStore] = None,
|
||||
store: Optional[Union[ModelConfigStore, ModelRecordServiceBase]] = None,
|
||||
config: Optional[InvokeAIAppConfig] = None,
|
||||
logger: Optional[Logger] = None,
|
||||
download: Optional[DownloadQueueBase] = None,
|
||||
@ -381,10 +387,6 @@ class ModelInstall(ModelInstallBase):
|
||||
self._installed = set()
|
||||
self._tmpdir = None
|
||||
|
||||
# this step synchronizes the `models` directory with the models db
|
||||
# do NOT do this automatically, but only on app startup
|
||||
# self.scan_models_directory()
|
||||
|
||||
@property
|
||||
def queue(self) -> DownloadQueueBase:
|
||||
"""Return the queue."""
|
||||
@ -609,7 +611,7 @@ class ModelInstall(ModelInstallBase):
|
||||
variant: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
priority: Optional[int] = 10,
|
||||
) -> DownloadJobBase:
|
||||
) -> ModelInstallJob:
|
||||
# Clean up a common source of error. Doesn't work with Paths.
|
||||
if isinstance(source, str):
|
||||
source = source.strip()
|
||||
@ -747,6 +749,10 @@ class ModelInstall(ModelInstallBase):
|
||||
pass
|
||||
return True
|
||||
|
||||
def sync_to_config(self):
|
||||
"""Synchronize models on disk to those in memory."""
|
||||
self.scan_models_directory()
|
||||
|
||||
def scan_models_directory(self):
|
||||
"""
|
||||
Scan the models directory for new and missing models.
|
||||
|
@ -11,11 +11,10 @@ import torch
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util import InvokeAILogger, Logger, choose_precision, choose_torch_device
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
||||
|
||||
from .cache import CacheStats, ModelCache
|
||||
from .config import BaseModelType, ModelConfigBase, ModelType, SubModelType
|
||||
from .download import DownloadEventHandler, DownloadQueueBase
|
||||
from .install import ModelInstall, ModelInstallBase
|
||||
from .models import MODEL_CLASSES, InvalidModelException, ModelBase
|
||||
from .storage import ConfigFileVersionMismatchException, ModelConfigStore, get_config_store, migrate_models_store
|
||||
|
||||
@ -65,12 +64,6 @@ class ModelLoadBase(ABC):
|
||||
"""Return the ModelConfigStore object that supports this loader."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def installer(self) -> ModelInstallBase:
|
||||
"""Return the ModelInstallBase object that supports this loader."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def logger(self) -> Logger:
|
||||
@ -80,13 +73,7 @@ class ModelLoadBase(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def config(self) -> InvokeAIAppConfig:
|
||||
"""Return the config object used by this installer."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def queue(self) -> DownloadQueueBase:
|
||||
"""Return the download queue object used by this object."""
|
||||
"""Return the config object used by the loader."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -105,18 +92,12 @@ class ModelLoadBase(ABC):
|
||||
"""Return torch.fp16 or torch.fp32."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sync_to_config(self):
|
||||
"""Reinitialize the store to sync in-memory and in-disk versions."""
|
||||
pass
|
||||
|
||||
|
||||
class ModelLoad(ModelLoadBase):
|
||||
"""Implementation of ModelLoadBase."""
|
||||
|
||||
_app_config: InvokeAIAppConfig
|
||||
_store: ModelConfigStore
|
||||
_installer: ModelInstallBase
|
||||
_store: Union[ModelConfigStore, ModelRecordServiceBase]
|
||||
_cache: ModelCache
|
||||
_logger: Logger
|
||||
_cache_keys: dict
|
||||
@ -126,7 +107,6 @@ class ModelLoad(ModelLoadBase):
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
store: Optional[ModelConfigStore] = None,
|
||||
event_handlers: List[DownloadEventHandler] = [],
|
||||
):
|
||||
"""
|
||||
Initialize ModelLoad object.
|
||||
@ -138,12 +118,6 @@ class ModelLoad(ModelLoadBase):
|
||||
self._app_config = config
|
||||
self._store = store
|
||||
self._logger = InvokeAILogger.get_logger()
|
||||
self._installer = ModelInstall(
|
||||
store=self._store,
|
||||
logger=self._logger,
|
||||
config=self._app_config,
|
||||
event_handlers=event_handlers,
|
||||
)
|
||||
self._cache_keys = dict()
|
||||
self._models_file = config.model_conf_path
|
||||
device = torch.device(choose_torch_device())
|
||||
@ -190,11 +164,6 @@ class ModelLoad(ModelLoadBase):
|
||||
"""Return torch.fp16 or torch.fp32."""
|
||||
return self._cache.precision
|
||||
|
||||
@property
|
||||
def installer(self) -> ModelInstallBase:
|
||||
"""Return the ModelInstallBase instance used by this class."""
|
||||
return self._installer
|
||||
|
||||
@property
|
||||
def logger(self) -> Logger:
|
||||
"""Return the current logger."""
|
||||
@ -202,14 +171,9 @@ class ModelLoad(ModelLoadBase):
|
||||
|
||||
@property
|
||||
def config(self) -> InvokeAIAppConfig:
|
||||
"""Return the config object used by the installer."""
|
||||
"""Return the config object."""
|
||||
return self._app_config
|
||||
|
||||
@property
|
||||
def queue(self) -> DownloadQueueBase:
|
||||
"""Return the download queue object used by this object."""
|
||||
return self._installer.queue
|
||||
|
||||
def get_model(self, key: str, submodel_type: Optional[SubModelType] = None) -> ModelInfo:
|
||||
"""
|
||||
Get the ModelInfo corresponding to the model with key "key".
|
||||
@ -303,7 +267,3 @@ class ModelLoad(ModelLoadBase):
|
||||
model_path = self.resolve_model_path(model_path)
|
||||
return model_path, is_submodel_override
|
||||
|
||||
def sync_to_config(self):
|
||||
"""Synchronize models on disk to those in memory."""
|
||||
self._store = get_config_store(self._models_file)
|
||||
self.installer.scan_models_directory()
|
||||
|
@ -84,9 +84,9 @@ class ModelMerger(object):
|
||||
self,
|
||||
model_keys: List[str],
|
||||
merged_model_name: str,
|
||||
alpha: float = 0.5,
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: bool = False,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
**kwargs,
|
||||
) -> ModelConfigBase:
|
||||
|
@ -10,7 +10,7 @@ from .base import ( # noqa F401
|
||||
from .migrate import migrate_models_store # noqa F401
|
||||
from .sql import ModelConfigStoreSQL # noqa F401
|
||||
from .yaml import ModelConfigStoreYAML # noqa F401
|
||||
|
||||
from ..config import AnyModelConfig # noqa F401
|
||||
|
||||
def get_config_store(location: pathlib.Path) -> ModelConfigStore:
|
||||
"""Return the type of ModelConfigStore appropriate to the path."""
|
||||
|
@ -39,7 +39,7 @@ class ModelConfigStore(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> None:
|
||||
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> ModelConfigBase:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
|
@ -53,21 +53,18 @@ from .base import CONFIG_FILE_VERSION, DuplicateModelException, ModelConfigStore
|
||||
class ModelConfigStoreSQL(ModelConfigStore):
|
||||
"""Implementation of the ModelConfigStore ABC using a YAML file."""
|
||||
|
||||
_filename: Path
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: Path):
|
||||
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock):
|
||||
"""Initialize ModelConfigStore object with a sqlite3 database."""
|
||||
super().__init__()
|
||||
self._filename = Path(filename).absolute() # don't let chdir mess us up!
|
||||
self._filename.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
self._conn = conn
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = threading.Lock()
|
||||
self._lock = lock
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
@ -174,7 +171,7 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
("version", CONFIG_FILE_VERSION),
|
||||
)
|
||||
|
||||
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None:
|
||||
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
@ -225,6 +222,8 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
return self.get_model(key)
|
||||
|
||||
@property
|
||||
def version(self) -> str:
|
||||
"""Return the version of the database schema."""
|
||||
|
@ -104,7 +104,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
"""Return version of this config file/database."""
|
||||
return self._config.__metadata__.get("version")
|
||||
|
||||
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None:
|
||||
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
@ -127,6 +127,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
self._commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get_model(key)
|
||||
|
||||
def _fix_enums(self, original: dict) -> dict:
|
||||
"""In python 3.9, omegaconf stores incorrectly stringified enums."""
|
||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import platform
|
||||
from contextlib import nullcontext
|
||||
from typing import Union
|
||||
from typing import Union, Literal
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
@ -41,6 +41,11 @@ def choose_precision(device: torch.device) -> str:
|
||||
return "float16"
|
||||
return "float32"
|
||||
|
||||
def get_precision() -> Literal['float16', 'float32']:
|
||||
device = torch.device(choose_torch_device())
|
||||
precision = choose_precision(device) if config.precision == "auto" else config.precision
|
||||
assert precision in ['float16', 'float32']
|
||||
return precision
|
||||
|
||||
def torch_dtype(device: torch.device) -> torch.dtype:
|
||||
if config.full_precision:
|
||||
|
Reference in New Issue
Block a user