Compare commits

...

9 Commits

6 changed files with 78 additions and 82 deletions

View File

@ -451,6 +451,7 @@ async def add_model_record(
)
async def install_model(
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
# TODO(MM2): Can we type this?
config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
@ -493,6 +494,7 @@ async def install_model(
source=source,
config=config,
access_token=access_token,
inplace=bool(inplace),
)
logger.info(f"Started installation of {source}")
except UnknownModelException as e:

View File

@ -2,6 +2,7 @@
# which are imported/used before parse_args() is called will get the default config values instead of the
# values from the command line or config file.
import sys
from contextlib import asynccontextmanager
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.version.invokeai_version import __version__
@ -71,9 +72,25 @@ logger = InvokeAILogger.get_logger(config=app_config)
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")
@asynccontextmanager
async def lifespan(app: FastAPI):
# Add startup event to load dependencies
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
yield
# Shut down threads
ApiDependencies.shutdown()
# Create the app
# TODO: create this all in a method so configuration/etc. can be passed in?
app = FastAPI(title="Invoke - Community Edition", docs_url=None, redoc_url=None, separate_input_output_schemas=False)
app = FastAPI(
title="Invoke - Community Edition",
docs_url=None,
redoc_url=None,
separate_input_output_schemas=False,
lifespan=lifespan,
)
# Add event handler
event_handler_id: int = id(app)
@ -96,18 +113,6 @@ app.add_middleware(
app.add_middleware(GZipMiddleware, minimum_size=1000)
# Add startup event to load dependencies
@app.on_event("startup")
async def startup_event() -> None:
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
# Shut down threads
@app.on_event("shutdown")
async def shutdown_event() -> None:
ApiDependencies.shutdown()
# Include all routers
app.include_router(utilities.utilities_router, prefix="/api")
app.include_router(model_manager.model_manager_router, prefix="/api")

View File

@ -1,17 +1,11 @@
from typing import Iterator, List, Optional, Tuple, Union
from typing import Iterator, List, Optional, Tuple, Union, cast
import torch
from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from transformers import CLIPTextModel, CLIPTokenizer
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
OutputField,
UIComponent,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
@ -25,12 +19,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
)
from invokeai.backend.util.devices import torch_dtype
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .model import ClipField
# unconditioned: Optional[torch.Tensor]
@ -149,7 +138,7 @@ class SDXLPromptInvocationBase:
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, CLIPTextModel)
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
# return zero on empty
if prompt == "" and zero_on_empty:
@ -196,7 +185,8 @@ class SDXLPromptInvocationBase:
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
):
assert isinstance(text_encoder, CLIPTextModel)
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
text_encoder = cast(CLIPTextModel, text_encoder)
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,

View File

@ -181,13 +181,14 @@ class ModelInstallService(ModelInstallServiceBase):
source: str,
config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None,
inplace: bool = False,
) -> ModelInstallJob:
variants = "|".join(ModelRepoVariant.__members__.values())
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
source_obj: Optional[StringLikeSource] = None
if Path(source).exists(): # A local file or directory
source_obj = LocalModelSource(path=Path(source))
source_obj = LocalModelSource(path=Path(source), inplace=inplace)
elif match := re.match(hf_repoid_re, source):
source_obj = HFModelSource(
repo_id=match.group(1),

View File

@ -28,6 +28,7 @@ from typing import Callable, Optional, Set, Union
from pydantic import BaseModel, Field
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
default_logger: Logger = InvokeAILogger.get_logger()
@ -117,13 +118,10 @@ class ModelSearch(ModelSearchBase):
"""
models_found: Set[Path] = Field(default_factory=set)
scanned_dirs: Set[Path] = Field(default_factory=set)
pruned_paths: Set[Path] = Field(default_factory=set)
config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
def search_started(self) -> None:
self.models_found = set()
self.scanned_dirs = set()
self.pruned_paths = set()
if self.on_search_started:
self.on_search_started(self._directory)
@ -139,29 +137,28 @@ class ModelSearch(ModelSearchBase):
def search(self, directory: Union[Path, str]) -> Set[Path]:
self._directory = Path(directory)
if not self._directory.is_absolute():
self._directory = self.config.models_path / self._directory
self.stats = SearchStats() # zero out
self.search_started() # This will initialize _models_found to empty
self._walk_directory(directory)
self._walk_directory(self._directory)
self.search_completed()
return self.models_found
def _walk_directory(self, path: Union[Path, str]) -> None:
for root, dirs, files in os.walk(path, followlinks=True):
# don't descend into directories that start with a "."
# to avoid the Mac .DS_STORE issue.
if str(Path(root).name).startswith("."):
self.pruned_paths.add(Path(root))
if any(Path(root).is_relative_to(x) for x in self.pruned_paths):
continue
self.stats.items_scanned += len(dirs) + len(files)
for d in dirs:
path = Path(root) / d
if path.parent in self.scanned_dirs:
self.scanned_dirs.add(path)
continue
def _walk_directory(self, path: Union[Path, str], max_depth: int = 20) -> None:
absolute_path = Path(path)
if (
len(absolute_path.parts) - len(self._directory.parts) > max_depth
or not absolute_path.exists()
or absolute_path.parent in self.models_found
):
return
entries = os.scandir(absolute_path.as_posix())
entries = [entry for entry in entries if not entry.name.startswith(".")]
dirs = [entry for entry in entries if entry.is_dir()]
file_names = [entry.name for entry in entries if entry.is_file()]
if any(
(path / x).exists()
x in file_names
for x in [
"config.json",
"model_index.json",
@ -170,22 +167,23 @@ class ModelSearch(ModelSearchBase):
"image_encoder.txt",
]
):
self.scanned_dirs.add(path)
try:
self.model_found(path)
self.model_found(absolute_path)
return
except KeyboardInterrupt:
raise
except Exception as e:
self.logger.warning(str(e))
return
for n in file_names:
if n.endswith((".ckpt", ".bin", ".pth", ".safetensors", ".pt")):
try:
self.model_found(absolute_path / n)
except KeyboardInterrupt:
raise
except Exception as e:
self.logger.warning(str(e))
for f in files:
path = Path(root) / f
if path.parent in self.scanned_dirs:
continue
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
try:
self.model_found(path)
except KeyboardInterrupt:
raise
except Exception as e:
self.logger.warning(str(e))
for d in dirs:
self._walk_directory(absolute_path / d)

View File

@ -4,12 +4,12 @@ from __future__ import annotations
import pickle
from contextlib import contextmanager
from typing import Any, Dict, Iterator, List, Optional, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers import OnnxRuntimeModel, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager import AnyModel
@ -168,7 +168,7 @@ class ModelPatcher:
def apply_ti(
cls,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection],
ti_list: List[Tuple[str, TextualInversionModelRaw]],
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
init_tokens_count = None
@ -265,7 +265,7 @@ class ModelPatcher:
@contextmanager
def apply_clip_skip(
cls,
text_encoder: CLIPTextModel,
text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection],
clip_skip: int,
) -> None:
skipped_layers = []