mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into main
This commit is contained in:
commit
0259114d9c
@ -614,8 +614,8 @@ async def convert_model(
|
||||
The return value is the model configuration for the converted model.
|
||||
"""
|
||||
model_manager = ApiDependencies.invoker.services.model_manager
|
||||
loader = model_manager.load
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
loader = ApiDependencies.invoker.services.model_manager.load
|
||||
store = ApiDependencies.invoker.services.model_manager.store
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
|
||||
@ -630,7 +630,13 @@ async def convert_model(
|
||||
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
|
||||
|
||||
# loading the model will convert it into a cached diffusers file
|
||||
model_manager.load.load_model(model_config, submodel_type=SubModelType.Scheduler)
|
||||
try:
|
||||
cc_size = loader.convert_cache.max_size
|
||||
if cc_size == 0: # temporary set the convert cache to a positive number so that cached model is written
|
||||
loader._convert_cache.max_size = 1.0
|
||||
loader.load_model(model_config, submodel_type=SubModelType.Scheduler)
|
||||
finally:
|
||||
loader._convert_cache.max_size = cc_size
|
||||
|
||||
# Get the path of the converted model from the loader
|
||||
cache_path = loader.convert_cache.cache_path(key)
|
||||
|
@ -3,6 +3,7 @@ Invoke-managed custom node loader. See README.md for more information.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import traceback
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
|
||||
@ -41,11 +42,15 @@ for d in Path(__file__).parent.iterdir():
|
||||
|
||||
logger.info(f"Loading node pack {module_name}")
|
||||
|
||||
module = module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
try:
|
||||
module = module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
loaded_count += 1
|
||||
loaded_count += 1
|
||||
except Exception:
|
||||
full_error = traceback.format_exc()
|
||||
logger.error(f"Failed to load node pack {module_name}:\n{full_error}")
|
||||
|
||||
del init, module_name
|
||||
|
||||
|
@ -373,13 +373,16 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
if k == "conf_path":
|
||||
parsed_config_dict["legacy_models_yaml_path"] = v
|
||||
if k == "legacy_conf_dir":
|
||||
# The old default for this was "configs/stable-diffusion". If if the incoming config has that as the value, we won't set it.
|
||||
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
|
||||
# Else we do not attempt to migrate this setting
|
||||
if v != "configs/stable-diffusion":
|
||||
parsed_config_dict["legacy_conf_dir"] = v
|
||||
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
|
||||
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
|
||||
# If if the incoming config has the default value, skip
|
||||
continue
|
||||
elif Path(v).name == "stable-diffusion":
|
||||
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
|
||||
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
|
||||
else:
|
||||
# Else we do not attempt to migrate this setting
|
||||
parsed_config_dict["legacy_conf_dir"] = v
|
||||
elif k in InvokeAIAppConfig.model_fields:
|
||||
# skip unknown fields
|
||||
parsed_config_dict[k] = v
|
||||
|
@ -348,8 +348,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
config: dict[str, Any] = {}
|
||||
config["name"] = model_name
|
||||
config["description"] = stanza.get("description")
|
||||
config["config_path"] = stanza.get("config")
|
||||
|
||||
legacy_config_path = stanza.get("config")
|
||||
if legacy_config_path:
|
||||
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
|
||||
legacy_config_path: Path = self._app_config.root_path / legacy_config_path
|
||||
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
|
||||
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
|
||||
config["config_path"] = str(legacy_config_path)
|
||||
try:
|
||||
id = self.register_path(model_path=model_path, config=config)
|
||||
self._logger.info(f"Migrated {model_name} with id {id}")
|
||||
@ -368,11 +373,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def delete(self, key: str) -> None: # noqa D102
|
||||
"""Unregister the model. Delete its files only if they are within our models directory."""
|
||||
model = self.record_store.get_model(key)
|
||||
models_dir = self.app_config.models_path
|
||||
model_path = models_dir / Path(model.path) # handle legacy relative model paths
|
||||
if model_path.is_relative_to(models_dir):
|
||||
model_path = self.app_config.models_path / model.path
|
||||
|
||||
if model_path.is_relative_to(self.app_config.models_path):
|
||||
# If the models is in the Invoke-managed models dir, we delete it
|
||||
self.unconditionally_delete(key)
|
||||
else:
|
||||
# Else we only unregister it, leaving the file in place
|
||||
self.unregister(key)
|
||||
|
||||
def unconditionally_delete(self, key: str) -> None: # noqa D102
|
||||
@ -500,9 +507,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def _scan_for_missing_models(self) -> list[AnyModelConfig]:
|
||||
"""Scan the models directory for missing models and return a list of them."""
|
||||
missing_models: list[AnyModelConfig] = []
|
||||
for x in self.record_store.all_models():
|
||||
if not Path(x.path).resolve().exists():
|
||||
missing_models.append(x)
|
||||
for model_config in self.record_store.all_models():
|
||||
if not (self.app_config.models_path / model_config.path).resolve().exists():
|
||||
missing_models.append(model_config)
|
||||
return missing_models
|
||||
|
||||
def _register_orphaned_models(self) -> None:
|
||||
@ -512,7 +519,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
only situations in which we may have orphaned models in the models directory.
|
||||
"""
|
||||
|
||||
installed_model_paths = {Path(x.path).resolve() for x in self.record_store.all_models()}
|
||||
installed_model_paths = {
|
||||
(self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models()
|
||||
}
|
||||
|
||||
# The bool returned by this callback determines if the model is added to the list of models found by the search
|
||||
def on_model_found(model_path: Path) -> bool:
|
||||
@ -548,10 +557,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
May raise an UnknownModelException.
|
||||
"""
|
||||
model = self.record_store.get_model(key)
|
||||
old_path = Path(model.path).resolve()
|
||||
models_dir = self.app_config.models_path.resolve()
|
||||
models_dir = self.app_config.models_path
|
||||
old_path = self.app_config.models_path / model.path
|
||||
|
||||
if not old_path.is_relative_to(models_dir):
|
||||
# The model is not in the models directory - we don't need to move it.
|
||||
return model
|
||||
|
||||
new_path = (models_dir / model.base.value / model.type.value / model.name).with_suffix(old_path.suffix)
|
||||
@ -561,7 +571,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
self._logger.info(f"Moving {model.name} to {new_path}.")
|
||||
new_path = self._move_model(old_path, new_path)
|
||||
model.path = new_path.as_posix()
|
||||
model.path = new_path.relative_to(models_dir).as_posix()
|
||||
self.record_store.update_model(key, ModelRecordChanges(path=model.path))
|
||||
return model
|
||||
|
||||
@ -600,12 +610,19 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
model_path = model_path.resolve()
|
||||
|
||||
# Models in the Invoke-managed models dir should use relative paths.
|
||||
if model_path.is_relative_to(self.app_config.models_path):
|
||||
model_path = model_path.relative_to(self.app_config.models_path)
|
||||
|
||||
info.path = model_path.as_posix()
|
||||
|
||||
# Checkpoints have a config file needed for conversion - resolve this to an absolute path
|
||||
if isinstance(info, CheckpointConfigBase):
|
||||
legacy_conf = (self.app_config.legacy_conf_path / info.config_path).resolve()
|
||||
info.config_path = legacy_conf.as_posix()
|
||||
# Checkpoints have a config file needed for conversion. Same handling as the model weights - if it's in the
|
||||
# invoke-managed legacy config dir, we use a relative path.
|
||||
legacy_config_path = self.app_config.legacy_conf_path / info.config_path
|
||||
if legacy_config_path.is_relative_to(self.app_config.legacy_conf_path):
|
||||
legacy_config_path = legacy_config_path.relative_to(self.app_config.legacy_conf_path)
|
||||
info.config_path = legacy_config_path.as_posix()
|
||||
self.record_store.add_model(info)
|
||||
return info.key
|
||||
|
||||
|
@ -121,141 +121,146 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
poll_now_event.clear()
|
||||
# Middle processor try block; any unhandled exception is a non-fatal processor error
|
||||
try:
|
||||
# If we are paused, wait for resume event
|
||||
resume_event.wait()
|
||||
|
||||
# Get the next session to process
|
||||
self._queue_item = self._invoker.services.session_queue.dequeue()
|
||||
if self._queue_item is not None and resume_event.is_set():
|
||||
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
||||
cancel_event.clear()
|
||||
|
||||
# If profiling is enabled, start the profiler
|
||||
if self._profiler is not None:
|
||||
self._profiler.start(profile_id=self._queue_item.session_id)
|
||||
if self._queue_item is None:
|
||||
# The queue was empty, wait for next polling interval or event to try again
|
||||
self._invoker.services.logger.debug("Waiting for next polling interval or event")
|
||||
poll_now_event.wait(self._polling_interval)
|
||||
continue
|
||||
|
||||
# Prepare invocations and take the first
|
||||
self._invocation = self._queue_item.session.next()
|
||||
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
||||
cancel_event.clear()
|
||||
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
while self._invocation is not None and not cancel_event.is_set():
|
||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
||||
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
|
||||
# If profiling is enabled, start the profiler
|
||||
if self._profiler is not None:
|
||||
self._profiler.start(profile_id=self._queue_item.session_id)
|
||||
|
||||
# Send starting event
|
||||
self._invoker.services.events.emit_invocation_started(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session_id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
)
|
||||
# Prepare invocations and take the first
|
||||
self._invocation = self._queue_item.session.next()
|
||||
|
||||
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
||||
try:
|
||||
with self._invoker.services.performance_statistics.collect_stats(
|
||||
self._invocation, self._queue_item.session.id
|
||||
):
|
||||
# Build invocation context (the node-facing API)
|
||||
data = InvocationContextData(
|
||||
invocation=self._invocation,
|
||||
source_invocation_id=source_invocation_id,
|
||||
queue_item=self._queue_item,
|
||||
)
|
||||
context = build_invocation_context(
|
||||
data=data,
|
||||
services=self._invoker.services,
|
||||
cancel_event=self._cancel_event,
|
||||
)
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
while self._invocation is not None and not cancel_event.is_set():
|
||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
||||
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
|
||||
|
||||
# Invoke the node
|
||||
outputs = self._invocation.invoke_internal(
|
||||
context=context, services=self._invoker.services
|
||||
)
|
||||
# Send starting event
|
||||
self._invoker.services.events.emit_invocation_started(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session_id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
)
|
||||
|
||||
# Save outputs and history
|
||||
self._queue_item.session.complete(self._invocation.id, outputs)
|
||||
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_invocation_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
result=outputs.model_dump(),
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
# TODO(MM2): Create an event for this
|
||||
pass
|
||||
|
||||
except CanceledException:
|
||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||
# between invocations, in this loop. Some invocations are long-running, and we need to
|
||||
# be able to cancel them mid-execution.
|
||||
#
|
||||
# For example, denoising is a long-running invocation with many steps. A step callback
|
||||
# is executed after each step. This step callback checks if the canceled event is set,
|
||||
# then raises a CanceledException to stop execution immediately.
|
||||
#
|
||||
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
error = traceback.format_exc()
|
||||
|
||||
# Save error
|
||||
self._queue_item.session.set_node_error(self._invocation.id, error)
|
||||
self._invoker.services.logger.error(
|
||||
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
|
||||
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
||||
try:
|
||||
with self._invoker.services.performance_statistics.collect_stats(
|
||||
self._invocation, self._queue_item.session.id
|
||||
):
|
||||
# Build invocation context (the node-facing API)
|
||||
data = InvocationContextData(
|
||||
invocation=self._invocation,
|
||||
source_invocation_id=source_invocation_id,
|
||||
queue_item=self._queue_item,
|
||||
)
|
||||
context = build_invocation_context(
|
||||
data=data,
|
||||
services=self._invoker.services,
|
||||
cancel_event=self._cancel_event,
|
||||
)
|
||||
self._invoker.services.logger.error(error)
|
||||
|
||||
# Send error event
|
||||
self._invoker.services.events.emit_invocation_error(
|
||||
queue_batch_id=self._queue_item.session_id,
|
||||
# Invoke the node
|
||||
outputs = self._invocation.invoke_internal(
|
||||
context=context, services=self._invoker.services
|
||||
)
|
||||
|
||||
# Save outputs and history
|
||||
self._queue_item.session.complete(self._invocation.id, outputs)
|
||||
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_invocation_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
result=outputs.model_dump(),
|
||||
)
|
||||
pass
|
||||
|
||||
# The session is complete if the all invocations are complete or there was an error
|
||||
if self._queue_item.session.is_complete() or cancel_event.is_set():
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
except KeyboardInterrupt:
|
||||
# TODO(MM2): Create an event for this
|
||||
pass
|
||||
|
||||
except CanceledException:
|
||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||
# between invocations, in this loop. Some invocations are long-running, and we need to
|
||||
# be able to cancel them mid-execution.
|
||||
#
|
||||
# For example, denoising is a long-running invocation with many steps. A step callback
|
||||
# is executed after each step. This step callback checks if the canceled event is set,
|
||||
# then raises a CanceledException to stop execution immediately.
|
||||
#
|
||||
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
error = traceback.format_exc()
|
||||
|
||||
# Save error
|
||||
self._queue_item.session.set_node_error(self._invocation.id, error)
|
||||
self._invoker.services.logger.error(
|
||||
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
|
||||
)
|
||||
self._invoker.services.logger.error(error)
|
||||
|
||||
# Send error event
|
||||
self._invoker.services.events.emit_invocation_error(
|
||||
queue_batch_id=self._queue_item.session_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
)
|
||||
pass
|
||||
|
||||
# The session is complete if the all invocations are complete or there was an error
|
||||
if self._queue_item.session.is_complete() or cancel_event.is_set():
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
)
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if self._profiler:
|
||||
profile_path = self._profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
self._invoker.services.performance_statistics.dump_stats(
|
||||
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
|
||||
)
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if self._profiler:
|
||||
profile_path = self._profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
self._invoker.services.performance_statistics.dump_stats(
|
||||
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
|
||||
)
|
||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
|
||||
self._invoker.services.performance_statistics.reset_stats()
|
||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
|
||||
self._invoker.services.performance_statistics.reset_stats()
|
||||
|
||||
# Set the invocation to None to prepare for the next session
|
||||
self._invocation = None
|
||||
else:
|
||||
# Prepare the next invocation
|
||||
self._invocation = self._queue_item.session.next()
|
||||
|
||||
# The session is complete, immediately poll for next session
|
||||
self._queue_item = None
|
||||
poll_now_event.set()
|
||||
# Set the invocation to None to prepare for the next session
|
||||
self._invocation = None
|
||||
else:
|
||||
# Prepare the next invocation
|
||||
self._invocation = self._queue_item.session.next()
|
||||
else:
|
||||
# The queue was empty, wait for next polling interval or event to try again
|
||||
self._invoker.services.logger.debug("Waiting for next polling interval or event")
|
||||
|
@ -10,6 +10,8 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import build_migration_8
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@ -37,6 +39,8 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_5())
|
||||
migrator.register_migration(build_migration_6())
|
||||
migrator.register_migration(build_migration_7())
|
||||
migrator.register_migration(build_migration_8(app_config=config))
|
||||
migrator.register_migration(build_migration_9())
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
@ -0,0 +1,91 @@
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration8Callback:
|
||||
def __init__(self, app_config: InvokeAIAppConfig) -> None:
|
||||
self._app_config = app_config
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._drop_model_config_table(cursor)
|
||||
self._migrate_abs_models_to_rel(cursor)
|
||||
|
||||
def _drop_model_config_table(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Drops the old model_config table. This was missed in a previous migration."""
|
||||
|
||||
cursor.execute("DROP TABLE IF EXISTS model_config;")
|
||||
|
||||
def _migrate_abs_models_to_rel(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Check all model paths & legacy config paths to determine if they are inside Invoke-managed directories. If
|
||||
they are, update the paths to be relative to the managed directories.
|
||||
|
||||
This migration is a no-op for normal users (their paths will already be relative), but is necessary for users
|
||||
who have been testing the RCs with their live databases. The paths were made absolute in the initial RC, but this
|
||||
change was reverted. To smooth over the revert for our tests, we can migrate the paths back to relative.
|
||||
"""
|
||||
|
||||
models_path = self._app_config.models_path
|
||||
legacy_conf_path = self._app_config.legacy_conf_path
|
||||
legacy_conf_dir = self._app_config.legacy_conf_dir
|
||||
|
||||
stmt = """---sql
|
||||
SELECT
|
||||
id,
|
||||
path,
|
||||
json_extract(config, '$.config_path') as config_path
|
||||
FROM models;
|
||||
"""
|
||||
|
||||
all_models = cursor.execute(stmt).fetchall()
|
||||
|
||||
for model_id, model_path, model_config_path in all_models:
|
||||
# If the model path is inside the models directory, update it to be relative to the models directory.
|
||||
if Path(model_path).is_relative_to(models_path):
|
||||
new_path = Path(model_path).relative_to(models_path)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE models
|
||||
SET config = json_set(config, '$.path', ?)
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(str(new_path), model_id),
|
||||
)
|
||||
# If the model has a legacy config path and it is inside the legacy conf directory, update it to be
|
||||
# relative to the legacy conf directory. This also fixes up cases in which the config path was
|
||||
# incorrectly relativized to the root directory. It will now be relativized to the legacy conf directory.
|
||||
if model_config_path:
|
||||
if Path(model_config_path).is_relative_to(legacy_conf_path):
|
||||
new_config_path = Path(model_config_path).relative_to(legacy_conf_path)
|
||||
elif Path(model_config_path).is_relative_to(legacy_conf_dir):
|
||||
new_config_path = Path(*Path(model_config_path).parts[1:])
|
||||
else:
|
||||
new_config_path = None
|
||||
if new_config_path:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE models
|
||||
SET config = json_set(config, '$.config_path', ?)
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(str(new_config_path), model_id),
|
||||
)
|
||||
|
||||
|
||||
def build_migration_8(app_config: InvokeAIAppConfig) -> Migration:
|
||||
"""
|
||||
Build the migration from database version 7 to 8.
|
||||
|
||||
This migration does the following:
|
||||
- Removes the `model_config` table.
|
||||
- Migrates absolute model & legacy config paths to be relative to the models directory.
|
||||
"""
|
||||
migration_8 = Migration(
|
||||
from_version=7,
|
||||
to_version=8,
|
||||
callback=Migration8Callback(app_config),
|
||||
)
|
||||
|
||||
return migration_8
|
@ -0,0 +1,29 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration9Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._empty_session_queue(cursor)
|
||||
|
||||
def _empty_session_queue(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Empties the session queue. This is done to prevent any lingering session queue items from causing pydantic errors due to changed schemas."""
|
||||
|
||||
cursor.execute("DELETE FROM session_queue;")
|
||||
|
||||
|
||||
def build_migration_9() -> Migration:
|
||||
"""
|
||||
Build the migration from database version 8 to 9.
|
||||
|
||||
This migration does the following:
|
||||
- Empties the session queue. This is done to prevent any lingering session queue items from causing pydantic errors due to changed schemas.
|
||||
"""
|
||||
migration_9 = Migration(
|
||||
from_version=8,
|
||||
to_version=9,
|
||||
callback=Migration9Callback(),
|
||||
)
|
||||
|
||||
return migration_9
|
@ -1,4 +1,6 @@
|
||||
import sqlite3
|
||||
from contextlib import closing
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@ -32,6 +34,7 @@ class SqliteMigrator:
|
||||
self._db = db
|
||||
self._logger = db.logger
|
||||
self._migration_set = MigrationSet()
|
||||
self._backup_path: Optional[Path] = None
|
||||
|
||||
def register_migration(self, migration: Migration) -> None:
|
||||
"""Registers a migration."""
|
||||
@ -55,6 +58,18 @@ class SqliteMigrator:
|
||||
return False
|
||||
|
||||
self._logger.info("Database update needed")
|
||||
|
||||
# Make a backup of the db if it needs to be updated and is a file db
|
||||
if self._db.db_path is not None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
self._backup_path = self._db.db_path.parent / f"{self._db.db_path.stem}_backup_{timestamp}.db"
|
||||
self._logger.info(f"Backing up database to {str(self._backup_path)}")
|
||||
# Use SQLite to do the backup
|
||||
with closing(sqlite3.connect(self._backup_path)) as backup_conn:
|
||||
self._db.conn.backup(backup_conn)
|
||||
else:
|
||||
self._logger.info("Using in-memory database, no backup needed")
|
||||
|
||||
next_migration = self._migration_set.get(from_version=self._get_current_version(cursor))
|
||||
while next_migration is not None:
|
||||
self._run_migration(next_migration)
|
||||
|
@ -33,42 +33,3 @@ __all__ = [
|
||||
"SchedulerPredictionType",
|
||||
"SubModelType",
|
||||
]
|
||||
|
||||
########## to help populate the openapi_schema with format enums for each config ###########
|
||||
# This code is no longer necessary?
|
||||
# leave it here just in case
|
||||
#
|
||||
# import inspect
|
||||
# from enum import Enum
|
||||
# from typing import Any, Iterable, Dict, get_args, Set
|
||||
# def _expand(something: Any) -> Iterable[type]:
|
||||
# if isinstance(something, type):
|
||||
# yield something
|
||||
# else:
|
||||
# for x in get_args(something):
|
||||
# for y in _expand(x):
|
||||
# yield y
|
||||
|
||||
# def _find_format(cls: type) -> Iterable[Enum]:
|
||||
# if hasattr(inspect, "get_annotations"):
|
||||
# fields = inspect.get_annotations(cls)
|
||||
# else:
|
||||
# fields = cls.__annotations__
|
||||
# if "format" in fields:
|
||||
# for x in get_args(fields["format"]):
|
||||
# yield x
|
||||
# for parent_class in cls.__bases__:
|
||||
# for x in _find_format(parent_class):
|
||||
# yield x
|
||||
# return None
|
||||
|
||||
# def get_model_config_formats() -> Dict[str, Set[Enum]]:
|
||||
# result: Dict[str, Set[Enum]] = {}
|
||||
# for model_config in _expand(AnyModelConfig):
|
||||
# for field in _find_format(model_config):
|
||||
# if field is None:
|
||||
# continue
|
||||
# if not result.get(model_config.__qualname__):
|
||||
# result[model_config.__qualname__] = set()
|
||||
# result[model_config.__qualname__].add(field)
|
||||
# return result
|
||||
|
@ -3,10 +3,10 @@
|
||||
"""Conversion script for the Stable Diffusion checkpoints."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
convert_ldm_vae_checkpoint,
|
||||
create_vae_diffusers_config,
|
||||
@ -15,11 +15,14 @@ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
)
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from . import AnyModel
|
||||
|
||||
|
||||
def convert_ldm_vae_to_diffusers(
|
||||
checkpoint: Dict[str, torch.Tensor],
|
||||
checkpoint: torch.Tensor | dict[str, torch.Tensor],
|
||||
vae_config: DictConfig,
|
||||
image_size: int,
|
||||
dump_path: Optional[Path] = None,
|
||||
precision: torch.dtype = torch.float16,
|
||||
) -> AutoencoderKL:
|
||||
"""Convert a checkpoint-style VAE into a Diffusers VAE"""
|
||||
@ -28,16 +31,21 @@ def convert_ldm_vae_to_diffusers(
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
return vae.to(precision)
|
||||
vae.to(precision)
|
||||
|
||||
if dump_path:
|
||||
vae.save_pretrained(dump_path, safe_serialization=True)
|
||||
|
||||
return vae
|
||||
|
||||
|
||||
def convert_ckpt_to_diffusers(
|
||||
checkpoint_path: str | Path,
|
||||
dump_path: str | Path,
|
||||
dump_path: Optional[str | Path] = None,
|
||||
precision: torch.dtype = torch.float16,
|
||||
use_safetensors: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
) -> AnyModel:
|
||||
"""
|
||||
Takes all the arguments of download_from_original_stable_diffusion_ckpt(),
|
||||
and in addition a path-like object indicating the location of the desired diffusers
|
||||
@ -47,18 +55,20 @@ def convert_ckpt_to_diffusers(
|
||||
pipe = pipe.to(precision)
|
||||
|
||||
# TO DO: save correct repo variant
|
||||
pipe.save_pretrained(
|
||||
dump_path,
|
||||
safe_serialization=use_safetensors,
|
||||
)
|
||||
if dump_path:
|
||||
pipe.save_pretrained(
|
||||
dump_path,
|
||||
safe_serialization=use_safetensors,
|
||||
)
|
||||
return pipe
|
||||
|
||||
|
||||
def convert_controlnet_to_diffusers(
|
||||
checkpoint_path: Path,
|
||||
dump_path: Path,
|
||||
dump_path: Optional[Path] = None,
|
||||
precision: torch.dtype = torch.float16,
|
||||
**kwargs,
|
||||
):
|
||||
) -> AnyModel:
|
||||
"""
|
||||
Takes all the arguments of download_controlnet_from_original_ckpt(),
|
||||
and in addition a path-like object indicating the location of the desired diffusers
|
||||
@ -68,4 +78,6 @@ def convert_controlnet_to_diffusers(
|
||||
pipe = pipe.to(precision)
|
||||
|
||||
# TO DO: save correct repo variant
|
||||
pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||
if dump_path:
|
||||
pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||
return pipe
|
||||
|
@ -19,11 +19,20 @@ class ModelConvertCache(ModelConvertCacheBase):
|
||||
self._cache_path = cache_path
|
||||
self._max_size = max_size
|
||||
|
||||
# adjust cache size at startup in case it has been changed
|
||||
if self._cache_path.exists():
|
||||
self.make_room(0.0)
|
||||
|
||||
@property
|
||||
def max_size(self) -> float:
|
||||
"""Return the maximum size of this cache directory (GB)."""
|
||||
return self._max_size
|
||||
|
||||
@max_size.setter
|
||||
def max_size(self, value: float) -> None:
|
||||
"""Set the maximum size of this cache directory (GB)."""
|
||||
self._max_size = value
|
||||
|
||||
def cache_path(self, key: str) -> Path:
|
||||
"""Return the path for a model with the indicated key."""
|
||||
return self._cache_path / key
|
||||
|
@ -83,3 +83,15 @@ class ModelLoaderBase(ABC):
|
||||
) -> int:
|
||||
"""Return size in bytes of the model, calculated before loading."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the convert cache associated with this loader."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
"""Return the ram cache associated with this loader."""
|
||||
pass
|
||||
|
@ -3,14 +3,13 @@
|
||||
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
InvalidModelConfigException,
|
||||
ModelRepoVariant,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
|
||||
@ -54,51 +53,43 @@ class ModelLoader(ModelLoaderBase):
|
||||
if model_config.type is ModelType.Main and not submodel_type:
|
||||
raise InvalidModelConfigException("submodel_type is required when loading a main model")
|
||||
|
||||
model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
|
||||
model_path = self._get_model_path(model_config)
|
||||
|
||||
if not model_path.exists():
|
||||
raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}")
|
||||
|
||||
model_path = self._convert_if_needed(model_config, model_path, submodel_type)
|
||||
locker = self._load_if_needed(model_config, model_path, submodel_type)
|
||||
with skip_torch_weight_init():
|
||||
locker = self._convert_and_load(model_config, model_path, submodel_type)
|
||||
return LoadedModel(config=model_config, _locker=locker)
|
||||
|
||||
def _get_model_path(
|
||||
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
|
||||
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
|
||||
@property
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the convert cache associated with this loader."""
|
||||
return self._convert_cache
|
||||
|
||||
@property
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
"""Return the ram cache associated with this loader."""
|
||||
return self._ram_cache
|
||||
|
||||
def _get_model_path(self, config: AnyModelConfig) -> Path:
|
||||
model_base = self._app_config.models_path
|
||||
result = (model_base / config.path).resolve(), config, submodel_type
|
||||
return result
|
||||
return (model_base / config.path).resolve()
|
||||
|
||||
def _convert_if_needed(
|
||||
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
) -> Path:
|
||||
cache_path: Path = self._convert_cache.cache_path(config.key)
|
||||
|
||||
if not self._needs_conversion(config, model_path, cache_path):
|
||||
return cache_path if cache_path.exists() else model_path
|
||||
|
||||
self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
|
||||
return self._convert_model(config, model_path, cache_path)
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
return False
|
||||
|
||||
def _load_if_needed(
|
||||
def _convert_and_load(
|
||||
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
) -> ModelLockerBase:
|
||||
# TO DO: This is not thread safe!
|
||||
try:
|
||||
return self._ram_cache.get(config.key, submodel_type)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
model_variant = getattr(config, "repo_variant", None)
|
||||
self._ram_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
|
||||
|
||||
# This is where the model is actually loaded!
|
||||
with skip_torch_weight_init():
|
||||
loaded_model = self._load_model(model_path, model_variant=model_variant, submodel_type=submodel_type)
|
||||
cache_path: Path = self._convert_cache.cache_path(config.key)
|
||||
if self._needs_conversion(config, model_path, cache_path):
|
||||
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
|
||||
else:
|
||||
config.path = str(cache_path) if cache_path.exists() else str(self._get_model_path(config))
|
||||
loaded_model = self._load_model(config, submodel_type)
|
||||
|
||||
self._ram_cache.put(
|
||||
config.key,
|
||||
@ -123,15 +114,34 @@ class ModelLoader(ModelLoaderBase):
|
||||
variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None,
|
||||
)
|
||||
|
||||
def _do_convert(
|
||||
self, config: AnyModelConfig, model_path: Path, cache_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
) -> AnyModel:
|
||||
self.convert_cache.make_room(calc_model_size_by_fs(model_path))
|
||||
pipeline = self._convert_model(config, model_path, cache_path if self.convert_cache.max_size > 0 else None)
|
||||
if submodel_type:
|
||||
# Proactively load the various submodels into the RAM cache so that we don't have to re-convert
|
||||
# the entire pipeline every time a new submodel is needed.
|
||||
for subtype in SubModelType:
|
||||
if subtype == submodel_type:
|
||||
continue
|
||||
if submodel := getattr(pipeline, subtype.value, None):
|
||||
self._ram_cache.put(
|
||||
config.key, submodel_type=subtype, model=submodel, size=calc_model_size_by_data(submodel)
|
||||
)
|
||||
return getattr(pipeline, submodel_type.value) if submodel_type else pipeline
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
return False
|
||||
|
||||
# This needs to be implemented in subclasses that handle checkpoints
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
|
||||
raise NotImplementedError
|
||||
|
||||
# This needs to be implemented in the subclass
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
raise NotImplementedError
|
||||
|
@ -122,6 +122,11 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""Return the cap on cache size."""
|
||||
return self._max_cache_size
|
||||
|
||||
@max_cache_size.setter
|
||||
def max_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on cache size."""
|
||||
self._max_cache_size = value
|
||||
|
||||
@property
|
||||
def stats(self) -> Optional[CacheStats]:
|
||||
"""Return collected CacheStats object."""
|
||||
@ -157,8 +162,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
) -> None:
|
||||
"""Store model under key and optional submodel_type."""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
assert key not in self._cached_models
|
||||
|
||||
if key in self._cached_models:
|
||||
return
|
||||
self.make_room(size)
|
||||
cache_record = CacheRecord(key, model, size)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
@ -405,6 +411,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
#
|
||||
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
|
||||
# immediately when their reference count hits 0.
|
||||
if self.stats:
|
||||
self.stats.cleared = models_cleared
|
||||
gc.collect()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -2,8 +2,10 @@
|
||||
"""Class for ControlNet model loading in InvokeAI."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
@ -33,7 +35,7 @@ class ControlNetLoader(GenericDiffusersLoader):
|
||||
else:
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
image_size = (
|
||||
512
|
||||
@ -44,8 +46,8 @@ class ControlNetLoader(GenericDiffusersLoader):
|
||||
)
|
||||
|
||||
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||
with open(self._app_config.root_path / config.config_path, "r") as config_stream:
|
||||
convert_controlnet_to_diffusers(
|
||||
with open(self._app_config.legacy_conf_path / config.config_path, "r") as config_stream:
|
||||
result = convert_controlnet_to_diffusers(
|
||||
model_path,
|
||||
output_path,
|
||||
original_config_file=config_stream,
|
||||
@ -53,4 +55,4 @@ class ControlNetLoader(GenericDiffusersLoader):
|
||||
precision=self._torch_dtype,
|
||||
from_safetensors=model_path.suffix == ".safetensors",
|
||||
)
|
||||
return output_path
|
||||
return result
|
||||
|
@ -10,13 +10,14 @@ from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase
|
||||
|
||||
from .. import ModelLoader, ModelLoaderRegistry
|
||||
|
||||
@ -28,14 +29,15 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
model_path = Path(config.path)
|
||||
model_class = self.get_hf_load_class(model_path)
|
||||
if submodel_type is not None:
|
||||
raise Exception(f"There are no submodels in models of type {model_class}")
|
||||
variant = model_variant.value if model_variant else None
|
||||
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
|
||||
variant = repo_variant.value if repo_variant else None
|
||||
try:
|
||||
result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant)
|
||||
except OSError as e:
|
||||
|
@ -9,13 +9,14 @@ import torch
|
||||
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
|
||||
@ -24,13 +25,13 @@ class IPAdapterInvokeAILoader(ModelLoader):
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if submodel_type is not None:
|
||||
raise ValueError("There are no submodels in an IP-Adapter model.")
|
||||
model = build_ip_adapter(
|
||||
model_path = Path(config.path)
|
||||
model: RawModel = build_ip_adapter(
|
||||
ip_adapter_ckpt_path=str(model_path / "ip_adapter.bin"),
|
||||
device=torch.device("cpu"),
|
||||
dtype=self._torch_dtype,
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
@ -12,7 +12,6 @@ from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
@ -41,12 +40,12 @@ class LoRALoader(ModelLoader):
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if submodel_type is not None:
|
||||
raise ValueError("There are no submodels in a LoRA model.")
|
||||
model_path = Path(config.path)
|
||||
assert self._model_base is not None
|
||||
model = LoRAModelRaw.from_checkpoint(
|
||||
file_path=model_path,
|
||||
@ -56,12 +55,9 @@ class LoRALoader(ModelLoader):
|
||||
return model
|
||||
|
||||
# override
|
||||
def _get_model_path(
|
||||
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
|
||||
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
|
||||
self._model_base = (
|
||||
config.base
|
||||
) # cheating a little - we remember this variable for using in the subsequent call to _load_model()
|
||||
def _get_model_path(self, config: AnyModelConfig) -> Path:
|
||||
# cheating a little - we remember this variable for using in the subsequent call to _load_model()
|
||||
self._model_base = config.base
|
||||
|
||||
model_base_path = self._app_config.models_path
|
||||
model_path = model_base_path / config.path
|
||||
@ -73,5 +69,4 @@ class LoRALoader(ModelLoader):
|
||||
model_path = path
|
||||
break
|
||||
|
||||
result = model_path.resolve(), config, submodel_type
|
||||
return result
|
||||
return model_path.resolve()
|
||||
|
@ -7,9 +7,9 @@ from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
@ -25,18 +25,19 @@ class OnnyxDiffusersModel(GenericDiffusersLoader):
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not submodel_type is not None:
|
||||
raise Exception("A submodel type must be provided when loading onnx pipelines.")
|
||||
model_path = Path(config.path)
|
||||
load_class = self.get_hf_load_class(model_path, submodel_type)
|
||||
variant = model_variant.value if model_variant else None
|
||||
repo_variant = getattr(config, "repo_variant", None)
|
||||
variant = repo_variant.value if repo_variant else None
|
||||
model_path = model_path / submodel_type.value
|
||||
result: AnyModel = load_class.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=self._torch_dtype,
|
||||
variant=variant,
|
||||
) # type: ignore
|
||||
)
|
||||
return result
|
||||
|
@ -9,12 +9,16 @@ from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig, ModelVariantType
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
DiffusersConfigBase,
|
||||
MainCheckpointConfig,
|
||||
ModelVariantType,
|
||||
)
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
@ -41,14 +45,15 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not submodel_type is not None:
|
||||
raise Exception("A submodel type must be provided when loading main pipelines.")
|
||||
model_path = Path(config.path)
|
||||
load_class = self.get_hf_load_class(model_path, submodel_type)
|
||||
variant = model_variant.value if model_variant else None
|
||||
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
|
||||
variant = repo_variant.value if repo_variant else None
|
||||
model_path = model_path / submodel_type.value
|
||||
try:
|
||||
result: AnyModel = load_class.from_pretrained(
|
||||
@ -78,7 +83,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
else:
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
|
||||
assert isinstance(config, MainCheckpointConfig)
|
||||
base = config.base
|
||||
|
||||
@ -94,11 +99,11 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
|
||||
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||
|
||||
convert_ckpt_to_diffusers(
|
||||
loaded_model = convert_ckpt_to_diffusers(
|
||||
model_path,
|
||||
output_path,
|
||||
model_type=self.model_base_to_model_type[base],
|
||||
original_config_file=self._app_config.root_path / config.config_path,
|
||||
original_config_file=self._app_config.legacy_conf_path / config.config_path,
|
||||
extract_ema=True,
|
||||
from_safetensors=model_path.suffix == ".safetensors",
|
||||
precision=self._torch_dtype,
|
||||
@ -108,4 +113,4 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
load_safety_checker=False,
|
||||
num_in_channels=VARIANT_TO_IN_CHANNEL_MAP[config.variant],
|
||||
)
|
||||
return output_path
|
||||
return loaded_model
|
||||
|
@ -2,14 +2,13 @@
|
||||
"""Class for TI model loading in InvokeAI."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
@ -27,22 +26,19 @@ class TextualInversionLoader(ModelLoader):
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if submodel_type is not None:
|
||||
raise ValueError("There are no submodels in a TI model.")
|
||||
model = TextualInversionModelRaw.from_checkpoint(
|
||||
file_path=model_path,
|
||||
file_path=config.path,
|
||||
dtype=self._torch_dtype,
|
||||
)
|
||||
return model
|
||||
|
||||
# override
|
||||
def _get_model_path(
|
||||
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
|
||||
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
|
||||
def _get_model_path(self, config: AnyModelConfig) -> Path:
|
||||
model_path = self._app_config.models_path / config.path
|
||||
|
||||
if config.format == ModelFormat.EmbeddingFolder:
|
||||
@ -53,4 +49,4 @@ class TextualInversionLoader(ModelLoader):
|
||||
if not path.exists():
|
||||
raise OSError(f"The embedding file at {path} was not found")
|
||||
|
||||
return path, config, submodel_type
|
||||
return path
|
||||
|
@ -2,6 +2,7 @@
|
||||
"""Class for VAE model loading in InvokeAI."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
@ -13,7 +14,7 @@ from invokeai.backend.model_manager import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
||||
from invokeai.backend.model_manager.config import AnyModel, CheckpointConfigBase
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
@ -38,13 +39,13 @@ class VAELoader(GenericDiffusersLoader):
|
||||
else:
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
|
||||
# TODO(MM2): check whether sdxl VAE models convert.
|
||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
raise Exception(f"VAE conversion not supported for model type: {config.base}")
|
||||
else:
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
config_file = self._app_config.root_path / config.config_path
|
||||
config_file = self._app_config.legacy_conf_path / config.config_path
|
||||
|
||||
if model_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors_load_file(model_path, device="cpu")
|
||||
@ -63,6 +64,6 @@ class VAELoader(GenericDiffusersLoader):
|
||||
vae_config=ckpt_config,
|
||||
image_size=512,
|
||||
precision=self._torch_dtype,
|
||||
dump_path=output_path,
|
||||
)
|
||||
vae_model.save_pretrained(output_path, safe_serialization=True)
|
||||
return output_path
|
||||
return vae_model
|
||||
|
@ -94,6 +94,7 @@
|
||||
"reactflow": "^11.10.4",
|
||||
"redux-dynamic-middlewares": "^2.2.0",
|
||||
"redux-remember": "^5.1.0",
|
||||
"rfdc": "^1.3.1",
|
||||
"roarr": "^7.21.1",
|
||||
"serialize-error": "^11.0.3",
|
||||
"socket.io-client": "^4.7.5",
|
||||
|
@ -137,6 +137,9 @@ dependencies:
|
||||
redux-remember:
|
||||
specifier: ^5.1.0
|
||||
version: 5.1.0(redux@5.0.1)
|
||||
rfdc:
|
||||
specifier: ^1.3.1
|
||||
version: 1.3.1
|
||||
roarr:
|
||||
specifier: ^7.21.1
|
||||
version: 7.21.1
|
||||
@ -12128,6 +12131,10 @@ packages:
|
||||
resolution: {integrity: sha512-/x8uIPdTafBqakK0TmPNJzgkLP+3H+yxpUJhCQHsLBg1rYEVNR2D8BRYNWQhVBjyOd7oo1dZRVzIkwMY2oqfYQ==}
|
||||
dev: true
|
||||
|
||||
/rfdc@1.3.1:
|
||||
resolution: {integrity: sha512-r5a3l5HzYlIC68TpmYKlxWjmOP6wiPJ1vWv2HeLhNsRZMrCkxeqxiHlQ21oXmQ4F3SiryXBHhAD7JZqvOJjFmg==}
|
||||
dev: false
|
||||
|
||||
/rimraf@2.6.3:
|
||||
resolution: {integrity: sha512-mwqeW5XsA2qAejG46gYdENaxXjx9onRNCfn7L0duuP4hCuTIi/QO7PDK07KJfp1d+izWPrzEJDcSqBa0OZQriA==}
|
||||
hasBin: true
|
||||
|
@ -4,7 +4,7 @@
|
||||
"reportBugLabel": "Fehler melden",
|
||||
"settingsLabel": "Einstellungen",
|
||||
"img2img": "Bild zu Bild",
|
||||
"nodes": "Knoten Editor",
|
||||
"nodes": "Arbeitsabläufe",
|
||||
"upload": "Hochladen",
|
||||
"load": "Laden",
|
||||
"statusDisconnected": "Getrennt",
|
||||
@ -74,7 +74,8 @@
|
||||
"updated": "Aktualisiert",
|
||||
"copy": "Kopieren",
|
||||
"aboutHeading": "Nutzen Sie Ihre kreative Energie",
|
||||
"toResolve": "Lösen"
|
||||
"toResolve": "Lösen",
|
||||
"add": "Hinzufügen"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Bildgröße",
|
||||
@ -104,11 +105,16 @@
|
||||
"dropToUpload": "$t(gallery.drop) zum hochladen",
|
||||
"dropOrUpload": "$t(gallery.drop) oder hochladen",
|
||||
"drop": "Ablegen",
|
||||
"problemDeletingImages": "Problem beim Löschen der Bilder"
|
||||
"problemDeletingImages": "Problem beim Löschen der Bilder",
|
||||
"bulkDownloadRequested": "Download vorbereiten",
|
||||
"bulkDownloadRequestedDesc": "Dein Download wird vorbereitet. Dies kann ein paar Momente dauern.",
|
||||
"bulkDownloadRequestFailed": "Problem beim Download vorbereiten",
|
||||
"bulkDownloadFailed": "Download fehlgeschlagen",
|
||||
"alwaysShowImageSizeBadge": "Zeige immer Bilder Größe Abzeichen"
|
||||
},
|
||||
"hotkeys": {
|
||||
"keyboardShortcuts": "Tastenkürzel",
|
||||
"appHotkeys": "App-Tastenkombinationen",
|
||||
"appHotkeys": "App",
|
||||
"generalHotkeys": "Allgemein",
|
||||
"galleryHotkeys": "Galerie",
|
||||
"unifiedCanvasHotkeys": "Leinwand",
|
||||
@ -757,7 +763,9 @@
|
||||
"scheduler": "Planer",
|
||||
"noRecallParameters": "Es wurden keine Parameter zum Abrufen gefunden",
|
||||
"recallParameters": "Parameter wiederherstellen",
|
||||
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)"
|
||||
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
|
||||
"allPrompts": "Alle Prompts",
|
||||
"imageDimensions": "Bilder Auslösungen"
|
||||
},
|
||||
"popovers": {
|
||||
"noiseUseCPU": {
|
||||
@ -1068,5 +1076,10 @@
|
||||
},
|
||||
"dynamicPrompts": {
|
||||
"showDynamicPrompts": "Dynamische Prompts anzeigen"
|
||||
},
|
||||
"prompt": {
|
||||
"noMatchingTriggers": "Keine passenden Auslöser",
|
||||
"addPromptTrigger": "Auslöse Text hinzufügen",
|
||||
"compatibleEmbeddings": "Kompatible Einbettungen"
|
||||
}
|
||||
}
|
||||
|
@ -73,7 +73,8 @@
|
||||
"ai": "ia",
|
||||
"file": "File",
|
||||
"toResolve": "Da risolvere",
|
||||
"add": "Aggiungi"
|
||||
"add": "Aggiungi",
|
||||
"loglevel": "Livello di log"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Dimensione dell'immagine",
|
||||
@ -934,7 +935,9 @@
|
||||
"base": "Base",
|
||||
"lineart": "Linea",
|
||||
"controlnet": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.controlNet))",
|
||||
"mediapipeFace": "Mediapipe Volto"
|
||||
"mediapipeFace": "Mediapipe Volto",
|
||||
"ip_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.ipAdapter))",
|
||||
"t2i_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.t2iAdapter))"
|
||||
},
|
||||
"queue": {
|
||||
"queueFront": "Aggiungi all'inizio della coda",
|
||||
@ -1490,7 +1493,8 @@
|
||||
"title": "Generazione"
|
||||
},
|
||||
"advanced": {
|
||||
"title": "Avanzate"
|
||||
"title": "Avanzate",
|
||||
"options": "Opzioni $t(accordions.advanced.title)"
|
||||
},
|
||||
"image": {
|
||||
"title": "Immagine"
|
||||
|
@ -75,7 +75,8 @@
|
||||
"copy": "Копировать",
|
||||
"localSystem": "Локальная система",
|
||||
"aboutDesc": "Используя Invoke для работы? Проверьте это:",
|
||||
"add": "Добавить"
|
||||
"add": "Добавить",
|
||||
"loglevel": "Уровень логов"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Размер изображений",
|
||||
@ -1505,7 +1506,8 @@
|
||||
"title": "Генерация"
|
||||
},
|
||||
"advanced": {
|
||||
"title": "Расширенные"
|
||||
"title": "Расширенные",
|
||||
"options": "Опции $t(accordions.advanced.title)"
|
||||
},
|
||||
"image": {
|
||||
"title": "Изображение"
|
||||
|
@ -1,7 +1,7 @@
|
||||
import type { UnknownAction } from '@reduxjs/toolkit';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
|
||||
import { cloneDeep } from 'lodash-es';
|
||||
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
||||
import type { Graph } from 'services/api/types';
|
||||
import { socketGeneratorProgress } from 'services/events/actions';
|
||||
@ -33,7 +33,7 @@ export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
|
||||
}
|
||||
|
||||
if (socketGeneratorProgress.match(action)) {
|
||||
const sanitized = cloneDeep(action);
|
||||
const sanitized = deepClone(action);
|
||||
if (sanitized.payload.data.progress_image) {
|
||||
sanitized.payload.data.progress_image.dataURL = '<Progress image omitted>';
|
||||
}
|
||||
|
@ -1,4 +1,5 @@
|
||||
import { cloneDeep, merge } from 'lodash-es';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { merge } from 'lodash-es';
|
||||
import { ClickScrollPlugin, OverlayScrollbars } from 'overlayscrollbars';
|
||||
import type { UseOverlayScrollbarsParams } from 'overlayscrollbars-react';
|
||||
|
||||
@ -22,7 +23,7 @@ export const getOverlayScrollbarsParams = (
|
||||
overflowX: 'hidden' | 'scroll' = 'hidden',
|
||||
overflowY: 'hidden' | 'scroll' = 'scroll'
|
||||
) => {
|
||||
const params = cloneDeep(overlayScrollbarsParams);
|
||||
const params = deepClone(overlayScrollbarsParams);
|
||||
merge(params, { options: { overflow: { y: overflowY, x: overflowX } } });
|
||||
return params;
|
||||
};
|
||||
|
15
invokeai/frontend/web/src/common/util/deepClone.ts
Normal file
15
invokeai/frontend/web/src/common/util/deepClone.ts
Normal file
@ -0,0 +1,15 @@
|
||||
import rfdc from 'rfdc';
|
||||
const _rfdc = rfdc();
|
||||
|
||||
/**
|
||||
* Deep-clones an object using Really Fast Deep Clone.
|
||||
* This is the fastest deep clone library on Chrome, but not the fastest on FF. Still, it's much faster than lodash
|
||||
* and structuredClone, so it's the best all-around choice.
|
||||
*
|
||||
* Simple Benchmark: https://www.measurethat.net/Benchmarks/Show/30358/0/lodash-clonedeep-vs-jsonparsejsonstringify-vs-recursive
|
||||
* Repo: https://github.com/davidmarkclements/rfdc
|
||||
*
|
||||
* @param obj The object to deep-clone
|
||||
* @returns The cloned object
|
||||
*/
|
||||
export const deepClone = <T>(obj: T): T => _rfdc(obj);
|
@ -1,6 +1,7 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMultiple';
|
||||
import calculateCoordinates from 'features/canvas/util/calculateCoordinates';
|
||||
import calculateScale from 'features/canvas/util/calculateScale';
|
||||
@ -13,7 +14,7 @@ import { modelChanged } from 'features/parameters/store/generationSlice';
|
||||
import type { PayloadActionWithOptimalDimension } from 'features/parameters/store/types';
|
||||
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import type { IRect, Vector2d } from 'konva/lib/types';
|
||||
import { clamp, cloneDeep } from 'lodash-es';
|
||||
import { clamp } from 'lodash-es';
|
||||
import type { RgbaColor } from 'react-colorful';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
@ -36,7 +37,7 @@ import { CANVAS_GRID_SIZE_FINE } from './constants';
|
||||
/**
|
||||
* The maximum history length to keep in the past/future layer states.
|
||||
*/
|
||||
const MAX_HISTORY = 128;
|
||||
const MAX_HISTORY = 100;
|
||||
|
||||
const initialLayerState: CanvasLayerState = {
|
||||
objects: [],
|
||||
@ -121,7 +122,7 @@ export const canvasSlice = createSlice({
|
||||
state.brushSize = action.payload;
|
||||
},
|
||||
clearMask: (state) => {
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
pushToPrevLayerStates(state);
|
||||
state.layerState.objects = state.layerState.objects.filter((obj) => !isCanvasMaskLine(obj));
|
||||
state.futureLayerStates = [];
|
||||
state.shouldPreserveMaskedArea = false;
|
||||
@ -163,10 +164,10 @@ export const canvasSlice = createSlice({
|
||||
state.boundingBoxDimensions = newBoundingBoxDimensions;
|
||||
state.boundingBoxCoordinates = newBoundingBoxCoordinates;
|
||||
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
pushToPrevLayerStates(state);
|
||||
|
||||
state.layerState = {
|
||||
...cloneDeep(initialLayerState),
|
||||
...deepClone(initialLayerState),
|
||||
objects: [
|
||||
{
|
||||
kind: 'image',
|
||||
@ -261,11 +262,7 @@ export const canvasSlice = createSlice({
|
||||
return;
|
||||
}
|
||||
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
||||
state.pastLayerStates.shift();
|
||||
}
|
||||
pushToPrevLayerStates(state);
|
||||
|
||||
state.layerState.stagingArea.images.push({
|
||||
kind: 'image',
|
||||
@ -279,13 +276,9 @@ export const canvasSlice = createSlice({
|
||||
state.futureLayerStates = [];
|
||||
},
|
||||
discardStagedImages: (state) => {
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
pushToPrevLayerStates(state);
|
||||
|
||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
||||
state.pastLayerStates.shift();
|
||||
}
|
||||
|
||||
state.layerState.stagingArea = cloneDeep(cloneDeep(initialLayerState)).stagingArea;
|
||||
state.layerState.stagingArea = deepClone(initialLayerState.stagingArea);
|
||||
|
||||
state.futureLayerStates = [];
|
||||
state.shouldShowStagingOutline = true;
|
||||
@ -294,11 +287,7 @@ export const canvasSlice = createSlice({
|
||||
},
|
||||
discardStagedImage: (state) => {
|
||||
const { images, selectedImageIndex } = state.layerState.stagingArea;
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
||||
state.pastLayerStates.shift();
|
||||
}
|
||||
pushToPrevLayerStates(state);
|
||||
|
||||
if (!images.length) {
|
||||
return;
|
||||
@ -320,11 +309,7 @@ export const canvasSlice = createSlice({
|
||||
addFillRect: (state) => {
|
||||
const { boundingBoxCoordinates, boundingBoxDimensions, brushColor } = state;
|
||||
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
||||
state.pastLayerStates.shift();
|
||||
}
|
||||
pushToPrevLayerStates(state);
|
||||
|
||||
state.layerState.objects.push({
|
||||
kind: 'fillRect',
|
||||
@ -339,11 +324,7 @@ export const canvasSlice = createSlice({
|
||||
addEraseRect: (state) => {
|
||||
const { boundingBoxCoordinates, boundingBoxDimensions } = state;
|
||||
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
||||
state.pastLayerStates.shift();
|
||||
}
|
||||
pushToPrevLayerStates(state);
|
||||
|
||||
state.layerState.objects.push({
|
||||
kind: 'eraseRect',
|
||||
@ -367,11 +348,7 @@ export const canvasSlice = createSlice({
|
||||
// set & then spread this to only conditionally add the "color" key
|
||||
const newColor = layer === 'base' && tool === 'brush' ? { color: brushColor } : {};
|
||||
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
||||
state.pastLayerStates.shift();
|
||||
}
|
||||
pushToPrevLayerStates(state);
|
||||
|
||||
const newLine: CanvasMaskLine | CanvasBaseLine = {
|
||||
kind: 'line',
|
||||
@ -409,11 +386,7 @@ export const canvasSlice = createSlice({
|
||||
return;
|
||||
}
|
||||
|
||||
state.futureLayerStates.unshift(cloneDeep(state.layerState));
|
||||
|
||||
if (state.futureLayerStates.length > MAX_HISTORY) {
|
||||
state.futureLayerStates.pop();
|
||||
}
|
||||
pushToFutureLayerStates(state);
|
||||
|
||||
state.layerState = targetState;
|
||||
},
|
||||
@ -424,11 +397,7 @@ export const canvasSlice = createSlice({
|
||||
return;
|
||||
}
|
||||
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
||||
state.pastLayerStates.shift();
|
||||
}
|
||||
pushToPrevLayerStates(state);
|
||||
|
||||
state.layerState = targetState;
|
||||
},
|
||||
@ -445,8 +414,8 @@ export const canvasSlice = createSlice({
|
||||
state.shouldShowIntermediates = action.payload;
|
||||
},
|
||||
resetCanvas: (state) => {
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
state.layerState = cloneDeep(initialLayerState);
|
||||
pushToPrevLayerStates(state);
|
||||
state.layerState = deepClone(initialLayerState);
|
||||
state.futureLayerStates = [];
|
||||
state.batchIds = [];
|
||||
state.boundingBoxCoordinates = {
|
||||
@ -540,11 +509,7 @@ export const canvasSlice = createSlice({
|
||||
|
||||
const { images, selectedImageIndex } = state.layerState.stagingArea;
|
||||
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
||||
state.pastLayerStates.shift();
|
||||
}
|
||||
pushToPrevLayerStates(state);
|
||||
|
||||
const imageToCommit = images[selectedImageIndex];
|
||||
|
||||
@ -553,7 +518,7 @@ export const canvasSlice = createSlice({
|
||||
...imageToCommit,
|
||||
});
|
||||
}
|
||||
state.layerState.stagingArea = cloneDeep(initialLayerState).stagingArea;
|
||||
state.layerState.stagingArea = deepClone(initialLayerState.stagingArea);
|
||||
|
||||
state.futureLayerStates = [];
|
||||
state.shouldShowStagingOutline = true;
|
||||
@ -623,7 +588,7 @@ export const canvasSlice = createSlice({
|
||||
};
|
||||
},
|
||||
setMergedCanvas: (state, action: PayloadAction<CanvasImage>) => {
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
pushToPrevLayerStates(state);
|
||||
|
||||
state.futureLayerStates = [];
|
||||
|
||||
@ -743,3 +708,17 @@ export const canvasPersistConfig: PersistConfig<CanvasState> = {
|
||||
migrate: migrateCanvasState,
|
||||
persistDenylist: [],
|
||||
};
|
||||
|
||||
const pushToPrevLayerStates = (state: CanvasState) => {
|
||||
state.pastLayerStates.push(deepClone(state.layerState));
|
||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
||||
state.pastLayerStates = state.pastLayerStates.slice(-MAX_HISTORY);
|
||||
}
|
||||
};
|
||||
|
||||
const pushToFutureLayerStates = (state: CanvasState) => {
|
||||
state.futureLayerStates.unshift(deepClone(state.layerState));
|
||||
if (state.futureLayerStates.length > MAX_HISTORY) {
|
||||
state.futureLayerStates = state.futureLayerStates.slice(0, MAX_HISTORY);
|
||||
}
|
||||
};
|
||||
|
@ -2,10 +2,11 @@ import type { PayloadAction, Update } from '@reduxjs/toolkit';
|
||||
import { createEntityAdapter, createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter';
|
||||
import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { cloneDeep, merge, uniq } from 'lodash-es';
|
||||
import { merge, uniq } from 'lodash-es';
|
||||
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import { socketInvocationError } from 'services/events/actions';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
@ -114,7 +115,7 @@ export const controlAdaptersSlice = createSlice({
|
||||
if (!controlAdapter) {
|
||||
return;
|
||||
}
|
||||
const newControlAdapter = merge(cloneDeep(controlAdapter), {
|
||||
const newControlAdapter = merge(deepClone(controlAdapter), {
|
||||
id: newId,
|
||||
isEnabled: true,
|
||||
});
|
||||
@ -270,7 +271,7 @@ export const controlAdaptersSlice = createSlice({
|
||||
return;
|
||||
}
|
||||
|
||||
const processorNode = merge(cloneDeep(cn.processorNode), params);
|
||||
const processorNode = merge(deepClone(cn.processorNode), params);
|
||||
|
||||
caAdapter.updateOne(state, {
|
||||
id,
|
||||
@ -293,7 +294,7 @@ export const controlAdaptersSlice = createSlice({
|
||||
return;
|
||||
}
|
||||
|
||||
const processorNode = cloneDeep(
|
||||
const processorNode = deepClone(
|
||||
CONTROLNET_PROCESSORS[processorType].buildDefaults(cn.model?.base)
|
||||
) as RequiredControlAdapterProcessorNode;
|
||||
|
||||
@ -333,7 +334,7 @@ export const controlAdaptersSlice = createSlice({
|
||||
caAdapter.updateOne(state, update);
|
||||
},
|
||||
controlAdaptersReset: () => {
|
||||
return cloneDeep(initialControlAdaptersState);
|
||||
return deepClone(initialControlAdaptersState);
|
||||
},
|
||||
pendingControlImagesCleared: (state) => {
|
||||
state.pendingControlImages = [];
|
||||
@ -406,7 +407,7 @@ const migrateControlAdaptersState = (state: any): any => {
|
||||
state._version = 1;
|
||||
}
|
||||
if (state._version === 1) {
|
||||
state = cloneDeep(initialControlAdaptersState);
|
||||
state = deepClone(initialControlAdaptersState);
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
@ -1,3 +1,4 @@
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||
import type {
|
||||
ControlAdapterConfig,
|
||||
@ -7,7 +8,7 @@ import type {
|
||||
RequiredCannyImageProcessorInvocation,
|
||||
T2IAdapterConfig,
|
||||
} from 'features/controlAdapters/store/types';
|
||||
import { cloneDeep, merge } from 'lodash-es';
|
||||
import { merge } from 'lodash-es';
|
||||
|
||||
export const initialControlNet: Omit<ControlNetConfig, 'id'> = {
|
||||
type: 'controlnet',
|
||||
@ -57,11 +58,11 @@ export const buildControlAdapter = (
|
||||
): ControlAdapterConfig => {
|
||||
switch (type) {
|
||||
case 'controlnet':
|
||||
return merge(cloneDeep(initialControlNet), { id, ...overrides });
|
||||
return merge(deepClone(initialControlNet), { id, ...overrides });
|
||||
case 't2i_adapter':
|
||||
return merge(cloneDeep(initialT2IAdapter), { id, ...overrides });
|
||||
return merge(deepClone(initialT2IAdapter), { id, ...overrides });
|
||||
case 'ip_adapter':
|
||||
return merge(cloneDeep(initialIPAdapter), { id, ...overrides });
|
||||
return merge(deepClone(initialIPAdapter), { id, ...overrides });
|
||||
default:
|
||||
throw new Error(`Unknown control adapter type: ${type}`);
|
||||
}
|
||||
|
@ -1,9 +1,9 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { cloneDeep } from 'lodash-es';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
|
||||
export type LoRA = {
|
||||
@ -58,7 +58,7 @@ export const loraSlice = createSlice({
|
||||
}
|
||||
lora.isEnabled = isEnabled;
|
||||
},
|
||||
lorasReset: () => cloneDeep(initialLoraState),
|
||||
lorasReset: () => deepClone(initialLoraState),
|
||||
},
|
||||
});
|
||||
|
||||
@ -74,7 +74,7 @@ const migrateLoRAState = (state: any): any => {
|
||||
}
|
||||
if (state._version === 1) {
|
||||
// Model type has changed, so we need to reset the state - too risky to migrate
|
||||
state = cloneDeep(initialLoraState);
|
||||
state = deepClone(initialLoraState);
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
@ -1,6 +1,7 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { workflowLoaded } from 'features/nodes/store/actions';
|
||||
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
|
||||
import type {
|
||||
@ -44,7 +45,7 @@ import {
|
||||
} from 'features/nodes/types/field';
|
||||
import type { AnyNode, InvocationTemplate, NodeExecutionState } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode, isNotesNode, zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { cloneDeep, forEach } from 'lodash-es';
|
||||
import { forEach } from 'lodash-es';
|
||||
import type {
|
||||
Connection,
|
||||
Edge,
|
||||
@ -571,8 +572,23 @@ export const nodesSlice = createSlice({
|
||||
);
|
||||
},
|
||||
selectionCopied: (state) => {
|
||||
state.nodesToCopy = state.nodes.filter((n) => n.selected).map(cloneDeep);
|
||||
state.edgesToCopy = state.edges.filter((e) => e.selected).map(cloneDeep);
|
||||
const nodesToCopy: AnyNode[] = [];
|
||||
const edgesToCopy: Edge[] = [];
|
||||
|
||||
for (const node of state.nodes) {
|
||||
if (node.selected) {
|
||||
nodesToCopy.push(deepClone(node));
|
||||
}
|
||||
}
|
||||
|
||||
for (const edge of state.edges) {
|
||||
if (edge.selected) {
|
||||
edgesToCopy.push(deepClone(edge));
|
||||
}
|
||||
}
|
||||
|
||||
state.nodesToCopy = nodesToCopy;
|
||||
state.edgesToCopy = edgesToCopy;
|
||||
|
||||
if (state.nodesToCopy.length > 0) {
|
||||
const averagePosition = { x: 0, y: 0 };
|
||||
@ -594,11 +610,21 @@ export const nodesSlice = createSlice({
|
||||
},
|
||||
selectionPasted: (state, action: PayloadAction<{ cursorPosition?: XYPosition }>) => {
|
||||
const { cursorPosition } = action.payload;
|
||||
const newNodes = state.nodesToCopy.map(cloneDeep);
|
||||
const newNodes: AnyNode[] = [];
|
||||
|
||||
for (const node of state.nodesToCopy) {
|
||||
newNodes.push(deepClone(node));
|
||||
}
|
||||
|
||||
const oldNodeIds = newNodes.map((n) => n.data.id);
|
||||
const newEdges = state.edgesToCopy
|
||||
.filter((e) => oldNodeIds.includes(e.source) && oldNodeIds.includes(e.target))
|
||||
.map(cloneDeep);
|
||||
|
||||
const newEdges: Edge[] = [];
|
||||
|
||||
for (const edge of state.edgesToCopy) {
|
||||
if (oldNodeIds.includes(edge.source) && oldNodeIds.includes(edge.target)) {
|
||||
newEdges.push(deepClone(edge));
|
||||
}
|
||||
}
|
||||
|
||||
newEdges.forEach((e) => (e.selected = true));
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { workflowLoaded } from 'features/nodes/store/actions';
|
||||
import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesChanged, nodesDeleted } from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
@ -11,7 +12,7 @@ import type {
|
||||
import type { FieldIdentifier } from 'features/nodes/types/field';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import type { WorkflowCategory, WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { cloneDeep, isEqual, omit, uniqBy } from 'lodash-es';
|
||||
import { isEqual, omit, uniqBy } from 'lodash-es';
|
||||
|
||||
const blankWorkflow: Omit<WorkflowV3, 'nodes' | 'edges'> = {
|
||||
name: '',
|
||||
@ -131,8 +132,8 @@ export const workflowSlice = createSlice({
|
||||
});
|
||||
|
||||
return {
|
||||
...cloneDeep(initialWorkflowState),
|
||||
...cloneDeep(workflowExtra),
|
||||
...deepClone(initialWorkflowState),
|
||||
...deepClone(workflowExtra),
|
||||
originalExposedFieldValues,
|
||||
mode: state.mode,
|
||||
};
|
||||
@ -144,7 +145,7 @@ export const workflowSlice = createSlice({
|
||||
});
|
||||
});
|
||||
|
||||
builder.addCase(nodeEditorReset, () => cloneDeep(initialWorkflowState));
|
||||
builder.addCase(nodeEditorReset, () => deepClone(initialWorkflowState));
|
||||
|
||||
builder.addCase(nodesChanged, (state, action) => {
|
||||
// Not all changes to nodes should result in the workflow being marked touched
|
||||
|
@ -1,8 +1,9 @@
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { satisfies } from 'compare-versions';
|
||||
import { NodeUpdateError } from 'features/nodes/types/error';
|
||||
import type { InvocationNode, InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { zParsedSemver } from 'features/nodes/types/semver';
|
||||
import { cloneDeep, defaultsDeep, keys, pick } from 'lodash-es';
|
||||
import { defaultsDeep, keys, pick } from 'lodash-es';
|
||||
|
||||
import { buildInvocationNode } from './buildInvocationNode';
|
||||
|
||||
@ -50,7 +51,7 @@ export const updateNode = (node: InvocationNode, template: InvocationTemplate):
|
||||
// The updateability of a node, via semver comparison, relies on the this kind of recursive merge
|
||||
// being valid. We rely on the template's major version to be majorly incremented if this kind of
|
||||
// merge would result in an invalid node.
|
||||
const clone = cloneDeep(node);
|
||||
const clone = deepClone(node);
|
||||
clone.data.version = template.version;
|
||||
defaultsDeep(clone, defaults); // mutates!
|
||||
|
||||
|
@ -1,11 +1,12 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import type { NodesState, WorkflowsState } from 'features/nodes/store/types';
|
||||
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { zWorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import i18n from 'i18n';
|
||||
import { cloneDeep, pick } from 'lodash-es';
|
||||
import { pick } from 'lodash-es';
|
||||
import { fromZodError } from 'zod-validation-error';
|
||||
|
||||
export type BuildWorkflowArg = {
|
||||
@ -30,7 +31,7 @@ const workflowKeys = [
|
||||
type BuildWorkflowFunction = (arg: BuildWorkflowArg) => WorkflowV3;
|
||||
|
||||
export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV3 => {
|
||||
const clonedWorkflow = pick(cloneDeep(workflow), workflowKeys);
|
||||
const clonedWorkflow = pick(deepClone(workflow), workflowKeys);
|
||||
|
||||
const newWorkflow: WorkflowV3 = {
|
||||
...clonedWorkflow,
|
||||
@ -43,14 +44,14 @@ export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflo
|
||||
newWorkflow.nodes.push({
|
||||
id: node.id,
|
||||
type: node.type,
|
||||
data: cloneDeep(node.data),
|
||||
data: deepClone(node.data),
|
||||
position: { ...node.position },
|
||||
});
|
||||
} else if (isNotesNode(node) && node.type) {
|
||||
newWorkflow.nodes.push({
|
||||
id: node.id,
|
||||
type: node.type,
|
||||
data: cloneDeep(node.data),
|
||||
data: deepClone(node.data),
|
||||
position: { ...node.position },
|
||||
});
|
||||
}
|
||||
|
@ -1,4 +1,5 @@
|
||||
import { $store } from 'app/store/nanostores/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error';
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
import type { InvocationNodeData } from 'features/nodes/types/invocation';
|
||||
@ -11,7 +12,7 @@ import { zWorkflowV2 } from 'features/nodes/types/v2/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { zWorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { t } from 'i18next';
|
||||
import { cloneDeep, forEach } from 'lodash-es';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { z } from 'zod';
|
||||
|
||||
/**
|
||||
@ -89,7 +90,7 @@ export const parseAndMigrateWorkflow = (data: unknown): WorkflowV3 => {
|
||||
throw new WorkflowVersionError(t('nodes.unableToGetWorkflowVersion'));
|
||||
}
|
||||
|
||||
let workflow = cloneDeep(data) as WorkflowV1 | WorkflowV2 | WorkflowV3;
|
||||
let workflow = deepClone(data) as WorkflowV1 | WorkflowV2 | WorkflowV3;
|
||||
|
||||
if (workflow.meta.version === '1.0.0') {
|
||||
const v1 = zWorkflowV1.parse(workflow);
|
||||
|
@ -280,6 +280,7 @@ const migrateGenerationState = (state: any): GenerationState => {
|
||||
// The signature of the model has changed, so we need to reset it
|
||||
state._version = 2;
|
||||
state.model = null;
|
||||
state.canvasCoherenceMode = initialGenerationState.canvasCoherenceMode;
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
File diff suppressed because one or more lines are too long
@ -1 +1 @@
|
||||
__version__ = "4.0.0rc6"
|
||||
__version__ = "4.0.1"
|
||||
|
@ -43,8 +43,7 @@ def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_fil
|
||||
assert model_record is not None
|
||||
assert model_record.name == "test_embedding"
|
||||
assert model_record.type == ModelType.TextualInversion
|
||||
assert model_record.path.endswith(embedding_file.as_posix())
|
||||
assert Path(model_record.path).is_absolute()
|
||||
assert Path(model_record.path) == embedding_file
|
||||
assert Path(model_record.path).exists()
|
||||
assert model_record.base == BaseModelType("sd-1")
|
||||
assert model_record.description is not None
|
||||
@ -77,8 +76,7 @@ def test_install(
|
||||
key = mm2_installer.install_path(embedding_file)
|
||||
model_record = store.get_model(key)
|
||||
assert model_record.path.endswith("sd-1/embedding/test_embedding.safetensors")
|
||||
assert Path(model_record.path).is_absolute()
|
||||
assert Path(model_record.path).exists()
|
||||
assert (mm2_app_config.models_path / model_record.path).exists()
|
||||
assert model_record.source == embedding_file.as_posix()
|
||||
|
||||
|
||||
@ -147,10 +145,7 @@ def test_background_install(
|
||||
model_record = mm2_installer.record_store.get_model(key)
|
||||
assert model_record is not None
|
||||
assert model_record.path.endswith(destination)
|
||||
assert Path(model_record.path).is_absolute()
|
||||
assert Path(model_record.path).exists()
|
||||
assert model_record.key != "<NOKEY>"
|
||||
assert Path(model_record.path).exists()
|
||||
assert (mm2_app_config.models_path / model_record.path).exists()
|
||||
|
||||
# see if metadata was properly passed through
|
||||
assert model_record.description == description
|
||||
@ -172,7 +167,7 @@ def test_not_inplace_install(
|
||||
assert job is not None
|
||||
assert job.config_out is not None
|
||||
assert Path(job.config_out.path) != embedding_file
|
||||
assert Path(job.config_out.path).exists()
|
||||
assert (mm2_app_config.models_path / job.config_out.path).exists()
|
||||
|
||||
|
||||
def test_inplace_install(
|
||||
@ -184,16 +179,21 @@ def test_inplace_install(
|
||||
assert job is not None
|
||||
assert job.config_out is not None
|
||||
assert Path(job.config_out.path) == embedding_file
|
||||
assert Path(job.config_out.path).exists()
|
||||
|
||||
|
||||
def test_delete_install(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
||||
def test_delete_install(
|
||||
mm2_installer: ModelInstallServiceBase, embedding_file: Path, mm2_app_config: InvokeAIAppConfig
|
||||
) -> None:
|
||||
store = mm2_installer.record_store
|
||||
key = mm2_installer.install_path(embedding_file)
|
||||
model_record = store.get_model(key)
|
||||
assert Path(model_record.path).exists()
|
||||
assert (mm2_app_config.models_path / model_record.path).exists()
|
||||
assert embedding_file.exists() # original should still be there after installation
|
||||
mm2_installer.delete(key)
|
||||
assert not Path(model_record.path).exists() # after deletion, installed copy should not exist
|
||||
assert not (
|
||||
mm2_app_config.models_path / model_record.path
|
||||
).exists() # after deletion, installed copy should not exist
|
||||
assert embedding_file.exists() # but original should still be there
|
||||
with pytest.raises(UnknownModelException):
|
||||
store.get_model(key)
|
||||
@ -232,7 +232,7 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
|
||||
|
||||
key = job.config_out.key
|
||||
model_record = store.get_model(key)
|
||||
assert Path(model_record.path).exists()
|
||||
assert (mm2_app_config.models_path / model_record.path).exists()
|
||||
|
||||
assert len(bus.events) == 4
|
||||
event_names = [x.event_name for x in bus.events]
|
||||
@ -261,7 +261,7 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co
|
||||
|
||||
key = job.config_out.key
|
||||
model_record = store.get_model(key)
|
||||
assert Path(model_record.path).exists()
|
||||
assert (mm2_app_config.models_path / model_record.path).exists()
|
||||
assert model_record.type == ModelType.Main
|
||||
assert model_record.format == ModelFormat.Diffusers
|
||||
|
||||
|
@ -98,6 +98,32 @@ def test_migrate_v3_config_from_file(tmp_path: Path, patch_rootdir: None):
|
||||
assert not hasattr(config, "esrgan")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"legacy_conf_dir,expected_value,expected_is_set",
|
||||
[
|
||||
# not set, expected value is the default value
|
||||
("configs/stable-diffusion", Path("configs"), False),
|
||||
# not set, expected value is the default value
|
||||
("configs\\stable-diffusion", Path("configs"), False),
|
||||
# set, best-effort resolution of the path
|
||||
("partial_custom_path/stable-diffusion", Path("partial_custom_path"), True),
|
||||
# set, exact path
|
||||
("full/custom/path", Path("full/custom/path"), True),
|
||||
],
|
||||
)
|
||||
def test_migrate_v3_legacy_conf_dir_defaults(
|
||||
tmp_path: Path, patch_rootdir: None, legacy_conf_dir: str, expected_value: Path, expected_is_set: bool
|
||||
):
|
||||
"""Test reading configuration from a file."""
|
||||
config_content = f"InvokeAI:\n Paths:\n legacy_conf_dir: {legacy_conf_dir}"
|
||||
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
||||
temp_config_file.write_text(config_content)
|
||||
|
||||
config = load_and_migrate_config(temp_config_file)
|
||||
assert config.legacy_conf_dir == expected_value
|
||||
assert ("legacy_conf_dir" in config.model_fields_set) is expected_is_set
|
||||
|
||||
|
||||
def test_migrate_v3_backup(tmp_path: Path, patch_rootdir: None):
|
||||
"""Test the backup of the config file."""
|
||||
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
||||
|
@ -250,6 +250,32 @@ def test_migrator_runs_all_migrations_file(logger: Logger) -> None:
|
||||
db.conn.close()
|
||||
|
||||
|
||||
def test_migrator_backs_up_db(logger: Logger) -> None:
|
||||
with TemporaryDirectory() as tempdir:
|
||||
original_db_path = Path(tempdir) / "invokeai.db"
|
||||
db = SqliteDatabase(db_path=original_db_path, logger=logger, verbose=False)
|
||||
# Write some data to the db to test for successful backup
|
||||
temp_cursor = db.conn.cursor()
|
||||
temp_cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
db.conn.commit()
|
||||
# Set up the migrator
|
||||
migrator = SqliteMigrator(db=db)
|
||||
migrations = [Migration(from_version=i, to_version=i + 1, callback=create_migrate(i)) for i in range(0, 3)]
|
||||
for migration in migrations:
|
||||
migrator.register_migration(migration)
|
||||
migrator.run_migrations()
|
||||
# Must manually close else we get an error on Windows
|
||||
db.conn.close()
|
||||
assert original_db_path.exists()
|
||||
# We should have a backup file when we migrated a file db
|
||||
assert migrator._backup_path
|
||||
# Check that the test table exists as a proxy for successful backup
|
||||
with closing(sqlite3.connect(migrator._backup_path)) as backup_db_conn:
|
||||
backup_db_cursor = backup_db_conn.cursor()
|
||||
backup_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
|
||||
assert backup_db_cursor.fetchone() is not None
|
||||
|
||||
|
||||
def test_migrator_makes_no_changes_on_failed_migration(
|
||||
migrator: SqliteMigrator, migration_no_op: Migration, failing_migrate_callback: MigrateCallback
|
||||
) -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user