mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
24 Commits
ryan/peft
...
bugfix-for
Author | SHA1 | Date | |
---|---|---|---|
0b238b1ece | |||
59b4a23479 | |||
13f410478a | |||
25ff0bf80f | |||
f83edcf990 | |||
a6dd50aeaf | |||
1badf0f32f | |||
3c9c58e0fa | |||
9a1b35fa37 | |||
5be69f191d | |||
3d6d89feb4 | |||
0ac1c0f339 | |||
c308654442 | |||
b0ffe36d21 | |||
6b3fdb8a93 | |||
7639e05dd2 | |||
6d261a5a13 | |||
31e9cf1f06 | |||
c5d1bd1360 | |||
3409711ed3 | |||
3681e34d5a | |||
2526ef52c5 | |||
43bcedee10 | |||
98cc9b963c |
@ -614,8 +614,8 @@ async def convert_model(
|
||||
The return value is the model configuration for the converted model.
|
||||
"""
|
||||
model_manager = ApiDependencies.invoker.services.model_manager
|
||||
loader = model_manager.load
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
loader = ApiDependencies.invoker.services.model_manager.load
|
||||
store = ApiDependencies.invoker.services.model_manager.store
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
|
||||
@ -630,7 +630,13 @@ async def convert_model(
|
||||
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
|
||||
|
||||
# loading the model will convert it into a cached diffusers file
|
||||
model_manager.load.load_model(model_config, submodel_type=SubModelType.Scheduler)
|
||||
try:
|
||||
cc_size = loader.convert_cache.max_size
|
||||
if cc_size == 0: # temporary set the convert cache to a positive number so that cached model is written
|
||||
loader._convert_cache.max_size = 1.0
|
||||
loader.load_model(model_config, submodel_type=SubModelType.Scheduler)
|
||||
finally:
|
||||
loader._convert_cache.max_size = cc_size
|
||||
|
||||
# Get the path of the converted model from the loader
|
||||
cache_path = loader.convert_cache.cache_path(key)
|
||||
|
@ -348,8 +348,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
config: dict[str, Any] = {}
|
||||
config["name"] = model_name
|
||||
config["description"] = stanza.get("description")
|
||||
config["config_path"] = stanza.get("config")
|
||||
|
||||
legacy_config_path = stanza.get("config")
|
||||
if legacy_config_path:
|
||||
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
|
||||
legacy_config_path: Path = self._app_config.root_path / legacy_config_path
|
||||
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
|
||||
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
|
||||
config["config_path"] = str(legacy_config_path)
|
||||
try:
|
||||
id = self.register_path(model_path=model_path, config=config)
|
||||
self._logger.info(f"Migrated {model_name} with id {id}")
|
||||
@ -368,11 +373,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def delete(self, key: str) -> None: # noqa D102
|
||||
"""Unregister the model. Delete its files only if they are within our models directory."""
|
||||
model = self.record_store.get_model(key)
|
||||
models_dir = self.app_config.models_path
|
||||
model_path = models_dir / Path(model.path) # handle legacy relative model paths
|
||||
if model_path.is_relative_to(models_dir):
|
||||
model_path = self.app_config.models_path / model.path
|
||||
|
||||
if model_path.is_relative_to(self.app_config.models_path):
|
||||
# If the models is in the Invoke-managed models dir, we delete it
|
||||
self.unconditionally_delete(key)
|
||||
else:
|
||||
# Else we only unregister it, leaving the file in place
|
||||
self.unregister(key)
|
||||
|
||||
def unconditionally_delete(self, key: str) -> None: # noqa D102
|
||||
@ -500,9 +507,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def _scan_for_missing_models(self) -> list[AnyModelConfig]:
|
||||
"""Scan the models directory for missing models and return a list of them."""
|
||||
missing_models: list[AnyModelConfig] = []
|
||||
for x in self.record_store.all_models():
|
||||
if not Path(x.path).resolve().exists():
|
||||
missing_models.append(x)
|
||||
for model_config in self.record_store.all_models():
|
||||
if not (self.app_config.models_path / model_config.path).resolve().exists():
|
||||
missing_models.append(model_config)
|
||||
return missing_models
|
||||
|
||||
def _register_orphaned_models(self) -> None:
|
||||
@ -512,7 +519,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
only situations in which we may have orphaned models in the models directory.
|
||||
"""
|
||||
|
||||
installed_model_paths = {Path(x.path).resolve() for x in self.record_store.all_models()}
|
||||
installed_model_paths = {
|
||||
(self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models()
|
||||
}
|
||||
|
||||
# The bool returned by this callback determines if the model is added to the list of models found by the search
|
||||
def on_model_found(model_path: Path) -> bool:
|
||||
@ -548,10 +557,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
May raise an UnknownModelException.
|
||||
"""
|
||||
model = self.record_store.get_model(key)
|
||||
old_path = Path(model.path).resolve()
|
||||
models_dir = self.app_config.models_path.resolve()
|
||||
models_dir = self.app_config.models_path
|
||||
old_path = self.app_config.models_path / model.path
|
||||
|
||||
if not old_path.is_relative_to(models_dir):
|
||||
# The model is not in the models directory - we don't need to move it.
|
||||
return model
|
||||
|
||||
new_path = (models_dir / model.base.value / model.type.value / model.name).with_suffix(old_path.suffix)
|
||||
@ -561,7 +571,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
self._logger.info(f"Moving {model.name} to {new_path}.")
|
||||
new_path = self._move_model(old_path, new_path)
|
||||
model.path = new_path.as_posix()
|
||||
model.path = new_path.relative_to(models_dir).as_posix()
|
||||
self.record_store.update_model(key, ModelRecordChanges(path=model.path))
|
||||
return model
|
||||
|
||||
@ -600,12 +610,19 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
model_path = model_path.resolve()
|
||||
|
||||
# Models in the Invoke-managed models dir should use relative paths.
|
||||
if model_path.is_relative_to(self.app_config.models_path):
|
||||
model_path = model_path.relative_to(self.app_config.models_path)
|
||||
|
||||
info.path = model_path.as_posix()
|
||||
|
||||
# Checkpoints have a config file needed for conversion - resolve this to an absolute path
|
||||
if isinstance(info, CheckpointConfigBase):
|
||||
legacy_conf = (self.app_config.legacy_conf_path / info.config_path).resolve()
|
||||
info.config_path = legacy_conf.as_posix()
|
||||
# Checkpoints have a config file needed for conversion. Same handling as the model weights - if it's in the
|
||||
# invoke-managed legacy config dir, we use a relative path.
|
||||
legacy_config_path = self.app_config.legacy_conf_path / info.config_path
|
||||
if legacy_config_path.is_relative_to(self.app_config.legacy_conf_path):
|
||||
legacy_config_path = legacy_config_path.relative_to(self.app_config.legacy_conf_path)
|
||||
info.config_path = legacy_config_path.as_posix()
|
||||
self.record_store.add_model(info)
|
||||
return info.key
|
||||
|
||||
|
@ -70,8 +70,18 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
async def _on_queue_event(self, event: FastAPIEvent) -> None:
|
||||
event_name = event[1]["event"]
|
||||
|
||||
if event_name == "session_canceled" or event_name == "queue_cleared":
|
||||
# These both mean we should cancel the current session.
|
||||
if (
|
||||
event_name == "session_canceled"
|
||||
and self._queue_item
|
||||
and self._queue_item.item_id == event[1]["data"]["queue_item_id"]
|
||||
):
|
||||
self._cancel_event.set()
|
||||
self._poll_now()
|
||||
elif (
|
||||
event_name == "queue_cleared"
|
||||
and self._queue_item
|
||||
and self._queue_item.queue_id == event[1]["data"]["queue_id"]
|
||||
):
|
||||
self._cancel_event.set()
|
||||
self._poll_now()
|
||||
elif event_name == "batch_enqueued":
|
||||
@ -111,141 +121,146 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
poll_now_event.clear()
|
||||
# Middle processor try block; any unhandled exception is a non-fatal processor error
|
||||
try:
|
||||
# If we are paused, wait for resume event
|
||||
resume_event.wait()
|
||||
|
||||
# Get the next session to process
|
||||
self._queue_item = self._invoker.services.session_queue.dequeue()
|
||||
if self._queue_item is not None and resume_event.is_set():
|
||||
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
||||
cancel_event.clear()
|
||||
|
||||
# If profiling is enabled, start the profiler
|
||||
if self._profiler is not None:
|
||||
self._profiler.start(profile_id=self._queue_item.session_id)
|
||||
if self._queue_item is None:
|
||||
# The queue was empty, wait for next polling interval or event to try again
|
||||
self._invoker.services.logger.debug("Waiting for next polling interval or event")
|
||||
poll_now_event.wait(self._polling_interval)
|
||||
continue
|
||||
|
||||
# Prepare invocations and take the first
|
||||
self._invocation = self._queue_item.session.next()
|
||||
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
||||
cancel_event.clear()
|
||||
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
while self._invocation is not None and not cancel_event.is_set():
|
||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
||||
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
|
||||
# If profiling is enabled, start the profiler
|
||||
if self._profiler is not None:
|
||||
self._profiler.start(profile_id=self._queue_item.session_id)
|
||||
|
||||
# Send starting event
|
||||
self._invoker.services.events.emit_invocation_started(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session_id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
)
|
||||
# Prepare invocations and take the first
|
||||
self._invocation = self._queue_item.session.next()
|
||||
|
||||
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
||||
try:
|
||||
with self._invoker.services.performance_statistics.collect_stats(
|
||||
self._invocation, self._queue_item.session.id
|
||||
):
|
||||
# Build invocation context (the node-facing API)
|
||||
data = InvocationContextData(
|
||||
invocation=self._invocation,
|
||||
source_invocation_id=source_invocation_id,
|
||||
queue_item=self._queue_item,
|
||||
)
|
||||
context = build_invocation_context(
|
||||
data=data,
|
||||
services=self._invoker.services,
|
||||
cancel_event=self._cancel_event,
|
||||
)
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
while self._invocation is not None and not cancel_event.is_set():
|
||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
||||
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
|
||||
|
||||
# Invoke the node
|
||||
outputs = self._invocation.invoke_internal(
|
||||
context=context, services=self._invoker.services
|
||||
)
|
||||
# Send starting event
|
||||
self._invoker.services.events.emit_invocation_started(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session_id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
)
|
||||
|
||||
# Save outputs and history
|
||||
self._queue_item.session.complete(self._invocation.id, outputs)
|
||||
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_invocation_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
result=outputs.model_dump(),
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
# TODO(MM2): Create an event for this
|
||||
pass
|
||||
|
||||
except CanceledException:
|
||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||
# between invocations, in this loop. Some invocations are long-running, and we need to
|
||||
# be able to cancel them mid-execution.
|
||||
#
|
||||
# For example, denoising is a long-running invocation with many steps. A step callback
|
||||
# is executed after each step. This step callback checks if the canceled event is set,
|
||||
# then raises a CanceledException to stop execution immediately.
|
||||
#
|
||||
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
error = traceback.format_exc()
|
||||
|
||||
# Save error
|
||||
self._queue_item.session.set_node_error(self._invocation.id, error)
|
||||
self._invoker.services.logger.error(
|
||||
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
|
||||
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
||||
try:
|
||||
with self._invoker.services.performance_statistics.collect_stats(
|
||||
self._invocation, self._queue_item.session.id
|
||||
):
|
||||
# Build invocation context (the node-facing API)
|
||||
data = InvocationContextData(
|
||||
invocation=self._invocation,
|
||||
source_invocation_id=source_invocation_id,
|
||||
queue_item=self._queue_item,
|
||||
)
|
||||
context = build_invocation_context(
|
||||
data=data,
|
||||
services=self._invoker.services,
|
||||
cancel_event=self._cancel_event,
|
||||
)
|
||||
self._invoker.services.logger.error(error)
|
||||
|
||||
# Send error event
|
||||
self._invoker.services.events.emit_invocation_error(
|
||||
queue_batch_id=self._queue_item.session_id,
|
||||
# Invoke the node
|
||||
outputs = self._invocation.invoke_internal(
|
||||
context=context, services=self._invoker.services
|
||||
)
|
||||
|
||||
# Save outputs and history
|
||||
self._queue_item.session.complete(self._invocation.id, outputs)
|
||||
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_invocation_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
result=outputs.model_dump(),
|
||||
)
|
||||
pass
|
||||
|
||||
# The session is complete if the all invocations are complete or there was an error
|
||||
if self._queue_item.session.is_complete() or cancel_event.is_set():
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
except KeyboardInterrupt:
|
||||
# TODO(MM2): Create an event for this
|
||||
pass
|
||||
|
||||
except CanceledException:
|
||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||
# between invocations, in this loop. Some invocations are long-running, and we need to
|
||||
# be able to cancel them mid-execution.
|
||||
#
|
||||
# For example, denoising is a long-running invocation with many steps. A step callback
|
||||
# is executed after each step. This step callback checks if the canceled event is set,
|
||||
# then raises a CanceledException to stop execution immediately.
|
||||
#
|
||||
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
error = traceback.format_exc()
|
||||
|
||||
# Save error
|
||||
self._queue_item.session.set_node_error(self._invocation.id, error)
|
||||
self._invoker.services.logger.error(
|
||||
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
|
||||
)
|
||||
self._invoker.services.logger.error(error)
|
||||
|
||||
# Send error event
|
||||
self._invoker.services.events.emit_invocation_error(
|
||||
queue_batch_id=self._queue_item.session_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
)
|
||||
pass
|
||||
|
||||
# The session is complete if the all invocations are complete or there was an error
|
||||
if self._queue_item.session.is_complete() or cancel_event.is_set():
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
)
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if self._profiler:
|
||||
profile_path = self._profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
self._invoker.services.performance_statistics.dump_stats(
|
||||
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
|
||||
)
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if self._profiler:
|
||||
profile_path = self._profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
self._invoker.services.performance_statistics.dump_stats(
|
||||
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
|
||||
)
|
||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
|
||||
self._invoker.services.performance_statistics.reset_stats()
|
||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
|
||||
self._invoker.services.performance_statistics.reset_stats()
|
||||
|
||||
# Set the invocation to None to prepare for the next session
|
||||
self._invocation = None
|
||||
else:
|
||||
# Prepare the next invocation
|
||||
self._invocation = self._queue_item.session.next()
|
||||
|
||||
# The session is complete, immediately poll for next session
|
||||
self._queue_item = None
|
||||
poll_now_event.set()
|
||||
# Set the invocation to None to prepare for the next session
|
||||
self._invocation = None
|
||||
else:
|
||||
# Prepare the next invocation
|
||||
self._invocation = self._queue_item.session.next()
|
||||
else:
|
||||
# The queue was empty, wait for next polling interval or event to try again
|
||||
self._invoker.services.logger.debug("Waiting for next polling interval or event")
|
||||
|
@ -10,6 +10,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import build_migration_8
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@ -37,6 +38,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_5())
|
||||
migrator.register_migration(build_migration_6())
|
||||
migrator.register_migration(build_migration_7())
|
||||
migrator.register_migration(build_migration_8(app_config=config))
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
@ -11,7 +11,7 @@ class Migration7Callback:
|
||||
def _drop_old_models_tables(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Drops the old model_records, model_metadata, model_tags and tags tables."""
|
||||
|
||||
tables = ["model_records", "model_metadata", "model_tags", "tags"]
|
||||
tables = ["model_config", "model_metadata", "model_tags", "tags"]
|
||||
|
||||
for table in tables:
|
||||
cursor.execute(f"DROP TABLE IF EXISTS {table};")
|
||||
|
@ -0,0 +1,91 @@
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration8Callback:
|
||||
def __init__(self, app_config: InvokeAIAppConfig) -> None:
|
||||
self._app_config = app_config
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._drop_model_config_table(cursor)
|
||||
self._migrate_abs_models_to_rel(cursor)
|
||||
|
||||
def _drop_model_config_table(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Drops the old model_config table. This was missed in a previous migration."""
|
||||
|
||||
cursor.execute("DROP TABLE IF EXISTS model_config;")
|
||||
|
||||
def _migrate_abs_models_to_rel(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Check all model paths & legacy config paths to determine if they are inside Invoke-managed directories. If
|
||||
they are, update the paths to be relative to the managed directories.
|
||||
|
||||
This migration is a no-op for normal users (their paths will already be relative), but is necessary for users
|
||||
who have been testing the RCs with their live databases. The paths were made absolute in the initial RC, but this
|
||||
change was reverted. To smooth over the revert for our tests, we can migrate the paths back to relative.
|
||||
"""
|
||||
|
||||
models_path = self._app_config.models_path
|
||||
legacy_conf_path = self._app_config.legacy_conf_path
|
||||
legacy_conf_dir = self._app_config.legacy_conf_dir
|
||||
|
||||
stmt = """---sql
|
||||
SELECT
|
||||
id,
|
||||
path,
|
||||
json_extract(config, '$.config_path') as config_path
|
||||
FROM models;
|
||||
"""
|
||||
|
||||
all_models = cursor.execute(stmt).fetchall()
|
||||
|
||||
for model_id, model_path, model_config_path in all_models:
|
||||
# If the model path is inside the models directory, update it to be relative to the models directory.
|
||||
if Path(model_path).is_relative_to(models_path):
|
||||
new_path = Path(model_path).relative_to(models_path)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE models
|
||||
SET config = json_set(config, '$.path', ?)
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(str(new_path), model_id),
|
||||
)
|
||||
# If the model has a legacy config path and it is inside the legacy conf directory, update it to be
|
||||
# relative to the legacy conf directory. This also fixes up cases in which the config path was
|
||||
# incorrectly relativized to the root directory. It will now be relativized to the legacy conf directory.
|
||||
if model_config_path:
|
||||
if Path(model_config_path).is_relative_to(legacy_conf_path):
|
||||
new_config_path = Path(model_config_path).relative_to(legacy_conf_path)
|
||||
elif Path(model_config_path).is_relative_to(legacy_conf_dir):
|
||||
new_config_path = Path(*Path(model_config_path).parts[1:])
|
||||
else:
|
||||
new_config_path = None
|
||||
if new_config_path:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE models
|
||||
SET config = json_set(config, '$.config_path', ?)
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(str(new_config_path), model_id),
|
||||
)
|
||||
|
||||
|
||||
def build_migration_8(app_config: InvokeAIAppConfig) -> Migration:
|
||||
"""
|
||||
Build the migration from database version 7 to 8.
|
||||
|
||||
This migration does the following:
|
||||
- Removes the `model_config` table.
|
||||
- Migrates absolute model & legacy config paths to be relative to the models directory.
|
||||
"""
|
||||
migration_8 = Migration(
|
||||
from_version=7,
|
||||
to_version=8,
|
||||
callback=Migration8Callback(app_config),
|
||||
)
|
||||
|
||||
return migration_8
|
@ -33,42 +33,3 @@ __all__ = [
|
||||
"SchedulerPredictionType",
|
||||
"SubModelType",
|
||||
]
|
||||
|
||||
########## to help populate the openapi_schema with format enums for each config ###########
|
||||
# This code is no longer necessary?
|
||||
# leave it here just in case
|
||||
#
|
||||
# import inspect
|
||||
# from enum import Enum
|
||||
# from typing import Any, Iterable, Dict, get_args, Set
|
||||
# def _expand(something: Any) -> Iterable[type]:
|
||||
# if isinstance(something, type):
|
||||
# yield something
|
||||
# else:
|
||||
# for x in get_args(something):
|
||||
# for y in _expand(x):
|
||||
# yield y
|
||||
|
||||
# def _find_format(cls: type) -> Iterable[Enum]:
|
||||
# if hasattr(inspect, "get_annotations"):
|
||||
# fields = inspect.get_annotations(cls)
|
||||
# else:
|
||||
# fields = cls.__annotations__
|
||||
# if "format" in fields:
|
||||
# for x in get_args(fields["format"]):
|
||||
# yield x
|
||||
# for parent_class in cls.__bases__:
|
||||
# for x in _find_format(parent_class):
|
||||
# yield x
|
||||
# return None
|
||||
|
||||
# def get_model_config_formats() -> Dict[str, Set[Enum]]:
|
||||
# result: Dict[str, Set[Enum]] = {}
|
||||
# for model_config in _expand(AnyModelConfig):
|
||||
# for field in _find_format(model_config):
|
||||
# if field is None:
|
||||
# continue
|
||||
# if not result.get(model_config.__qualname__):
|
||||
# result[model_config.__qualname__] = set()
|
||||
# result[model_config.__qualname__].add(field)
|
||||
# return result
|
||||
|
@ -3,10 +3,10 @@
|
||||
"""Conversion script for the Stable Diffusion checkpoints."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
convert_ldm_vae_checkpoint,
|
||||
create_vae_diffusers_config,
|
||||
@ -15,11 +15,14 @@ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
)
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from . import AnyModel
|
||||
|
||||
|
||||
def convert_ldm_vae_to_diffusers(
|
||||
checkpoint: Dict[str, torch.Tensor],
|
||||
checkpoint: torch.Tensor | dict[str, torch.Tensor],
|
||||
vae_config: DictConfig,
|
||||
image_size: int,
|
||||
dump_path: Optional[Path] = None,
|
||||
precision: torch.dtype = torch.float16,
|
||||
) -> AutoencoderKL:
|
||||
"""Convert a checkpoint-style VAE into a Diffusers VAE"""
|
||||
@ -28,16 +31,21 @@ def convert_ldm_vae_to_diffusers(
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
return vae.to(precision)
|
||||
vae.to(precision)
|
||||
|
||||
if dump_path:
|
||||
vae.save_pretrained(dump_path, safe_serialization=True)
|
||||
|
||||
return vae
|
||||
|
||||
|
||||
def convert_ckpt_to_diffusers(
|
||||
checkpoint_path: str | Path,
|
||||
dump_path: str | Path,
|
||||
dump_path: Optional[str | Path] = None,
|
||||
precision: torch.dtype = torch.float16,
|
||||
use_safetensors: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
) -> AnyModel:
|
||||
"""
|
||||
Takes all the arguments of download_from_original_stable_diffusion_ckpt(),
|
||||
and in addition a path-like object indicating the location of the desired diffusers
|
||||
@ -47,18 +55,20 @@ def convert_ckpt_to_diffusers(
|
||||
pipe = pipe.to(precision)
|
||||
|
||||
# TO DO: save correct repo variant
|
||||
pipe.save_pretrained(
|
||||
dump_path,
|
||||
safe_serialization=use_safetensors,
|
||||
)
|
||||
if dump_path:
|
||||
pipe.save_pretrained(
|
||||
dump_path,
|
||||
safe_serialization=use_safetensors,
|
||||
)
|
||||
return pipe
|
||||
|
||||
|
||||
def convert_controlnet_to_diffusers(
|
||||
checkpoint_path: Path,
|
||||
dump_path: Path,
|
||||
dump_path: Optional[Path] = None,
|
||||
precision: torch.dtype = torch.float16,
|
||||
**kwargs,
|
||||
):
|
||||
) -> AnyModel:
|
||||
"""
|
||||
Takes all the arguments of download_controlnet_from_original_ckpt(),
|
||||
and in addition a path-like object indicating the location of the desired diffusers
|
||||
@ -68,4 +78,6 @@ def convert_controlnet_to_diffusers(
|
||||
pipe = pipe.to(precision)
|
||||
|
||||
# TO DO: save correct repo variant
|
||||
pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||
if dump_path:
|
||||
pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||
return pipe
|
||||
|
@ -19,11 +19,20 @@ class ModelConvertCache(ModelConvertCacheBase):
|
||||
self._cache_path = cache_path
|
||||
self._max_size = max_size
|
||||
|
||||
# adjust cache size at startup in case it has been changed
|
||||
if self._cache_path.exists():
|
||||
self.make_room(0.0)
|
||||
|
||||
@property
|
||||
def max_size(self) -> float:
|
||||
"""Return the maximum size of this cache directory (GB)."""
|
||||
return self._max_size
|
||||
|
||||
@max_size.setter
|
||||
def max_size(self, value: float) -> None:
|
||||
"""Set the maximum size of this cache directory (GB)."""
|
||||
self._max_size = value
|
||||
|
||||
def cache_path(self, key: str) -> Path:
|
||||
"""Return the path for a model with the indicated key."""
|
||||
return self._cache_path / key
|
||||
|
@ -83,3 +83,15 @@ class ModelLoaderBase(ABC):
|
||||
) -> int:
|
||||
"""Return size in bytes of the model, calculated before loading."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the convert cache associated with this loader."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
"""Return the ram cache associated with this loader."""
|
||||
pass
|
||||
|
@ -3,14 +3,13 @@
|
||||
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
InvalidModelConfigException,
|
||||
ModelRepoVariant,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
|
||||
@ -54,51 +53,43 @@ class ModelLoader(ModelLoaderBase):
|
||||
if model_config.type is ModelType.Main and not submodel_type:
|
||||
raise InvalidModelConfigException("submodel_type is required when loading a main model")
|
||||
|
||||
model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
|
||||
model_path = self._get_model_path(model_config)
|
||||
|
||||
if not model_path.exists():
|
||||
raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}")
|
||||
|
||||
model_path = self._convert_if_needed(model_config, model_path, submodel_type)
|
||||
locker = self._load_if_needed(model_config, model_path, submodel_type)
|
||||
with skip_torch_weight_init():
|
||||
locker = self._convert_and_load(model_config, model_path, submodel_type)
|
||||
return LoadedModel(config=model_config, _locker=locker)
|
||||
|
||||
def _get_model_path(
|
||||
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
|
||||
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
|
||||
@property
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the convert cache associated with this loader."""
|
||||
return self._convert_cache
|
||||
|
||||
@property
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
"""Return the ram cache associated with this loader."""
|
||||
return self._ram_cache
|
||||
|
||||
def _get_model_path(self, config: AnyModelConfig) -> Path:
|
||||
model_base = self._app_config.models_path
|
||||
result = (model_base / config.path).resolve(), config, submodel_type
|
||||
return result
|
||||
return (model_base / config.path).resolve()
|
||||
|
||||
def _convert_if_needed(
|
||||
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
) -> Path:
|
||||
cache_path: Path = self._convert_cache.cache_path(config.key)
|
||||
|
||||
if not self._needs_conversion(config, model_path, cache_path):
|
||||
return cache_path if cache_path.exists() else model_path
|
||||
|
||||
self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
|
||||
return self._convert_model(config, model_path, cache_path)
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
return False
|
||||
|
||||
def _load_if_needed(
|
||||
def _convert_and_load(
|
||||
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
) -> ModelLockerBase:
|
||||
# TO DO: This is not thread safe!
|
||||
try:
|
||||
return self._ram_cache.get(config.key, submodel_type)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
model_variant = getattr(config, "repo_variant", None)
|
||||
self._ram_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
|
||||
|
||||
# This is where the model is actually loaded!
|
||||
with skip_torch_weight_init():
|
||||
loaded_model = self._load_model(model_path, model_variant=model_variant, submodel_type=submodel_type)
|
||||
cache_path: Path = self._convert_cache.cache_path(config.key)
|
||||
if self._needs_conversion(config, model_path, cache_path):
|
||||
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
|
||||
else:
|
||||
config.path = str(cache_path) if cache_path.exists() else str(self._get_model_path(config))
|
||||
loaded_model = self._load_model(config, submodel_type)
|
||||
|
||||
self._ram_cache.put(
|
||||
config.key,
|
||||
@ -123,15 +114,34 @@ class ModelLoader(ModelLoaderBase):
|
||||
variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None,
|
||||
)
|
||||
|
||||
def _do_convert(
|
||||
self, config: AnyModelConfig, model_path: Path, cache_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
) -> AnyModel:
|
||||
self.convert_cache.make_room(calc_model_size_by_fs(model_path))
|
||||
pipeline = self._convert_model(config, model_path, cache_path if self.convert_cache.max_size > 0 else None)
|
||||
if submodel_type:
|
||||
# Proactively load the various submodels into the RAM cache so that we don't have to re-convert
|
||||
# the entire pipeline every time a new submodel is needed.
|
||||
for subtype in SubModelType:
|
||||
if subtype == submodel_type:
|
||||
continue
|
||||
if submodel := getattr(pipeline, subtype.value, None):
|
||||
self._ram_cache.put(
|
||||
config.key, submodel_type=subtype, model=submodel, size=calc_model_size_by_data(submodel)
|
||||
)
|
||||
return getattr(pipeline, submodel_type.value) if submodel_type else pipeline
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
return False
|
||||
|
||||
# This needs to be implemented in subclasses that handle checkpoints
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
|
||||
raise NotImplementedError
|
||||
|
||||
# This needs to be implemented in the subclass
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
raise NotImplementedError
|
||||
|
@ -122,6 +122,11 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""Return the cap on cache size."""
|
||||
return self._max_cache_size
|
||||
|
||||
@max_cache_size.setter
|
||||
def max_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on cache size."""
|
||||
self._max_cache_size = value
|
||||
|
||||
@property
|
||||
def stats(self) -> Optional[CacheStats]:
|
||||
"""Return collected CacheStats object."""
|
||||
@ -157,8 +162,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
) -> None:
|
||||
"""Store model under key and optional submodel_type."""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
assert key not in self._cached_models
|
||||
|
||||
if key in self._cached_models:
|
||||
return
|
||||
self.make_room(size)
|
||||
cache_record = CacheRecord(key, model, size)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
@ -405,6 +411,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
#
|
||||
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
|
||||
# immediately when their reference count hits 0.
|
||||
if self.stats:
|
||||
self.stats.cleared = models_cleared
|
||||
gc.collect()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -2,8 +2,10 @@
|
||||
"""Class for ControlNet model loading in InvokeAI."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
@ -33,7 +35,7 @@ class ControlNetLoader(GenericDiffusersLoader):
|
||||
else:
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
image_size = (
|
||||
512
|
||||
@ -44,8 +46,8 @@ class ControlNetLoader(GenericDiffusersLoader):
|
||||
)
|
||||
|
||||
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||
with open(self._app_config.root_path / config.config_path, "r") as config_stream:
|
||||
convert_controlnet_to_diffusers(
|
||||
with open(self._app_config.legacy_conf_path / config.config_path, "r") as config_stream:
|
||||
result = convert_controlnet_to_diffusers(
|
||||
model_path,
|
||||
output_path,
|
||||
original_config_file=config_stream,
|
||||
@ -53,4 +55,4 @@ class ControlNetLoader(GenericDiffusersLoader):
|
||||
precision=self._torch_dtype,
|
||||
from_safetensors=model_path.suffix == ".safetensors",
|
||||
)
|
||||
return output_path
|
||||
return result
|
||||
|
@ -10,13 +10,14 @@ from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase
|
||||
|
||||
from .. import ModelLoader, ModelLoaderRegistry
|
||||
|
||||
@ -28,14 +29,15 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
model_path = Path(config.path)
|
||||
model_class = self.get_hf_load_class(model_path)
|
||||
if submodel_type is not None:
|
||||
raise Exception(f"There are no submodels in models of type {model_class}")
|
||||
variant = model_variant.value if model_variant else None
|
||||
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
|
||||
variant = repo_variant.value if repo_variant else None
|
||||
try:
|
||||
result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant)
|
||||
except OSError as e:
|
||||
|
@ -9,13 +9,14 @@ import torch
|
||||
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
|
||||
@ -24,13 +25,13 @@ class IPAdapterInvokeAILoader(ModelLoader):
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if submodel_type is not None:
|
||||
raise ValueError("There are no submodels in an IP-Adapter model.")
|
||||
model = build_ip_adapter(
|
||||
model_path = Path(config.path)
|
||||
model: RawModel = build_ip_adapter(
|
||||
ip_adapter_ckpt_path=str(model_path / "ip_adapter.bin"),
|
||||
device=torch.device("cpu"),
|
||||
dtype=self._torch_dtype,
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
@ -12,7 +12,6 @@ from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
@ -41,12 +40,12 @@ class LoRALoader(ModelLoader):
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if submodel_type is not None:
|
||||
raise ValueError("There are no submodels in a LoRA model.")
|
||||
model_path = Path(config.path)
|
||||
assert self._model_base is not None
|
||||
model = LoRAModelRaw.from_checkpoint(
|
||||
file_path=model_path,
|
||||
@ -56,12 +55,9 @@ class LoRALoader(ModelLoader):
|
||||
return model
|
||||
|
||||
# override
|
||||
def _get_model_path(
|
||||
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
|
||||
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
|
||||
self._model_base = (
|
||||
config.base
|
||||
) # cheating a little - we remember this variable for using in the subsequent call to _load_model()
|
||||
def _get_model_path(self, config: AnyModelConfig) -> Path:
|
||||
# cheating a little - we remember this variable for using in the subsequent call to _load_model()
|
||||
self._model_base = config.base
|
||||
|
||||
model_base_path = self._app_config.models_path
|
||||
model_path = model_base_path / config.path
|
||||
@ -73,5 +69,4 @@ class LoRALoader(ModelLoader):
|
||||
model_path = path
|
||||
break
|
||||
|
||||
result = model_path.resolve(), config, submodel_type
|
||||
return result
|
||||
return model_path.resolve()
|
||||
|
@ -7,9 +7,9 @@ from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
@ -25,18 +25,19 @@ class OnnyxDiffusersModel(GenericDiffusersLoader):
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not submodel_type is not None:
|
||||
raise Exception("A submodel type must be provided when loading onnx pipelines.")
|
||||
model_path = Path(config.path)
|
||||
load_class = self.get_hf_load_class(model_path, submodel_type)
|
||||
variant = model_variant.value if model_variant else None
|
||||
repo_variant = getattr(config, "repo_variant", None)
|
||||
variant = repo_variant.value if repo_variant else None
|
||||
model_path = model_path / submodel_type.value
|
||||
result: AnyModel = load_class.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=self._torch_dtype,
|
||||
variant=variant,
|
||||
) # type: ignore
|
||||
)
|
||||
return result
|
||||
|
@ -9,12 +9,16 @@ from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig, ModelVariantType
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
DiffusersConfigBase,
|
||||
MainCheckpointConfig,
|
||||
ModelVariantType,
|
||||
)
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
@ -41,14 +45,15 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not submodel_type is not None:
|
||||
raise Exception("A submodel type must be provided when loading main pipelines.")
|
||||
model_path = Path(config.path)
|
||||
load_class = self.get_hf_load_class(model_path, submodel_type)
|
||||
variant = model_variant.value if model_variant else None
|
||||
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
|
||||
variant = repo_variant.value if repo_variant else None
|
||||
model_path = model_path / submodel_type.value
|
||||
try:
|
||||
result: AnyModel = load_class.from_pretrained(
|
||||
@ -78,7 +83,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
else:
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
|
||||
assert isinstance(config, MainCheckpointConfig)
|
||||
base = config.base
|
||||
|
||||
@ -94,11 +99,11 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
|
||||
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||
|
||||
convert_ckpt_to_diffusers(
|
||||
loaded_model = convert_ckpt_to_diffusers(
|
||||
model_path,
|
||||
output_path,
|
||||
model_type=self.model_base_to_model_type[base],
|
||||
original_config_file=self._app_config.root_path / config.config_path,
|
||||
original_config_file=self._app_config.legacy_conf_path / config.config_path,
|
||||
extract_ema=True,
|
||||
from_safetensors=model_path.suffix == ".safetensors",
|
||||
precision=self._torch_dtype,
|
||||
@ -108,4 +113,4 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
load_safety_checker=False,
|
||||
num_in_channels=VARIANT_TO_IN_CHANNEL_MAP[config.variant],
|
||||
)
|
||||
return output_path
|
||||
return loaded_model
|
||||
|
@ -2,14 +2,13 @@
|
||||
"""Class for TI model loading in InvokeAI."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
@ -27,22 +26,19 @@ class TextualInversionLoader(ModelLoader):
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if submodel_type is not None:
|
||||
raise ValueError("There are no submodels in a TI model.")
|
||||
model = TextualInversionModelRaw.from_checkpoint(
|
||||
file_path=model_path,
|
||||
file_path=config.path,
|
||||
dtype=self._torch_dtype,
|
||||
)
|
||||
return model
|
||||
|
||||
# override
|
||||
def _get_model_path(
|
||||
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
|
||||
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
|
||||
def _get_model_path(self, config: AnyModelConfig) -> Path:
|
||||
model_path = self._app_config.models_path / config.path
|
||||
|
||||
if config.format == ModelFormat.EmbeddingFolder:
|
||||
@ -53,4 +49,4 @@ class TextualInversionLoader(ModelLoader):
|
||||
if not path.exists():
|
||||
raise OSError(f"The embedding file at {path} was not found")
|
||||
|
||||
return path, config, submodel_type
|
||||
return path
|
||||
|
@ -2,6 +2,7 @@
|
||||
"""Class for VAE model loading in InvokeAI."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
@ -13,7 +14,7 @@ from invokeai.backend.model_manager import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
||||
from invokeai.backend.model_manager.config import AnyModel, CheckpointConfigBase
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
@ -38,13 +39,13 @@ class VAELoader(GenericDiffusersLoader):
|
||||
else:
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
|
||||
# TODO(MM2): check whether sdxl VAE models convert.
|
||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
raise Exception(f"VAE conversion not supported for model type: {config.base}")
|
||||
else:
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
config_file = self._app_config.root_path / config.config_path
|
||||
config_file = self._app_config.legacy_conf_path / config.config_path
|
||||
|
||||
if model_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors_load_file(model_path, device="cpu")
|
||||
@ -63,6 +64,6 @@ class VAELoader(GenericDiffusersLoader):
|
||||
vae_config=ckpt_config,
|
||||
image_size=512,
|
||||
precision=self._torch_dtype,
|
||||
dump_path=output_path,
|
||||
)
|
||||
vae_model.save_pretrained(output_path, safe_serialization=True)
|
||||
return output_path
|
||||
return vae_model
|
||||
|
@ -323,7 +323,7 @@ class ModelProbe(object):
|
||||
with SilenceWarnings():
|
||||
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
|
||||
cls._scan_model(model_path.name, model_path)
|
||||
model = torch.load(model_path)
|
||||
model = torch.load(model_path, map_location="cpu")
|
||||
assert isinstance(model, dict)
|
||||
return model
|
||||
else:
|
||||
|
@ -28,6 +28,10 @@ def _conv_forward_asymmetric(self, input, weight, bias):
|
||||
|
||||
@contextmanager
|
||||
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL, AutoencoderTiny], seamless_axes: List[str]):
|
||||
if not seamless_axes:
|
||||
yield
|
||||
return
|
||||
|
||||
# Callable: (input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor
|
||||
to_restore: list[tuple[nn.Conv2d | nn.ConvTranspose2d, Callable]] = []
|
||||
try:
|
||||
|
@ -31,6 +31,9 @@ class ConfigMapper:
|
||||
YAML_FILENAME = "invokeai.yaml"
|
||||
DATABASE_FILENAME = "invokeai.db"
|
||||
|
||||
DEFAULT_OUTDIR = "outputs"
|
||||
DEFAULT_DB_DIR = "databases"
|
||||
|
||||
database_path = None
|
||||
database_backup_dir = None
|
||||
outputs_path = None
|
||||
@ -50,12 +53,18 @@ class ConfigMapper:
|
||||
def __load_from_root_config(self, invoke_root):
|
||||
"""Validate a yaml path exists, confirm the user wants to use it and load config."""
|
||||
yaml_path = os.path.join(invoke_root, self.YAML_FILENAME)
|
||||
if not os.path.exists(yaml_path):
|
||||
print(f"Unable to find invokeai.yaml at {yaml_path}!")
|
||||
return False
|
||||
if os.path.exists(yaml_path):
|
||||
db_dir, outdir = self.__load_paths_from_yaml_file(yaml_path)
|
||||
|
||||
if db_dir is None or outdir is None:
|
||||
print("The invokeai.yaml file was found but is missing the db_dir and/or outdir setting!")
|
||||
return False
|
||||
if db_dir is None:
|
||||
db_dir = self.DEFAULT_DB_DIR
|
||||
print(f"The invokeai.yaml file was found but is missing the db_dir setting! Defaulting to {db_dir}")
|
||||
if outdir is None:
|
||||
outdir = self.DEFAULT_OUTDIR
|
||||
print(f"The invokeai.yaml file was found but is missing the outdir setting! Defaulting to {outdir}")
|
||||
|
||||
if os.path.isabs(db_dir):
|
||||
self.database_path = os.path.join(db_dir, self.DATABASE_FILENAME)
|
||||
|
@ -43,8 +43,7 @@ def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_fil
|
||||
assert model_record is not None
|
||||
assert model_record.name == "test_embedding"
|
||||
assert model_record.type == ModelType.TextualInversion
|
||||
assert model_record.path.endswith(embedding_file.as_posix())
|
||||
assert Path(model_record.path).is_absolute()
|
||||
assert Path(model_record.path) == embedding_file
|
||||
assert Path(model_record.path).exists()
|
||||
assert model_record.base == BaseModelType("sd-1")
|
||||
assert model_record.description is not None
|
||||
@ -77,8 +76,7 @@ def test_install(
|
||||
key = mm2_installer.install_path(embedding_file)
|
||||
model_record = store.get_model(key)
|
||||
assert model_record.path.endswith("sd-1/embedding/test_embedding.safetensors")
|
||||
assert Path(model_record.path).is_absolute()
|
||||
assert Path(model_record.path).exists()
|
||||
assert (mm2_app_config.models_path / model_record.path).exists()
|
||||
assert model_record.source == embedding_file.as_posix()
|
||||
|
||||
|
||||
@ -147,10 +145,7 @@ def test_background_install(
|
||||
model_record = mm2_installer.record_store.get_model(key)
|
||||
assert model_record is not None
|
||||
assert model_record.path.endswith(destination)
|
||||
assert Path(model_record.path).is_absolute()
|
||||
assert Path(model_record.path).exists()
|
||||
assert model_record.key != "<NOKEY>"
|
||||
assert Path(model_record.path).exists()
|
||||
assert (mm2_app_config.models_path / model_record.path).exists()
|
||||
|
||||
# see if metadata was properly passed through
|
||||
assert model_record.description == description
|
||||
@ -172,7 +167,7 @@ def test_not_inplace_install(
|
||||
assert job is not None
|
||||
assert job.config_out is not None
|
||||
assert Path(job.config_out.path) != embedding_file
|
||||
assert Path(job.config_out.path).exists()
|
||||
assert (mm2_app_config.models_path / job.config_out.path).exists()
|
||||
|
||||
|
||||
def test_inplace_install(
|
||||
@ -184,16 +179,21 @@ def test_inplace_install(
|
||||
assert job is not None
|
||||
assert job.config_out is not None
|
||||
assert Path(job.config_out.path) == embedding_file
|
||||
assert Path(job.config_out.path).exists()
|
||||
|
||||
|
||||
def test_delete_install(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
||||
def test_delete_install(
|
||||
mm2_installer: ModelInstallServiceBase, embedding_file: Path, mm2_app_config: InvokeAIAppConfig
|
||||
) -> None:
|
||||
store = mm2_installer.record_store
|
||||
key = mm2_installer.install_path(embedding_file)
|
||||
model_record = store.get_model(key)
|
||||
assert Path(model_record.path).exists()
|
||||
assert (mm2_app_config.models_path / model_record.path).exists()
|
||||
assert embedding_file.exists() # original should still be there after installation
|
||||
mm2_installer.delete(key)
|
||||
assert not Path(model_record.path).exists() # after deletion, installed copy should not exist
|
||||
assert not (
|
||||
mm2_app_config.models_path / model_record.path
|
||||
).exists() # after deletion, installed copy should not exist
|
||||
assert embedding_file.exists() # but original should still be there
|
||||
with pytest.raises(UnknownModelException):
|
||||
store.get_model(key)
|
||||
@ -232,7 +232,7 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
|
||||
|
||||
key = job.config_out.key
|
||||
model_record = store.get_model(key)
|
||||
assert Path(model_record.path).exists()
|
||||
assert (mm2_app_config.models_path / model_record.path).exists()
|
||||
|
||||
assert len(bus.events) == 4
|
||||
event_names = [x.event_name for x in bus.events]
|
||||
@ -261,7 +261,7 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co
|
||||
|
||||
key = job.config_out.key
|
||||
model_record = store.get_model(key)
|
||||
assert Path(model_record.path).exists()
|
||||
assert (mm2_app_config.models_path / model_record.path).exists()
|
||||
assert model_record.type == ModelType.Main
|
||||
assert model_record.format == ModelFormat.Diffusers
|
||||
|
||||
|
Reference in New Issue
Block a user