Compare commits

..

1 Commits

Author SHA1 Message Date
0b238b1ece Update probe to always use cpu for loading models 2024-04-03 16:29:38 -04:00
4 changed files with 2 additions and 43 deletions

View File

@ -1,6 +1,4 @@
import sqlite3
from contextlib import closing
from datetime import datetime
from pathlib import Path
from typing import Optional
@ -34,7 +32,6 @@ class SqliteMigrator:
self._db = db
self._logger = db.logger
self._migration_set = MigrationSet()
self._backup_path: Optional[Path] = None
def register_migration(self, migration: Migration) -> None:
"""Registers a migration."""
@ -58,18 +55,6 @@ class SqliteMigrator:
return False
self._logger.info("Database update needed")
# Make a backup of the db if it needs to be updated and is a file db
if self._db.db_path is not None:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
self._backup_path = self._db.db_path.parent / f"{self._db.db_path.stem}_backup_{timestamp}.db"
self._logger.info(f"Backing up database to {str(self._backup_path)}")
# Use SQLite to do the backup
with closing(sqlite3.connect(self._backup_path)) as backup_conn:
self._db.conn.backup(backup_conn)
else:
self._logger.info("Using in-memory database, no backup needed")
next_migration = self._migration_set.get(from_version=self._get_current_version(cursor))
while next_migration is not None:
self._run_migration(next_migration)

View File

@ -323,7 +323,7 @@ class ModelProbe(object):
with SilenceWarnings():
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
cls._scan_model(model_path.name, model_path)
model = torch.load(model_path)
model = torch.load(model_path, map_location="cpu")
assert isinstance(model, dict)
return model
else:

View File

@ -1 +1 @@
__version__ = "4.0.0"
__version__ = "4.0.0rc6"

View File

@ -250,32 +250,6 @@ def test_migrator_runs_all_migrations_file(logger: Logger) -> None:
db.conn.close()
def test_migrator_backs_up_db(logger: Logger) -> None:
with TemporaryDirectory() as tempdir:
original_db_path = Path(tempdir) / "invokeai.db"
db = SqliteDatabase(db_path=original_db_path, logger=logger, verbose=False)
# Write some data to the db to test for successful backup
temp_cursor = db.conn.cursor()
temp_cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
db.conn.commit()
# Set up the migrator
migrator = SqliteMigrator(db=db)
migrations = [Migration(from_version=i, to_version=i + 1, callback=create_migrate(i)) for i in range(0, 3)]
for migration in migrations:
migrator.register_migration(migration)
migrator.run_migrations()
# Must manually close else we get an error on Windows
db.conn.close()
assert original_db_path.exists()
# We should have a backup file when we migrated a file db
assert migrator._backup_path
# Check that the test table exists as a proxy for successful backup
with closing(sqlite3.connect(migrator._backup_path)) as backup_db_conn:
backup_db_cursor = backup_db_conn.cursor()
backup_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
assert backup_db_cursor.fetchone() is not None
def test_migrator_makes_no_changes_on_failed_migration(
migrator: SqliteMigrator, migration_no_op: Migration, failing_migrate_callback: MigrateCallback
) -> None: