Compare commits

..

4 Commits

Author SHA1 Message Date
e04c25eba7 Merge branch 'main' into fix/diffusers-embeddings 2023-08-01 22:14:43 +10:00
be61ffdbf6 Merge branch 'main' into fix/diffusers-embeddings 2023-08-01 00:21:50 -04:00
823b879329 Merge branch 'main' into fix/diffusers-embeddings 2023-07-31 21:04:08 -04:00
17c901aaf7 fix diffusers-style textual embeddings
- also fix a couple places where the wrong base was used for relative model paths
2023-07-31 21:00:12 -04:00
15 changed files with 135 additions and 890 deletions

View File

@ -34,10 +34,6 @@
cudaPackages.cudnn
cudaPackages.cuda_nvrtc
cudatoolkit
pkgconfig
libconfig
cmake
blas
freeglut
glib
gperf
@ -46,12 +42,6 @@
libGLU
linuxPackages.nvidia_x11
python
(opencv4.override {
enableGtk3 = true;
enableFfmpeg = true;
enableCuda = true;
enableUnfree = true;
})
stdenv.cc
stdenv.cc.cc.lib
xorg.libX11

View File

@ -30,8 +30,6 @@ from ..services.invoker import Invoker
from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage
from ..services.model_manager_service import ModelManagerService
from ..services.batch_manager import BatchManager
from ..services.batch_manager_storage import SqliteBatchProcessStorage
from .events import FastAPIEventService
@ -118,15 +116,11 @@ class ApiDependencies:
)
)
batch_manager_storage = SqliteBatchProcessStorage(db_location)
batch_manager = BatchManager(batch_manager_storage)
services = InvocationServices(
model_manager=ModelManagerService(config, logger),
events=events,
latents=latents,
images=images,
batch_manager=batch_manager,
boards=boards,
board_images=board_images,
queue=MemoryInvocationQueue(),

View File

@ -15,7 +15,6 @@ from ...services.graph import (
GraphExecutionState,
NodeAlreadyExecutedError,
)
from ...services.batch_manager import Batch, BatchProcess
from ...services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies
@ -38,37 +37,6 @@ async def create_session(
return session
@session_router.post(
"/batch",
operation_id="create_batch",
responses={
200: {"model": BatchProcess},
400: {"description": "Invalid json"},
},
)
async def create_batch(
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
batches: list[Batch] = Body(description="Batch config to apply to the given graph"),
) -> BatchProcess:
"""Creates and starts a new new batch process"""
batch_id = ApiDependencies.invoker.services.batch_manager.create_batch_process(batches, graph)
ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_id)
return {"batch_id":batch_id}
@session_router.delete(
"{batch_process_id}/batch",
operation_id="cancel_batch",
responses={202: {"description": "The batch is canceled"}},
)
async def cancel_batch(
batch_process_id: str = Path(description="The id of the batch process to cancel"),
) -> Response:
"""Creates and starts a new new batch process"""
ApiDependencies.invoker.services.batch_manager.cancel_batch_process(batch_process_id)
return Response(status_code=202)
@session_router.get(
"/",
operation_id="list_sessions",

View File

@ -37,8 +37,6 @@ from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService
from invokeai.app.services.batch_manager import BatchManager
from invokeai.app.services.batch_manager_storage import SqliteBatchProcessStorage
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
@ -302,16 +300,12 @@ def invoke_cli():
)
)
batch_manager_storage = SqliteBatchProcessStorage(db_location)
batch_manager = BatchManager(batch_manager_storage)
services = InvocationServices(
model_manager=model_manager,
events=events,
latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")),
images=images,
boards=boards,
batch_manager=batch_manager,
board_images=board_images,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),

View File

