Compare commits

..

4 Commits

Author SHA1 Message Date
e04c25eba7 Merge branch 'main' into fix/diffusers-embeddings 2023-08-01 22:14:43 +10:00
be61ffdbf6 Merge branch 'main' into fix/diffusers-embeddings 2023-08-01 00:21:50 -04:00
823b879329 Merge branch 'main' into fix/diffusers-embeddings 2023-07-31 21:04:08 -04:00
17c901aaf7 fix diffusers-style textual embeddings
- also fix a couple places where the wrong base was used for relative model paths
2023-07-31 21:00:12 -04:00
15 changed files with 135 additions and 890 deletions

View File

@ -34,10 +34,6 @@
cudaPackages.cudnn cudaPackages.cudnn
cudaPackages.cuda_nvrtc cudaPackages.cuda_nvrtc
cudatoolkit cudatoolkit
pkgconfig
libconfig
cmake
blas
freeglut freeglut
glib glib
gperf gperf
@ -46,12 +42,6 @@
libGLU libGLU
linuxPackages.nvidia_x11 linuxPackages.nvidia_x11
python python
(opencv4.override {
enableGtk3 = true;
enableFfmpeg = true;
enableCuda = true;
enableUnfree = true;
})
stdenv.cc stdenv.cc
stdenv.cc.cc.lib stdenv.cc.cc.lib
xorg.libX11 xorg.libX11

View File

@ -30,8 +30,6 @@ from ..services.invoker import Invoker
from ..services.processor import DefaultInvocationProcessor from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage from ..services.sqlite import SqliteItemStorage
from ..services.model_manager_service import ModelManagerService from ..services.model_manager_service import ModelManagerService
from ..services.batch_manager import BatchManager
from ..services.batch_manager_storage import SqliteBatchProcessStorage
from .events import FastAPIEventService from .events import FastAPIEventService
@ -118,15 +116,11 @@ class ApiDependencies:
) )
) )
batch_manager_storage = SqliteBatchProcessStorage(db_location)
batch_manager = BatchManager(batch_manager_storage)
services = InvocationServices( services = InvocationServices(
model_manager=ModelManagerService(config, logger), model_manager=ModelManagerService(config, logger),
events=events, events=events,
latents=latents, latents=latents,
images=images, images=images,
batch_manager=batch_manager,
boards=boards, boards=boards,
board_images=board_images, board_images=board_images,
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),

View File

@ -15,7 +15,6 @@ from ...services.graph import (
GraphExecutionState, GraphExecutionState,
NodeAlreadyExecutedError, NodeAlreadyExecutedError,
) )
from ...services.batch_manager import Batch, BatchProcess
from ...services.item_storage import PaginatedResults from ...services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
@ -38,37 +37,6 @@ async def create_session(
return session return session
@session_router.post(
"/batch",
operation_id="create_batch",
responses={
200: {"model": BatchProcess},
400: {"description": "Invalid json"},
},
)
async def create_batch(
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
batches: list[Batch] = Body(description="Batch config to apply to the given graph"),
) -> BatchProcess:
"""Creates and starts a new new batch process"""
batch_id = ApiDependencies.invoker.services.batch_manager.create_batch_process(batches, graph)
ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_id)
return {"batch_id":batch_id}
@session_router.delete(
"{batch_process_id}/batch",
operation_id="cancel_batch",
responses={202: {"description": "The batch is canceled"}},
)
async def cancel_batch(
batch_process_id: str = Path(description="The id of the batch process to cancel"),
) -> Response:
"""Creates and starts a new new batch process"""
ApiDependencies.invoker.services.batch_manager.cancel_batch_process(batch_process_id)
return Response(status_code=202)
@session_router.get( @session_router.get(
"/", "/",
operation_id="list_sessions", operation_id="list_sessions",

View File

@ -37,8 +37,6 @@ from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService from invokeai.app.services.urls import LocalUrlService
from invokeai.app.services.batch_manager import BatchManager
from invokeai.app.services.batch_manager_storage import SqliteBatchProcessStorage
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
@ -302,16 +300,12 @@ def invoke_cli():
) )
) )
batch_manager_storage = SqliteBatchProcessStorage(db_location)
batch_manager = BatchManager(batch_manager_storage)
services = InvocationServices( services = InvocationServices(
model_manager=model_manager, model_manager=model_manager,
events=events, events=events,
latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")), latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")),
images=images, images=images,
boards=boards, boards=boards,
batch_manager=batch_manager,
board_images=board_images, board_images=board_images,
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"), graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),

View File

@ -1,139 +0,0 @@
import networkx as nx
import copy
from abc import ABC, abstractmethod
from itertools import product
from pydantic import BaseModel, Field
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.graph import Graph, GraphExecutionState
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.batch_manager_storage import (
BatchProcessStorageBase,
BatchSessionNotFoundException,
Batch,
BatchProcess,
BatchSession,
BatchSessionChanges,
)
class BatchManagerBase(ABC):
@abstractmethod
def start(self, invoker: Invoker):
pass
@abstractmethod
def create_batch_process(self, batches: list[Batch], graph: Graph) -> str:
pass
@abstractmethod
def run_batch_process(self, batch_id: str):
pass
@abstractmethod
def cancel_batch_process(self, batch_process_id: str):
pass
class BatchManager(BatchManagerBase):
"""Responsible for managing currently running and scheduled batch jobs"""
__invoker: Invoker
__batches: list[BatchProcess]
__batch_process_storage: BatchProcessStorageBase
def __init__(self, batch_process_storage: BatchProcessStorageBase) -> None:
super().__init__()
self.__batch_process_storage = batch_process_storage
def start(self, invoker: Invoker) -> None:
# if we do want multithreading at some point, we could make this configurable
self.__invoker = invoker
self.__batches = list()
local_handler.register(event_name=EventServiceBase.session_event, _func=self.on_event)
async def on_event(self, event: Event):
event_name = event[1]["event"]
match event_name:
case "graph_execution_state_complete":
await self.process(event, False)
case "invocation_error":
await self.process(event, True)
return event
async def process(self, event: Event, err: bool):
data = event[1]["data"]
batch_session = self.__batch_process_storage.get_session(data["graph_execution_state_id"])
if not batch_session:
return
updateSession = BatchSessionChanges(
state='error' if err else 'completed'
)
batch_session = self.__batch_process_storage.update_session_state(
batch_session.batch_id,
batch_session.session_id,
updateSession,
)
self.run_batch_process(batch_session.batch_id)
def _create_batch_session(self, batch_process: BatchProcess, batch_indices: list[int]) -> GraphExecutionState:
graph = copy.deepcopy(batch_process.graph)
batches = batch_process.batches
g = graph.nx_graph_flat()
sorted_nodes = nx.topological_sort(g)
for npath in sorted_nodes:
node = graph.get_node(npath)
(index, batch) = next(((i, b) for i, b in enumerate(batches) if b.node_id in node.id), (None, None))
if batch:
batch_index = batch_indices[index]
datum = batch.data[batch_index]
for key in datum:
node.__dict__[key] = datum[key]
graph.update_node(npath, node)
return GraphExecutionState(graph=graph)
def run_batch_process(self, batch_id: str):
try:
created_session = self.__batch_process_storage.get_created_session(batch_id)
except BatchSessionNotFoundException:
return
ges = self.__invoker.services.graph_execution_manager.get(created_session.session_id)
self.__invoker.invoke(ges, invoke_all=True)
def _valid_batch_config(self, batch_process: BatchProcess) -> bool:
return True
def create_batch_process(self, batches: list[Batch], graph: Graph) -> str:
batch_process = BatchProcess(
batches=batches,
graph=graph,
)
if not self._valid_batch_config(batch_process):
return None
batch_process = self.__batch_process_storage.save(batch_process)
self._create_sessions(batch_process)
return batch_process.batch_id
def _create_sessions(self, batch_process: BatchProcess):
batch_indices = list()
for batch in batch_process.batches:
batch_indices.append(list(range(len(batch.data))))
all_batch_indices = product(*batch_indices)
for bi in all_batch_indices:
ges = self._create_batch_session(batch_process, bi)
self.__invoker.services.graph_execution_manager.set(ges)
batch_session = BatchSession(
batch_id=batch_process.batch_id,
session_id=ges.id,
state="created"
)
self.__batch_process_storage.create_session(batch_session)
def cancel_batch_process(self, batch_process_id: str):
self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id]

View File

@ -1,505 +0,0 @@
from abc import ABC, abstractmethod
from typing import cast
import uuid
import sqlite3
import threading
from typing import (
Any,
List,
Literal,
Optional,
Union,
)
import json
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
)
from invokeai.app.services.graph import Graph
from invokeai.app.models.image import ImageField
from pydantic import BaseModel, Field, Extra, parse_raw_as
invocations = BaseInvocation.get_invocations()
InvocationsUnion = Union[invocations] # type: ignore
BatchDataType = Union[str, int, float, ImageField]
class Batch(BaseModel):
data: list[dict[str, BatchDataType]] = Field(description="Mapping of node field to data value")
node_id: str = Field(description="ID of the node to batch")
class BatchSession(BaseModel):
batch_id: str = Field(description="Identifier for which batch this Index belongs to")
session_id: str = Field(description="Session ID Created for this Batch Index")
state: Literal["created", "completed", "inprogress", "error"] = Field(
description="Is this session created, completed, in progress, or errored?"
)
def uuid_string():
res = uuid.uuid4()
return str(res)
class BatchProcess(BaseModel):
batch_id: Optional[str] = Field(default_factory=uuid_string, description="Identifier for this batch")
batches: List[Batch] = Field(
description="List of batch configs to apply to this session",
default_factory=list,
)
graph: Graph = Field(description="The graph being executed")
class BatchSessionChanges(BaseModel, extra=Extra.forbid):
state: Literal["created", "completed", "inprogress", "error"] = Field(
description="Is this session created, completed, in progress, or errored?"
)
class BatchProcessNotFoundException(Exception):
"""Raised when an Batch Process record is not found."""
def __init__(self, message="BatchProcess record not found"):
super().__init__(message)
class BatchProcessSaveException(Exception):
"""Raised when an Batch Process record cannot be saved."""
def __init__(self, message="BatchProcess record not saved"):
super().__init__(message)
class BatchProcessDeleteException(Exception):
"""Raised when an Batch Process record cannot be deleted."""
def __init__(self, message="BatchProcess record not deleted"):
super().__init__(message)
class BatchSessionNotFoundException(Exception):
"""Raised when an Batch Session record is not found."""
def __init__(self, message="BatchSession record not found"):
super().__init__(message)
class BatchSessionSaveException(Exception):
"""Raised when an Batch Session record cannot be saved."""
def __init__(self, message="BatchSession record not saved"):
super().__init__(message)
class BatchSessionDeleteException(Exception):
"""Raised when an Batch Session record cannot be deleted."""
def __init__(self, message="BatchSession record not deleted"):
super().__init__(message)
class BatchProcessStorageBase(ABC):
"""Low-level service responsible for interfacing with the Batch Process record store."""
@abstractmethod
def delete(self, batch_id: str) -> None:
"""Deletes a Batch Process record."""
pass
@abstractmethod
def save(
self,
batch_process: BatchProcess,
) -> BatchProcess:
"""Saves a Batch Process record."""
pass
@abstractmethod
def get(
self,
batch_id: str,
) -> BatchProcess:
"""Gets a Batch Process record."""
pass
@abstractmethod
def create_session(
self,
session: BatchSession,
) -> BatchSession:
"""Creates a Batch Session attached to a Batch Process."""
pass
@abstractmethod
def get_session(
self,
session_id: str
) -> BatchSession:
"""Gets session by session_id"""
pass
@abstractmethod
def get_created_session(
self,
batch_id: str
) -> BatchSession:
"""Gets all created Batch Sessions for a given Batch Process id."""
pass
@abstractmethod
def get_created_sessions(
self,
batch_id: str
) -> List[BatchSession]:
"""Gets all created Batch Sessions for a given Batch Process id."""
pass
@abstractmethod
def update_session_state(
self,
batch_id: str,
session_id: str,
changes: BatchSessionChanges,
) -> BatchSession:
"""Updates the state of a Batch Session record."""
pass
class SqliteBatchProcessStorage(BatchProcessStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `batch_process` table and `batch_session` junction table."""
# Create the `batch_process` table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS batch_process (
batch_id TEXT NOT NULL PRIMARY KEY,
batches TEXT NOT NULL,
graph TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME
);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_batch_process_created_at ON batch_process (created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_batch_process_updated_at
AFTER UPDATE
ON batch_process FOR EACH ROW
BEGIN
UPDATE batch_process SET updated_at = current_timestamp
WHERE batch_id = old.batch_id;
END;
"""
)
# Create the `batch_session` junction table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS batch_session (
batch_id TEXT NOT NULL,
session_id TEXT NOT NULL,
state TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
-- enforce one-to-many relationship between batch_process and batch_session using PK
-- (we can extend this to many-to-many later)
PRIMARY KEY (batch_id,session_id),
FOREIGN KEY (batch_id) REFERENCES batch_process (batch_id) ON DELETE CASCADE
);
"""
)
# Add index for batch id
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id ON batch_session (batch_id);
"""
)
# Add index for batch id, sorted by created_at
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id_created_at ON batch_session (batch_id,created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_batch_session_updated_at
AFTER UPDATE
ON batch_session FOR EACH ROW
BEGIN
UPDATE batch_session SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE batch_id = old.batch_id AND session_id = old.session_id;
END;
"""
)
def delete(self, batch_id: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM batch_process
WHERE batch_id = ?;
""",
(batch_id,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchProcessDeleteException from e
except Exception as e:
self._conn.rollback()
raise BatchProcessDeleteException from e
finally:
self._lock.release()
def save(
self,
batch_process: BatchProcess,
) -> BatchProcess:
try:
self._lock.acquire()
batches = [batch.json() for batch in batch_process.batches]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO batch_process (batch_id, batches, graph)
VALUES (?, ?, ?);
""",
(batch_process.batch_id, json.dumps(batches), batch_process.graph.json()),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchProcessSaveException from e
finally:
self._lock.release()
return self.get(batch_process.batch_id)
def _deserialize_batch_process(self, session_dict: dict) -> BatchProcess:
"""Deserializes a batch session."""
# Retrieve all the values, setting "reasonable" defaults if they are not present.
batch_id = session_dict.get("batch_id", "unknown")
batches_raw = session_dict.get("batches", "unknown")
graph_raw = session_dict.get("graph", "unknown")
batches = json.loads(batches_raw)
batches = [parse_raw_as(Batch, batch) for batch in batches]
return BatchProcess(
batch_id=batch_id,
batches=batches,
graph=parse_raw_as(Graph, graph_raw),
)
def get(
self,
batch_id: str,
) -> BatchProcess:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_process
WHERE batch_id = ?;
""",
(batch_id,)
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchProcessNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BatchProcessNotFoundException
return self._deserialize_batch_process(dict(result))
def create_session(
self,
session: BatchSession,
) -> BatchSession:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state)
VALUES (?, ?, ?);
""",
(session.batch_id, session.session_id, session.state),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionSaveException from e
finally:
self._lock.release()
return self.get_session(session.session_id)
def get_session(
self,
session_id: str
) -> BatchSession:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_session
WHERE session_id= ?;
""",
(session_id,),
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BatchSessionNotFoundException
return self._deserialize_batch_session(dict(result))
def _deserialize_batch_session(self, session_dict: dict) -> BatchSession:
"""Deserializes a batch session."""
# Retrieve all the values, setting "reasonable" defaults if they are not present.
batch_id = session_dict.get("batch_id", "unknown")
session_id = session_dict.get("session_id", "unknown")
state = session_dict.get("state", "unknown")
return BatchSession(
batch_id=batch_id,
session_id=session_id,
state=state,
)
def get_created_session(
self,
batch_id: str
) -> BatchSession:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_session
WHERE batch_id = ? AND state = 'created';
""",
(batch_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BatchSessionNotFoundException
session = self._deserialize_batch_session(dict(result))
return session
def get_created_sessions(
self,
batch_id: str
) -> List[BatchSession]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_session
WHERE batch_id = ? AND state = created;
""",
(batch_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BatchSessionNotFoundException
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
return sessions
def update_session_state(
self,
batch_id: str,
session_id: str,
changes: BatchSessionChanges,
) -> BatchSession:
try:
self._lock.acquire()
# Change the state of a batch session
if changes.state is not None:
self._cursor.execute(
f"""--sql
UPDATE batch_session
SET state = ?
WHERE batch_id = ? AND session_id = ?;
""",
(changes.state, batch_id, session_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionSaveException from e
finally:
self._lock.release()
return self.get_session(session_id)

View File

@ -4,7 +4,6 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from logging import Logger from logging import Logger
from invokeai.app.services.batch_manager import BatchManagerBase
from invokeai.app.services.board_images import BoardImagesServiceABC from invokeai.app.services.board_images import BoardImagesServiceABC
from invokeai.app.services.boards import BoardServiceABC from invokeai.app.services.boards import BoardServiceABC
from invokeai.app.services.images import ImageServiceABC from invokeai.app.services.images import ImageServiceABC
@ -22,7 +21,6 @@ class InvocationServices:
"""Services that can be used by invocations""" """Services that can be used by invocations"""
# TODO: Just forward-declared everything due to circular dependencies. Fix structure. # TODO: Just forward-declared everything due to circular dependencies. Fix structure.
batch_manager: "BatchManagerBase"
board_images: "BoardImagesServiceABC" board_images: "BoardImagesServiceABC"
boards: "BoardServiceABC" boards: "BoardServiceABC"
configuration: "InvokeAIAppConfig" configuration: "InvokeAIAppConfig"
@ -38,7 +36,6 @@ class InvocationServices:
def __init__( def __init__(
self, self,
batch_manager: "BatchManagerBase",
board_images: "BoardImagesServiceABC", board_images: "BoardImagesServiceABC",
boards: "BoardServiceABC", boards: "BoardServiceABC",
configuration: "InvokeAIAppConfig", configuration: "InvokeAIAppConfig",
@ -52,7 +49,6 @@ class InvocationServices:
processor: "InvocationProcessorABC", processor: "InvocationProcessorABC",
queue: "InvocationQueueABC", queue: "InvocationQueueABC",
): ):
self.batch_manager = batch_manager
self.board_images = board_images self.board_images = board_images
self.boards = boards self.boards = boards
self.boards = boards self.boards = boards

View File

@ -29,7 +29,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._conn = sqlite3.connect( self._conn = sqlite3.connect(
self._filename, check_same_thread=False self._filename, check_same_thread=False
) # TODO: figure out a better threading solution ) # TODO: figure out a better threading solution
self._conn.execute('pragma journal_mode=wal')
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()
self._create_table() self._create_table()

View File

@ -651,7 +651,10 @@ class TextualInversionModel:
file_path = Path(file_path) file_path = Path(file_path)
result = cls() # TODO: result = cls() # TODO:
result.name = file_path.stem # TODO: if file_path.name == "learned_embeds.bin":
result.name = file_path.parent.name
else:
result.name = file_path.stem
if file_path.suffix == ".safetensors": if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu") state_dict = load_file(file_path.absolute().as_posix(), device="cpu")

View File

@ -188,7 +188,7 @@ class ModelCache(object):
cache_entry = self._cached_models.get(key, None) cache_entry = self._cached_models.get(key, None)
if cache_entry is None: if cache_entry is None:
self.logger.info( self.logger.info(
f"Loading model {model_path}, type {base_model.value}:{model_type.value}:{submodel.value if submodel else ''}" f"Loading model {model_path}, type {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
) )
# this will remove older cached models until # this will remove older cached models until

View File

@ -472,7 +472,7 @@ class ModelManager(object):
if submodel_type is not None and hasattr(model_config, submodel_type): if submodel_type is not None and hasattr(model_config, submodel_type):
override_path = getattr(model_config, submodel_type) override_path = getattr(model_config, submodel_type)
if override_path: if override_path:
model_path = self.app_config.root_path / override_path model_path = self.resolve_path(override_path)
model_type = submodel_type model_type = submodel_type
submodel_type = None submodel_type = None
model_class = MODEL_CLASSES[base_model][model_type] model_class = MODEL_CLASSES[base_model][model_type]
@ -670,7 +670,7 @@ class ModelManager(object):
# TODO: if path changed and old_model.path inside models folder should we delete this too? # TODO: if path changed and old_model.path inside models folder should we delete this too?
# remove conversion cache as config changed # remove conversion cache as config changed
old_model_path = self.app_config.root_path / old_model.path old_model_path = self.resolve_model_path(old_model.path)
old_model_cache = self._get_model_cache_path(old_model_path) old_model_cache = self._get_model_cache_path(old_model_path)
if old_model_cache.exists(): if old_model_cache.exists():
if old_model_cache.is_dir(): if old_model_cache.is_dir():
@ -780,7 +780,7 @@ class ModelManager(object):
model_type, model_type,
**submodel, **submodel,
) )
checkpoint_path = self.app_config.root_path / info["path"] checkpoint_path = self.resolve_model_path(info["path"])
old_diffusers_path = self.resolve_model_path(model.location) old_diffusers_path = self.resolve_model_path(model.location)
new_diffusers_path = ( new_diffusers_path = (
dest_directory or self.app_config.models_path / base_model.value / model_type.value dest_directory or self.app_config.models_path / base_model.value / model_type.value
@ -992,7 +992,7 @@ class ModelManager(object):
model_manager=self, model_manager=self,
prediction_type_helper=ask_user_for_prediction_type, prediction_type_helper=ask_user_for_prediction_type,
) )
known_paths = {config.root_path / x["path"] for x in self.list_models()} known_paths = {self.resolve_model_path(x["path"]) for x in self.list_models()}
directories = { directories = {
config.root_path / x config.root_path / x
for x in [ for x in [

View File

@ -1,4 +1,4 @@
import { ButtonGroup, Flex, Spinner, Text } from '@chakra-ui/react'; import { ButtonGroup, Flex, Text } from '@chakra-ui/react';
import { EntityState } from '@reduxjs/toolkit'; import { EntityState } from '@reduxjs/toolkit';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput'; import IAIInput from 'common/components/IAIInput';
@ -6,23 +6,23 @@ import { forEach } from 'lodash-es';
import type { ChangeEvent, PropsWithChildren } from 'react'; import type { ChangeEvent, PropsWithChildren } from 'react';
import { useCallback, useState } from 'react'; import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants';
import { import {
LoRAModelConfigEntity,
MainModelConfigEntity, MainModelConfigEntity,
OnnxModelConfigEntity, OnnxModelConfigEntity,
useGetLoRAModelsQuery,
useGetMainModelsQuery, useGetMainModelsQuery,
useGetOnnxModelsQuery, useGetOnnxModelsQuery,
useGetLoRAModelsQuery,
LoRAModelConfigEntity,
} from 'services/api/endpoints/models'; } from 'services/api/endpoints/models';
import ModelListItem from './ModelListItem'; import ModelListItem from './ModelListItem';
import { ALL_BASE_MODELS } from 'services/api/constants';
type ModelListProps = { type ModelListProps = {
selectedModelId: string | undefined; selectedModelId: string | undefined;
setSelectedModelId: (name: string | undefined) => void; setSelectedModelId: (name: string | undefined) => void;
}; };
type ModelFormat = 'all' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx'; type ModelFormat = 'images' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx';
type ModelType = 'main' | 'lora' | 'onnx'; type ModelType = 'main' | 'lora' | 'onnx';
@ -33,63 +33,47 @@ const ModelList = (props: ModelListProps) => {
const { t } = useTranslation(); const { t } = useTranslation();
const [nameFilter, setNameFilter] = useState<string>(''); const [nameFilter, setNameFilter] = useState<string>('');
const [modelFormatFilter, setModelFormatFilter] = const [modelFormatFilter, setModelFormatFilter] =
useState<CombinedModelFormat>('all'); useState<CombinedModelFormat>('images');
const { filteredDiffusersModels, isLoadingDiffusersModels } = const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
useGetMainModelsQuery(ALL_BASE_MODELS, { selectFromResult: ({ data }) => ({
selectFromResult: ({ data, isLoading }) => ({ filteredDiffusersModels: modelsFilter(
filteredDiffusersModels: modelsFilter( data,
data, 'main',
'main', 'diffusers',
'diffusers', nameFilter
nameFilter ),
), }),
isLoadingDiffusersModels: isLoading, });
}),
});
const { filteredCheckpointModels, isLoadingCheckpointModels } = const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
useGetMainModelsQuery(ALL_BASE_MODELS, { selectFromResult: ({ data }) => ({
selectFromResult: ({ data, isLoading }) => ({ filteredCheckpointModels: modelsFilter(
filteredCheckpointModels: modelsFilter( data,
data, 'main',
'main', 'checkpoint',
'checkpoint', nameFilter
nameFilter ),
), }),
isLoadingCheckpointModels: isLoading, });
}),
});
const { filteredLoraModels, isLoadingLoraModels } = useGetLoRAModelsQuery( const { filteredLoraModels } = useGetLoRAModelsQuery(undefined, {
undefined, selectFromResult: ({ data }) => ({
{ filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter),
selectFromResult: ({ data, isLoading }) => ({ }),
filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter), });
isLoadingLoraModels: isLoading,
}),
}
);
const { filteredOnnxModels, isLoadingOnnxModels } = useGetOnnxModelsQuery( const { filteredOnnxModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, {
ALL_BASE_MODELS, selectFromResult: ({ data }) => ({
{ filteredOnnxModels: modelsFilter(data, 'onnx', 'onnx', nameFilter),
selectFromResult: ({ data, isLoading }) => ({ }),
filteredOnnxModels: modelsFilter(data, 'onnx', 'onnx', nameFilter), });
isLoadingOnnxModels: isLoading,
}),
}
);
const { filteredOliveModels, isLoadingOliveModels } = useGetOnnxModelsQuery( const { filteredOliveModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, {
ALL_BASE_MODELS, selectFromResult: ({ data }) => ({
{ filteredOliveModels: modelsFilter(data, 'onnx', 'olive', nameFilter),
selectFromResult: ({ data, isLoading }) => ({ }),
filteredOliveModels: modelsFilter(data, 'onnx', 'olive', nameFilter), });
isLoadingOliveModels: isLoading,
}),
}
);
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => { const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setNameFilter(e.target.value); setNameFilter(e.target.value);
@ -100,8 +84,8 @@ const ModelList = (props: ModelListProps) => {
<Flex flexDirection="column" gap={4} paddingInlineEnd={4}> <Flex flexDirection="column" gap={4} paddingInlineEnd={4}>
<ButtonGroup isAttached> <ButtonGroup isAttached>
<IAIButton <IAIButton
onClick={() => setModelFormatFilter('all')} onClick={() => setModelFormatFilter('images')}
isChecked={modelFormatFilter === 'all'} isChecked={modelFormatFilter === 'images'}
size="sm" size="sm"
> >
{t('modelManager.allModels')} {t('modelManager.allModels')}
@ -155,76 +139,95 @@ const ModelList = (props: ModelListProps) => {
maxHeight={window.innerHeight - 280} maxHeight={window.innerHeight - 280}
overflow="scroll" overflow="scroll"
> >
{/* Diffusers List */} {['images', 'diffusers'].includes(modelFormatFilter) &&
{isLoadingDiffusersModels && (
<FetchingModelsLoader loadingMessage="Loading Diffusers..." />
)}
{['all', 'diffusers'].includes(modelFormatFilter) &&
!isLoadingDiffusersModels &&
filteredDiffusersModels.length > 0 && ( filteredDiffusersModels.length > 0 && (
<ModelListWrapper <StyledModelContainer>
title="Diffusers" <Flex sx={{ gap: 2, flexDir: 'column' }}>
modelList={filteredDiffusersModels} <Text variant="subtext" fontSize="sm">
selected={{ selectedModelId, setSelectedModelId }} Diffusers
key="diffusers" </Text>
/> {filteredDiffusersModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
)} )}
{/* Checkpoints List */} {['images', 'checkpoint'].includes(modelFormatFilter) &&
{isLoadingCheckpointModels && (
<FetchingModelsLoader loadingMessage="Loading Checkpoints..." />
)}
{['all', 'checkpoint'].includes(modelFormatFilter) &&
!isLoadingCheckpointModels &&
filteredCheckpointModels.length > 0 && ( filteredCheckpointModels.length > 0 && (
<ModelListWrapper <StyledModelContainer>
title="Checkpoints" <Flex sx={{ gap: 2, flexDir: 'column' }}>
modelList={filteredCheckpointModels} <Text variant="subtext" fontSize="sm">
selected={{ selectedModelId, setSelectedModelId }} Checkpoints
key="checkpoints" </Text>
/> {filteredCheckpointModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
)} )}
{['images', 'olive'].includes(modelFormatFilter) &&
{/* LoRAs List */}
{isLoadingLoraModels && (
<FetchingModelsLoader loadingMessage="Loading LoRAs..." />
)}
{['all', 'lora'].includes(modelFormatFilter) &&
!isLoadingLoraModels &&
filteredLoraModels.length > 0 && (
<ModelListWrapper
title="LoRAs"
modelList={filteredLoraModels}
selected={{ selectedModelId, setSelectedModelId }}
key="loras"
/>
)}
{/* Olive List */}
{isLoadingOliveModels && (
<FetchingModelsLoader loadingMessage="Loading Olives..." />
)}
{['all', 'olive'].includes(modelFormatFilter) &&
!isLoadingOliveModels &&
filteredOliveModels.length > 0 && ( filteredOliveModels.length > 0 && (
<ModelListWrapper <StyledModelContainer>
title="Olives" <Flex sx={{ gap: 2, flexDir: 'column' }}>
modelList={filteredOliveModels} <Text variant="subtext" fontSize="sm">
selected={{ selectedModelId, setSelectedModelId }} Olives
key="olive" </Text>
/> {filteredOliveModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
)} )}
{/* Onnx List */} {['images', 'onnx'].includes(modelFormatFilter) &&
{isLoadingOnnxModels && (
<FetchingModelsLoader loadingMessage="Loading ONNX..." />
)}
{['all', 'onnx'].includes(modelFormatFilter) &&
!isLoadingOnnxModels &&
filteredOnnxModels.length > 0 && ( filteredOnnxModels.length > 0 && (
<ModelListWrapper <StyledModelContainer>
title="ONNX" <Flex sx={{ gap: 2, flexDir: 'column' }}>
modelList={filteredOnnxModels} <Text variant="subtext" fontSize="sm">
selected={{ selectedModelId, setSelectedModelId }} Onnx
key="onnx" </Text>
/> {filteredOnnxModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
)}
{['images', 'lora'].includes(modelFormatFilter) &&
filteredLoraModels.length > 0 && (
<StyledModelContainer>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
LoRAs
</Text>
{filteredLoraModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
)} )}
</Flex> </Flex>
</Flex> </Flex>
@ -284,52 +287,3 @@ const StyledModelContainer = (props: PropsWithChildren) => {
</Flex> </Flex>
); );
}; };
type ModelListWrapperProps = {
title: string;
modelList:
| MainModelConfigEntity[]
| LoRAModelConfigEntity[]
| OnnxModelConfigEntity[];
selected: ModelListProps;
};
function ModelListWrapper(props: ModelListWrapperProps) {
const { title, modelList, selected } = props;
return (
<StyledModelContainer>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
{title}
</Text>
{modelList.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selected.selectedModelId === model.id}
setSelectedModelId={selected.setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
);
}
function FetchingModelsLoader({ loadingMessage }: { loadingMessage?: string }) {
return (
<StyledModelContainer>
<Flex
justifyContent="center"
alignItems="center"
flexDirection="column"
p={4}
gap={8}
>
<Spinner />
<Text variant="subtext">
{loadingMessage ? loadingMessage : 'Fetching...'}
</Text>
</Flex>
</StyledModelContainer>
);
}

View File

@ -181,7 +181,7 @@ export const modelsApi = api.injectEndpoints({
}, },
providesTags: (result, error, arg) => { providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [ const tags: ApiFullTagDescription[] = [
{ type: 'OnnxModel', id: LIST_TAG }, { id: 'OnnxModel', type: LIST_TAG },
]; ];
if (result) { if (result) {
@ -266,7 +266,6 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [ invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
], ],
}), }),
importMainModels: build.mutation< importMainModels: build.mutation<
@ -283,7 +282,6 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [ invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
], ],
}), }),
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({ addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
@ -297,7 +295,6 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [ invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
], ],
}), }),
deleteMainModels: build.mutation< deleteMainModels: build.mutation<
@ -313,7 +310,6 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [ invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
], ],
}), }),
convertMainModels: build.mutation< convertMainModels: build.mutation<
@ -330,7 +326,6 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [ invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
], ],
}), }),
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({ mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
@ -344,7 +339,6 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [ invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
], ],
}), }),
syncModels: build.mutation<SyncModelsResponse, void>({ syncModels: build.mutation<SyncModelsResponse, void>({
@ -357,7 +351,6 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [ invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
], ],
}), }),
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({ getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({

View File

@ -48,7 +48,6 @@ def mock_services() -> InvocationServices:
images=None, # type: ignore images=None, # type: ignore
latents=None, # type: ignore latents=None, # type: ignore
boards=None, # type: ignore boards=None, # type: ignore
batch_manager=None, # type: ignore
board_images=None, # type: ignore board_images=None, # type: ignore
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"), graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"),

View File

@ -40,7 +40,6 @@ def mock_services() -> InvocationServices:
logger=None, # type: ignore logger=None, # type: ignore
images=None, # type: ignore images=None, # type: ignore
latents=None, # type: ignore latents=None, # type: ignore
batch_manager=None, # type: ignore
boards=None, # type: ignore boards=None, # type: ignore
board_images=None, # type: ignore board_images=None, # type: ignore
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),