Compare commits

...

9 Commits

6 changed files with 78 additions and 82 deletions

View File

@ -451,6 +451,7 @@ async def add_model_record(
) )
async def install_model( async def install_model(
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"), source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
# TODO(MM2): Can we type this? # TODO(MM2): Can we type this?
config: Optional[Dict[str, Any]] = Body( config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ", description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
@ -493,6 +494,7 @@ async def install_model(
source=source, source=source,
config=config, config=config,
access_token=access_token, access_token=access_token,
inplace=bool(inplace),
) )
logger.info(f"Started installation of {source}") logger.info(f"Started installation of {source}")
except UnknownModelException as e: except UnknownModelException as e:

View File

@ -2,6 +2,7 @@
# which are imported/used before parse_args() is called will get the default config values instead of the # which are imported/used before parse_args() is called will get the default config values instead of the
# values from the command line or config file. # values from the command line or config file.
import sys import sys
from contextlib import asynccontextmanager
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.version.invokeai_version import __version__ from invokeai.version.invokeai_version import __version__
@ -71,9 +72,25 @@ logger = InvokeAILogger.get_logger(config=app_config)
mimetypes.add_type("application/javascript", ".js") mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css") mimetypes.add_type("text/css", ".css")
@asynccontextmanager
async def lifespan(app: FastAPI):
# Add startup event to load dependencies
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
yield
# Shut down threads
ApiDependencies.shutdown()
# Create the app # Create the app
# TODO: create this all in a method so configuration/etc. can be passed in? # TODO: create this all in a method so configuration/etc. can be passed in?
app = FastAPI(title="Invoke - Community Edition", docs_url=None, redoc_url=None, separate_input_output_schemas=False) app = FastAPI(
title="Invoke - Community Edition",
docs_url=None,
redoc_url=None,
separate_input_output_schemas=False,
lifespan=lifespan,
)
# Add event handler # Add event handler
event_handler_id: int = id(app) event_handler_id: int = id(app)
@ -96,18 +113,6 @@ app.add_middleware(
app.add_middleware(GZipMiddleware, minimum_size=1000) app.add_middleware(GZipMiddleware, minimum_size=1000)
# Add startup event to load dependencies
@app.on_event("startup")
async def startup_event() -> None:
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
# Shut down threads
@app.on_event("shutdown")
async def shutdown_event() -> None:
ApiDependencies.shutdown()
# Include all routers # Include all routers
app.include_router(utilities.utilities_router, prefix="/api") app.include_router(utilities.utilities_router, prefix="/api")
app.include_router(model_manager.model_manager_router, prefix="/api") app.include_router(model_manager.model_manager_router, prefix="/api")

View File

@ -1,17 +1,11 @@
from typing import Iterator, List, Optional, Tuple, Union from typing import Iterator, List, Optional, Tuple, Union, cast
import torch import torch
from compel import Compel, ReturnedEmbeddingsType from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent
FieldDescriptions,
Input,
InputField,
OutputField,
UIComponent,
)
from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list from invokeai.app.util.ti_utils import generate_ti_list
@ -25,12 +19,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
) )
from invokeai.backend.util.devices import torch_dtype from invokeai.backend.util.devices import torch_dtype
from .baseinvocation import ( from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from .model import ClipField from .model import ClipField
# unconditioned: Optional[torch.Tensor] # unconditioned: Optional[torch.Tensor]
@ -149,7 +138,7 @@ class SDXLPromptInvocationBase:
assert isinstance(tokenizer_model, CLIPTokenizer) assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump()) text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
text_encoder_model = text_encoder_info.model text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, CLIPTextModel) assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
# return zero on empty # return zero on empty
if prompt == "" and zero_on_empty: if prompt == "" and zero_on_empty:
@ -196,7 +185,8 @@ class SDXLPromptInvocationBase:
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers), ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
): ):
assert isinstance(text_encoder, CLIPTextModel) assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
text_encoder = cast(CLIPTextModel, text_encoder)
compel = Compel( compel = Compel(
tokenizer=tokenizer, tokenizer=tokenizer,
text_encoder=text_encoder, text_encoder=text_encoder,

View File

@ -181,13 +181,14 @@ class ModelInstallService(ModelInstallServiceBase):
source: str, source: str,
config: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None, access_token: Optional[str] = None,
inplace: bool = False,
) -> ModelInstallJob: ) -> ModelInstallJob:
variants = "|".join(ModelRepoVariant.__members__.values()) variants = "|".join(ModelRepoVariant.__members__.values())
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$" hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
source_obj: Optional[StringLikeSource] = None source_obj: Optional[StringLikeSource] = None
if Path(source).exists(): # A local file or directory if Path(source).exists(): # A local file or directory
source_obj = LocalModelSource(path=Path(source)) source_obj = LocalModelSource(path=Path(source), inplace=inplace)
elif match := re.match(hf_repoid_re, source): elif match := re.match(hf_repoid_re, source):
source_obj = HFModelSource( source_obj = HFModelSource(
repo_id=match.group(1), repo_id=match.group(1),

View File

@ -28,6 +28,7 @@ from typing import Callable, Optional, Set, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
default_logger: Logger = InvokeAILogger.get_logger() default_logger: Logger = InvokeAILogger.get_logger()
@ -117,13 +118,10 @@ class ModelSearch(ModelSearchBase):
""" """
models_found: Set[Path] = Field(default_factory=set) models_found: Set[Path] = Field(default_factory=set)
scanned_dirs: Set[Path] = Field(default_factory=set) config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
pruned_paths: Set[Path] = Field(default_factory=set)
def search_started(self) -> None: def search_started(self) -> None:
self.models_found = set() self.models_found = set()
self.scanned_dirs = set()
self.pruned_paths = set()
if self.on_search_started: if self.on_search_started:
self.on_search_started(self._directory) self.on_search_started(self._directory)
@ -139,53 +137,53 @@ class ModelSearch(ModelSearchBase):
def search(self, directory: Union[Path, str]) -> Set[Path]: def search(self, directory: Union[Path, str]) -> Set[Path]:
self._directory = Path(directory) self._directory = Path(directory)
if not self._directory.is_absolute():
self._directory = self.config.models_path / self._directory
self.stats = SearchStats() # zero out self.stats = SearchStats() # zero out
self.search_started() # This will initialize _models_found to empty self.search_started() # This will initialize _models_found to empty
self._walk_directory(directory) self._walk_directory(self._directory)
self.search_completed() self.search_completed()
return self.models_found return self.models_found
def _walk_directory(self, path: Union[Path, str]) -> None: def _walk_directory(self, path: Union[Path, str], max_depth: int = 20) -> None:
for root, dirs, files in os.walk(path, followlinks=True): absolute_path = Path(path)
# don't descend into directories that start with a "." if (
# to avoid the Mac .DS_STORE issue. len(absolute_path.parts) - len(self._directory.parts) > max_depth
if str(Path(root).name).startswith("."): or not absolute_path.exists()
self.pruned_paths.add(Path(root)) or absolute_path.parent in self.models_found
if any(Path(root).is_relative_to(x) for x in self.pruned_paths): ):
continue return
entries = os.scandir(absolute_path.as_posix())
entries = [entry for entry in entries if not entry.name.startswith(".")]
dirs = [entry for entry in entries if entry.is_dir()]
file_names = [entry.name for entry in entries if entry.is_file()]
if any(
x in file_names
for x in [
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"image_encoder.txt",
]
):
try:
self.model_found(absolute_path)
return
except KeyboardInterrupt:
raise
except Exception as e:
self.logger.warning(str(e))
return
self.stats.items_scanned += len(dirs) + len(files) for n in file_names:
for d in dirs: if n.endswith((".ckpt", ".bin", ".pth", ".safetensors", ".pt")):
path = Path(root) / d try:
if path.parent in self.scanned_dirs: self.model_found(absolute_path / n)
self.scanned_dirs.add(path) except KeyboardInterrupt:
continue raise
if any( except Exception as e:
(path / x).exists() self.logger.warning(str(e))
for x in [
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"image_encoder.txt",
]
):
self.scanned_dirs.add(path)
try:
self.model_found(path)
except KeyboardInterrupt:
raise
except Exception as e:
self.logger.warning(str(e))
for f in files: for d in dirs:
path = Path(root) / f self._walk_directory(absolute_path / d)
if path.parent in self.scanned_dirs:
continue
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
try:
self.model_found(path)
except KeyboardInterrupt:
raise
except Exception as e:
self.logger.warning(str(e))

View File

@ -4,12 +4,12 @@ from __future__ import annotations
import pickle import pickle
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, Iterator, List, Optional, Tuple from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from diffusers import OnnxRuntimeModel, UNet2DConditionModel from diffusers import OnnxRuntimeModel, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from invokeai.app.shared.models import FreeUConfig from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager import AnyModel from invokeai.backend.model_manager import AnyModel
@ -168,7 +168,7 @@ class ModelPatcher:
def apply_ti( def apply_ti(
cls, cls,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel, text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection],
ti_list: List[Tuple[str, TextualInversionModelRaw]], ti_list: List[Tuple[str, TextualInversionModelRaw]],
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]: ) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
init_tokens_count = None init_tokens_count = None
@ -265,7 +265,7 @@ class ModelPatcher:
@contextmanager @contextmanager
def apply_clip_skip( def apply_clip_skip(
cls, cls,
text_encoder: CLIPTextModel, text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection],
clip_skip: int, clip_skip: int,
) -> None: ) -> None:
skipped_layers = [] skipped_layers = []