@ -1,139 +0,0 @@
import networkx as nx
import copy
from abc import ABC, abstractmethod
from itertools import product
from pydantic import BaseModel, Field
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.graph import Graph, GraphExecutionState
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.batch_manager_storage import (
BatchProcessStorageBase,
BatchSessionNotFoundException,
Batch,
BatchProcess,
BatchSession,
BatchSessionChanges,
)
class BatchManagerBase(ABC):
@abstractmethod
def start(self, invoker: Invoker):
pass
@abstractmethod
def create_batch_process(self, batches: list[Batch], graph: Graph) -> str:
pass
@abstractmethod
def run_batch_process(self, batch_id: str):
pass
@abstractmethod
def cancel_batch_process(self, batch_process_id: str):
pass
class BatchManager(BatchManagerBase):
"""Responsible for managing currently running and scheduled batch jobs"""
__invoker: Invoker
__batches: list[BatchProcess]
__batch_process_storage: BatchProcessStorageBase
def __init__(self, batch_process_storage: BatchProcessStorageBase) -> None:
super().__init__()
self.__batch_process_storage = batch_process_storage
def start(self, invoker: Invoker) -> None:
# if we do want multithreading at some point, we could make this configurable
self.__invoker = invoker
self.__batches = list()
local_handler.register(event_name=EventServiceBase.session_event, _func=self.on_event)
async def on_event(self, event: Event):
event_name = event[1]["event"]
match event_name:
case "graph_execution_state_complete":
await self.process(event, False)
case "invocation_error":
await self.process(event, True)
return event
async def process(self, event: Event, err: bool):
data = event[1]["data"]
batch_session = self.__batch_process_storage.get_session(data["graph_execution_state_id"])
if not batch_session:
return
updateSession = BatchSessionChanges(
state='error' if err else 'completed'
)
batch_session = self.__batch_process_storage.update_session_state(
batch_session.batch_id,
batch_session.session_id,
updateSession,
)
self.run_batch_process(batch_session.batch_id)
def _create_batch_session(self, batch_process: BatchProcess, batch_indices: list[int]) -> GraphExecutionState:
graph = copy.deepcopy(batch_process.graph)
batches = batch_process.batches
g = graph.nx_graph_flat()
sorted_nodes = nx.topological_sort(g)
for npath in sorted_nodes:
node = graph.get_node(npath)
(index, batch) = next(((i, b) for i, b in enumerate(batches) if b.node_id in node.id), (None, None))
if batch:
batch_index = batch_indices[index]
datum = batch.data[batch_index]
for key in datum:
node.__dict__[key] = datum[key]
graph.update_node(npath, node)
return GraphExecutionState(graph=graph)
def run_batch_process(self, batch_id: str):
try:
created_session = self.__batch_process_storage.get_created_session(batch_id)
except BatchSessionNotFoundException:
return
ges = self.__invoker.services.graph_execution_manager.get(created_session.session_id)
self.__invoker.invoke(ges, invoke_all=True)
def _valid_batch_config(self, batch_process: BatchProcess) -> bool:
return True
def create_batch_process(self, batches: list[Batch], graph: Graph) -> str:
batch_process = BatchProcess(
batches=batches,
graph=graph,
)
if not self._valid_batch_config(batch_process):
return None
batch_process = self.__batch_process_storage.save(batch_process)
self._create_sessions(batch_process)
return batch_process.batch_id
def _create_sessions(self, batch_process: BatchProcess):
batch_indices = list()
for batch in batch_process.batches:
batch_indices.append(list(range(len(batch.data))))
all_batch_indices = product(*batch_indices)
for bi in all_batch_indices:
ges = self._create_batch_session(batch_process, bi)
self.__invoker.services.graph_execution_manager.set(ges)
batch_session = BatchSession(
batch_id=batch_process.batch_id,
session_id=ges.id,
state="created"
)
self.__batch_process_storage.create_session(batch_session)
def cancel_batch_process(self, batch_process_id: str):
self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id]

View File

@ -1,505 +0,0 @@
from abc import ABC, abstractmethod
from typing import cast
import uuid
import sqlite3
import threading
from typing import (
Any,
List,
Literal,
Optional,
Union,
)
import json
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
)
from invokeai.app.services.graph import Graph
from invokeai.app.models.image import ImageField
from pydantic import BaseModel, Field, Extra, parse_raw_as
invocations = BaseInvocation.get_invocations()
InvocationsUnion = Union[invocations] # type: ignore
BatchDataType = Union[str, int, float, ImageField]
class Batch(BaseModel):
data: list[dict[str, BatchDataType]] = Field(description="Mapping of node field to data value")
node_id: str = Field(description="ID of the node to batch")
class BatchSession(BaseModel):
batch_id: str = Field(description="Identifier for which batch this Index belongs to")
session_id: str = Field(description="Session ID Created for this Batch Index")
state: Literal["created", "completed", "inprogress", "error"] = Field(
description="Is this session created, completed, in progress, or errored?"
)
def uuid_string():
res = uuid.uuid4()
return str(res)
class BatchProcess(BaseModel):
batch_id: Optional[str] = Field(default_factory=uuid_string, description="Identifier for this batch")
batches: List[Batch] = Field(
description="List of batch configs to apply to this session",
default_factory=list,
)
graph: Graph = Field(description="The graph being executed")
class BatchSessionChanges(BaseModel, extra=Extra.forbid):
state: Literal["created", "completed", "inprogress", "error"] = Field(
description="Is this session created, completed, in progress, or errored?"
)
class BatchProcessNotFoundException(Exception):
"""Raised when an Batch Process record is not found."""
def __init__(self, message="BatchProcess record not found"):
super().__init__(message)
class BatchProcessSaveException(Exception):
"""Raised when an Batch Process record cannot be saved."""
def __init__(self, message="BatchProcess record not saved"):
super().__init__(message)
class BatchProcessDeleteException(Exception):
"""Raised when an Batch Process record cannot be deleted."""
def __init__(self, message="BatchProcess record not deleted"):
super().__init__(message)
class BatchSessionNotFoundException(Exception):
"""Raised when an Batch Session record is not found."""
def __init__(self, message="BatchSession record not found"):
super().__init__(message)
class BatchSessionSaveException(Exception):
"""Raised when an Batch Session record cannot be saved."""
def __init__(self, message="BatchSession record not saved"):
super().__init__(message)
class BatchSessionDeleteException(Exception):
"""Raised when an Batch Session record cannot be deleted."""
def __init__(self, message="BatchSession record not deleted"):
super().__init__(message)
class BatchProcessStorageBase(ABC):
"""Low-level service responsible for interfacing with the Batch Process record store."""
@abstractmethod
def delete(self, batch_id: str) -> None:
"""Deletes a Batch Process record."""
pass
@abstractmethod
def save(
self,
batch_process: BatchProcess,
) -> BatchProcess:
"""Saves a Batch Process record."""
pass
@abstractmethod
def get(
self,
batch_id: str,
) -> BatchProcess:
"""Gets a Batch Process record."""
pass
@abstractmethod
def create_session(
self,
session: BatchSession,
) -> BatchSession:
"""Creates a Batch Session attached to a Batch Process."""
pass
@abstractmethod
def get_session(
self,
session_id: str
) -> BatchSession:
"""Gets session by session_id"""
pass
@abstractmethod
def get_created_session(
self,
batch_id: str
) -> BatchSession:
"""Gets all created Batch Sessions for a given Batch Process id."""
pass
@abstractmethod
def get_created_sessions(
self,
batch_id: str
) -> List[BatchSession]:
"""Gets all created Batch Sessions for a given Batch Process id."""
pass
@abstractmethod
def update_session_state(
self,
batch_id: str,
session_id: str,
changes: BatchSessionChanges,
) -> BatchSession:
"""Updates the state of a Batch Session record."""
pass
class SqliteBatchProcessStorage(BatchProcessStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `batch_process` table and `batch_session` junction table."""
# Create the `batch_process` table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS batch_process (
batch_id TEXT NOT NULL PRIMARY KEY,
batches TEXT NOT NULL,
graph TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME
);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_batch_process_created_at ON batch_process (created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_batch_process_updated_at
AFTER UPDATE
ON batch_process FOR EACH ROW
BEGIN
UPDATE batch_process SET updated_at = current_timestamp
WHERE batch_id = old.batch_id;
END;
"""
)
# Create the `batch_session` junction table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS batch_session (
batch_id TEXT NOT NULL,
session_id TEXT NOT NULL,
state TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
-- enforce one-to-many relationship between batch_process and batch_session using PK
-- (we can extend this to many-to-many later)
PRIMARY KEY (batch_id,session_id),
FOREIGN KEY (batch_id) REFERENCES batch_process (batch_id) ON DELETE CASCADE
);
"""
)
# Add index for batch id
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id ON batch_session (batch_id);
"""
)
# Add index for batch id, sorted by created_at
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id_created_at ON batch_session (batch_id,created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_batch_session_updated_at
AFTER UPDATE
ON batch_session FOR EACH ROW
BEGIN
UPDATE batch_session SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE batch_id = old.batch_id AND session_id = old.session_id;
END;
"""
)
def delete(self, batch_id: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM batch_process
WHERE batch_id = ?;
""",
(batch_id,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchProcessDeleteException from e
except Exception as e:
self._conn.rollback()
raise BatchProcessDeleteException from e
finally:
self._lock.release()
def save(
self,
batch_process: BatchProcess,
) -> BatchProcess:
try:
self._lock.acquire()
batches = [batch.json() for batch in batch_process.batches]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO batch_process (batch_id, batches, graph)
VALUES (?, ?, ?);
""",
(batch_process.batch_id, json.dumps(batches), batch_process.graph.json()),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchProcessSaveException from e
finally:
self._lock.release()
return self.get(batch_process.batch_id)
def _deserialize_batch_process(self, session_dict: dict) -> BatchProcess:
"""Deserializes a batch session."""
# Retrieve all the values, setting "reasonable" defaults if they are not present.
batch_id = session_dict.get("batch_id", "unknown")
batches_raw = session_dict.get("batches", "unknown")
graph_raw = session_dict.get("graph", "unknown")
batches = json.loads(batches_raw)
batches = [parse_raw_as(Batch, batch) for batch in batches]
return BatchProcess(
batch_id=batch_id,
batches=batches,
graph=parse_raw_as(Graph, graph_raw),
)
def get(
self,
batch_id: str,
) -> BatchProcess:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_process
WHERE batch_id = ?;
""",
(batch_id,)
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchProcessNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BatchProcessNotFoundException
return self._deserialize_batch_process(dict(result))
def create_session(
self,
session: BatchSession,
) -> BatchSession:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state)
VALUES (?, ?, ?);
""",
(session.batch_id, session.session_id, session.state),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionSaveException from e
finally:
self._lock.release()
return self.get_session(session.session_id)
def get_session(
self,
session_id: str
) -> BatchSession:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_session
WHERE session_id= ?;
""",
(session_id,),
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BatchSessionNotFoundException
return self._deserialize_batch_session(dict(result))
def _deserialize_batch_session(self, session_dict: dict) -> BatchSession:
"""Deserializes a batch session."""
# Retrieve all the values, setting "reasonable" defaults if they are not present.
batch_id = session_dict.get("batch_id", "unknown")
session_id = session_dict.get("session_id", "unknown")
state = session_dict.get("state", "unknown")
return BatchSession(
batch_id=batch_id,
session_id=session_id,
state=state,
)
def get_created_session(
self,
batch_id: str
) -> BatchSession:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_session
WHERE batch_id = ? AND state = 'created';
""",
(batch_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BatchSessionNotFoundException
session = self._deserialize_batch_session(dict(result))
return session
def get_created_sessions(
self,
batch_id: str
) -> List[BatchSession]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_session
WHERE batch_id = ? AND state = created;
""",
(batch_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BatchSessionNotFoundException
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
return sessions
def update_session_state(
self,
batch_id: str,
session_id: str,
changes: BatchSessionChanges,
) -> BatchSession:
try:
self._lock.acquire()
# Change the state of a batch session
if changes.state is not None:
self._cursor.execute(
f"""--sql
UPDATE batch_session
SET state = ?
WHERE batch_id = ? AND session_id = ?;
""",
(changes.state, batch_id, session_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionSaveException from e
finally:
self._lock.release()
return self.get_session(session_id)

View File

@ -4,7 +4,6 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from logging import Logger
from invokeai.app.services.batch_manager import BatchManagerBase
from invokeai.app.services.board_images import BoardImagesServiceABC
from invokeai.app.services.boards import BoardServiceABC
from invokeai.app.services.images import ImageServiceABC
@ -22,7 +21,6 @@ class InvocationServices:
"""Services that can be used by invocations"""
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
batch_manager: "BatchManagerBase"
board_images: "BoardImagesServiceABC"
boards: "BoardServiceABC"
configuration: "InvokeAIAppConfig"
@ -38,7 +36,6 @@ class InvocationServices:
def __init__(
self,
batch_manager: "BatchManagerBase",
board_images: "BoardImagesServiceABC",
boards: "BoardServiceABC",
configuration: "InvokeAIAppConfig",
@ -52,7 +49,6 @@ class InvocationServices:
processor: "InvocationProcessorABC",
queue: "InvocationQueueABC",
):
self.batch_manager = batch_manager
self.board_images = board_images
self.boards = boards
self.boards = boards

View File

@ -29,7 +29,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._conn = sqlite3.connect(
self._filename, check_same_thread=False
) # TODO: figure out a better threading solution
self._conn.execute('pragma journal_mode=wal')
self._cursor = self._conn.cursor()
self._create_table()

View File

@ -651,7 +651,10 @@ class TextualInversionModel:
file_path = Path(file_path)
result = cls() # TODO:
result.name = file_path.stem # TODO:
if file_path.name == "learned_embeds.bin":
result.name = file_path.parent.name
else:
result.name = file_path.stem
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")

View File

@ -188,7 +188,7 @@ class ModelCache(object):
cache_entry = self._cached_models.get(key, None)
if cache_entry is None:
self.logger.info(
f"Loading model {model_path}, type {base_model.value}:{model_type.value}:{submodel.value if submodel else ''}"
f"Loading model {model_path}, type {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
)
# this will remove older cached models until

View File

@ -472,7 +472,7 @@ class ModelManager(object):
if submodel_type is not None and hasattr(model_config, submodel_type):
override_path = getattr(model_config, submodel_type)
if override_path:
model_path = self.app_config.root_path / override_path
model_path = self.resolve_path(override_path)
model_type = submodel_type
submodel_type = None
model_class = MODEL_CLASSES[base_model][model_type]
@ -670,7 +670,7 @@ class ModelManager(object):
# TODO: if path changed and old_model.path inside models folder should we delete this too?
# remove conversion cache as config changed
old_model_path = self.app_config.root_path / old_model.path
old_model_path = self.resolve_model_path(old_model.path)
old_model_cache = self._get_model_cache_path(old_model_path)
if old_model_cache.exists():
if old_model_cache.is_dir():
@ -780,7 +780,7 @@ class ModelManager(object):
model_type,
**submodel,
)
checkpoint_path = self.app_config.root_path / info["path"]
checkpoint_path = self.resolve_model_path(info["path"])
old_diffusers_path = self.resolve_model_path(model.location)
new_diffusers_path = (
dest_directory or self.app_config.models_path / base_model.value / model_type.value
@ -992,7 +992,7 @@ class ModelManager(object):
model_manager=self,
prediction_type_helper=ask_user_for_prediction_type,
)
known_paths = {config.root_path / x["path"] for x in self.list_models()}
known_paths = {self.resolve_model_path(x["path"]) for x in self.list_models()}
directories = {
config.root_path / x
for x in [

View File

@ -1,4 +1,4 @@
import { ButtonGroup, Flex, Spinner, Text } from '@chakra-ui/react';
import { ButtonGroup, Flex, Text } from '@chakra-ui/react';
import { EntityState } from '@reduxjs/toolkit';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
@ -6,23 +6,23 @@ import { forEach } from 'lodash-es';
import type { ChangeEvent, PropsWithChildren } from 'react';
import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants';
import {
LoRAModelConfigEntity,
MainModelConfigEntity,
OnnxModelConfigEntity,
useGetLoRAModelsQuery,
useGetMainModelsQuery,
useGetOnnxModelsQuery,
useGetLoRAModelsQuery,
LoRAModelConfigEntity,
} from 'services/api/endpoints/models';
import ModelListItem from './ModelListItem';
import { ALL_BASE_MODELS } from 'services/api/constants';
type ModelListProps = {
selectedModelId: string | undefined;
setSelectedModelId: (name: string | undefined) => void;
};
type ModelFormat = 'all' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx';
type ModelFormat = 'images' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx';
type ModelType = 'main' | 'lora' | 'onnx';
@ -33,63 +33,47 @@ const ModelList = (props: ModelListProps) => {
const { t } = useTranslation();
const [nameFilter, setNameFilter] = useState<string>('');
const [modelFormatFilter, setModelFormatFilter] =
useState<CombinedModelFormat>('all');
useState<CombinedModelFormat>('images');
const { filteredDiffusersModels, isLoadingDiffusersModels } =
useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data, isLoading }) => ({
filteredDiffusersModels: modelsFilter(
data,
'main',
'diffusers',
nameFilter
),
isLoadingDiffusersModels: isLoading,
}),
});
const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({
filteredDiffusersModels: modelsFilter(
data,
'main',
'diffusers',
nameFilter
),
}),
});
const { filteredCheckpointModels, isLoadingCheckpointModels } =
useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data, isLoading }) => ({
filteredCheckpointModels: modelsFilter(
data,
'main',
'checkpoint',
nameFilter
),
isLoadingCheckpointModels: isLoading,
}),
});
const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({
filteredCheckpointModels: modelsFilter(
data,
'main',
'checkpoint',
nameFilter
),
}),
});
const { filteredLoraModels, isLoadingLoraModels } = useGetLoRAModelsQuery(
undefined,
{
selectFromResult: ({ data, isLoading }) => ({
filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter),
isLoadingLoraModels: isLoading,
}),
}
);
const { filteredLoraModels } = useGetLoRAModelsQuery(undefined, {
selectFromResult: ({ data }) => ({
filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter),
}),
});
const { filteredOnnxModels, isLoadingOnnxModels } = useGetOnnxModelsQuery(
ALL_BASE_MODELS,
{
selectFromResult: ({ data, isLoading }) => ({
filteredOnnxModels: modelsFilter(data, 'onnx', 'onnx', nameFilter),
isLoadingOnnxModels: isLoading,
}),
}
);
const { filteredOnnxModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({
filteredOnnxModels: modelsFilter(data, 'onnx', 'onnx', nameFilter),
}),
});
const { filteredOliveModels, isLoadingOliveModels } = useGetOnnxModelsQuery(
ALL_BASE_MODELS,
{
selectFromResult: ({ data, isLoading }) => ({
filteredOliveModels: modelsFilter(data, 'onnx', 'olive', nameFilter),
isLoadingOliveModels: isLoading,
}),
}
);
const { filteredOliveModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({
filteredOliveModels: modelsFilter(data, 'onnx', 'olive', nameFilter),
}),
});
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setNameFilter(e.target.value);
@ -100,8 +84,8 @@ const ModelList = (props: ModelListProps) => {
<Flex flexDirection="column" gap={4} paddingInlineEnd={4}>
<ButtonGroup isAttached>
<IAIButton
onClick={() => setModelFormatFilter('all')}
isChecked={modelFormatFilter === 'all'}
onClick={() => setModelFormatFilter('images')}
isChecked={modelFormatFilter === 'images'}
size="sm"
>
{t('modelManager.allModels')}
@ -155,76 +139,95 @@ const ModelList = (props: ModelListProps) => {
maxHeight={window.innerHeight - 280}
overflow="scroll"
>
{/* Diffusers List */}
{isLoadingDiffusersModels && (
<FetchingModelsLoader loadingMessage="Loading Diffusers..." />
)}
{['all', 'diffusers'].includes(modelFormatFilter) &&
!isLoadingDiffusersModels &&
{['images', 'diffusers'].includes(modelFormatFilter) &&
filteredDiffusersModels.length > 0 && (
<ModelListWrapper
title="Diffusers"
modelList={filteredDiffusersModels}
selected={{ selectedModelId, setSelectedModelId }}
key="diffusers"
/>
<StyledModelContainer>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
Diffusers
</Text>
{filteredDiffusersModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
)}
{/* Checkpoints List */}
{isLoadingCheckpointModels && (
<FetchingModelsLoader loadingMessage="Loading Checkpoints..." />
)}
{['all', 'checkpoint'].includes(modelFormatFilter) &&
!isLoadingCheckpointModels &&
{['images', 'checkpoint'].includes(modelFormatFilter) &&
filteredCheckpointModels.length > 0 && (
<ModelListWrapper
title="Checkpoints"
modelList={filteredCheckpointModels}
selected={{ selectedModelId, setSelectedModelId }}
key="checkpoints"
/>
<StyledModelContainer>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
Checkpoints
</Text>
{filteredCheckpointModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
)}
{/* LoRAs List */}
{isLoadingLoraModels && (
<FetchingModelsLoader loadingMessage="Loading LoRAs..." />
)}
{['all', 'lora'].includes(modelFormatFilter) &&
!isLoadingLoraModels &&
filteredLoraModels.length > 0 && (
<ModelListWrapper
title="LoRAs"
modelList={filteredLoraModels}
selected={{ selectedModelId, setSelectedModelId }}
key="loras"
/>
)}
{/* Olive List */}
{isLoadingOliveModels && (
<FetchingModelsLoader loadingMessage="Loading Olives..." />
)}
{['all', 'olive'].includes(modelFormatFilter) &&
!isLoadingOliveModels &&
{['images', 'olive'].includes(modelFormatFilter) &&
filteredOliveModels.length > 0 && (
<ModelListWrapper
title="Olives"
modelList={filteredOliveModels}
selected={{ selectedModelId, setSelectedModelId }}
key="olive"
/>
<StyledModelContainer>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
Olives
</Text>
{filteredOliveModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
)}
{/* Onnx List */}
{isLoadingOnnxModels && (
<FetchingModelsLoader loadingMessage="Loading ONNX..." />
)}
{['all', 'onnx'].includes(modelFormatFilter) &&
!isLoadingOnnxModels &&
{['images', 'onnx'].includes(modelFormatFilter) &&
filteredOnnxModels.length > 0 && (
<ModelListWrapper
title="ONNX"
modelList={filteredOnnxModels}
selected={{ selectedModelId, setSelectedModelId }}
key="onnx"
/>
<StyledModelContainer>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
Onnx
</Text>
{filteredOnnxModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
)}
{['images', 'lora'].includes(modelFormatFilter) &&
filteredLoraModels.length > 0 && (
<StyledModelContainer>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
LoRAs
</Text>
{filteredLoraModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
)}
</Flex>
</Flex>
@ -284,52 +287,3 @@ const StyledModelContainer = (props: PropsWithChildren) => {
</Flex>
);
};
type ModelListWrapperProps = {
title: string;
modelList:
| MainModelConfigEntity[]
| LoRAModelConfigEntity[]
| OnnxModelConfigEntity[];
selected: ModelListProps;
};
function ModelListWrapper(props: ModelListWrapperProps) {
const { title, modelList, selected } = props;
return (
<StyledModelContainer>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
{title}
</Text>
{modelList.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selected.selectedModelId === model.id}
setSelectedModelId={selected.setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
);
}
function FetchingModelsLoader({ loadingMessage }: { loadingMessage?: string }) {
return (
<StyledModelContainer>
<Flex
justifyContent="center"
alignItems="center"
flexDirection="column"
p={4}
gap={8}
>
<Spinner />
<Text variant="subtext">
{loadingMessage ? loadingMessage : 'Fetching...'}
</Text>
</Flex>
</StyledModelContainer>
);
}

View File

@ -181,7 +181,7 @@ export const modelsApi = api.injectEndpoints({
},
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ type: 'OnnxModel', id: LIST_TAG },
{ id: 'OnnxModel', type: LIST_TAG },
];
if (result) {
@ -266,7 +266,6 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
],
}),
importMainModels: build.mutation<
@ -283,7 +282,6 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
],
}),
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
@ -297,7 +295,6 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
],
}),
deleteMainModels: build.mutation<
@ -313,7 +310,6 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
],
}),
convertMainModels: build.mutation<
@ -330,7 +326,6 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
],
}),
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
@ -344,7 +339,6 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
],
}),
syncModels: build.mutation<SyncModelsResponse, void>({
@ -357,7 +351,6 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
],
}),
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({

View File

@ -48,7 +48,6 @@ def mock_services() -> InvocationServices:
images=None, # type: ignore
latents=None, # type: ignore
boards=None, # type: ignore
batch_manager=None, # type: ignore
board_images=None, # type: ignore
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"),

View File

@ -40,7 +40,6 @@ def mock_services() -> InvocationServices:
logger=None, # type: ignore
images=None, # type: ignore
latents=None, # type: ignore
batch_manager=None, # type: ignore
boards=None, # type: ignore
board_images=None, # type: ignore
queue=MemoryInvocationQueue(),