Merge branch 'main' into build/gitignore

This commit is contained in:
Lincoln Stein 2023-07-05 13:35:54 -04:00 committed by GitHub
commit 94740e440d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 30 additions and 9 deletions

View File

@ -9,6 +9,7 @@ from compel.prompt_parser import (Blend, Conjunction,
FlattenedPrompt, Fragment) FlattenedPrompt, Fragment)
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ...backend.model_management.models import ModelNotFoundException
from ...backend.model_management import BaseModelType, ModelType, SubModelType from ...backend.model_management import BaseModelType, ModelType, SubModelType
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
@ -86,10 +87,10 @@ class CompelInvocation(BaseInvocation):
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
).context.model ).context.model
) )
except Exception: except ModelNotFoundException:
# print(e) # print(e)
#import traceback #import traceback
# print(traceback.format_exc()) #print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found") print(f"Warn: trigger: \"{trigger}\" not found")
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\ with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\

View File

@ -29,7 +29,7 @@ import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from .model_manager import ModelManager from .model_manager import ModelManager
from .model_cache import ModelCache from picklescan.scanner import scan_file_path
from .models import BaseModelType, ModelVariantType from .models import BaseModelType, ModelVariantType
try: try:
@ -1014,7 +1014,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint = load_file(checkpoint_path) checkpoint = load_file(checkpoint_path)
else: else:
if scan_needed: if scan_needed:
ModelCache.scan_model(checkpoint_path, checkpoint_path) # scan model
scan_result = scan_file_path(checkpoint_path)
if scan_result.infected_files != 0:
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
checkpoint = torch.load(checkpoint_path) checkpoint = torch.load(checkpoint_path)
# sometimes there is a state_dict key and sometimes not # sometimes there is a state_dict key and sometimes not

View File

@ -655,6 +655,9 @@ class TextualInversionModel:
else: else:
result.embedding = next(iter(state_dict.values())) result.embedding = next(iter(state_dict.values()))
if len(result.embedding.shape) == 1:
result.embedding = result.embedding.unsqueeze(0)
if not isinstance(result.embedding, torch.Tensor): if not isinstance(result.embedding, torch.Tensor):
raise ValueError(f"Invalid embeddings file: {file_path.name}") raise ValueError(f"Invalid embeddings file: {file_path.name}")

View File

@ -249,7 +249,7 @@ from .model_cache import ModelCache, ModelLocker
from .models import ( from .models import (
BaseModelType, ModelType, SubModelType, BaseModelType, ModelType, SubModelType,
ModelError, SchedulerPredictionType, MODEL_CLASSES, ModelError, SchedulerPredictionType, MODEL_CLASSES,
ModelConfigBase, ModelConfigBase, ModelNotFoundException,
) )
# We are only starting to number the config file with release 3. # We are only starting to number the config file with release 3.
@ -409,7 +409,7 @@ class ModelManager(object):
if model_key not in self.models: if model_key not in self.models:
self.scan_models_directory(base_model=base_model, model_type=model_type) self.scan_models_directory(base_model=base_model, model_type=model_type)
if model_key not in self.models: if model_key not in self.models:
raise Exception(f"Model not found - {model_key}") raise ModelNotFoundException(f"Model not found - {model_key}")
model_config = self.models[model_key] model_config = self.models[model_key]
model_path = self.app_config.root_path / model_config.path model_path = self.app_config.root_path / model_config.path
@ -421,7 +421,7 @@ class ModelManager(object):
else: else:
self.models.pop(model_key, None) self.models.pop(model_key, None)
raise Exception(f"Model not found - {model_key}") raise ModelNotFoundException(f"Model not found - {model_key}")
# vae/movq override # vae/movq override
# TODO: # TODO:

View File

@ -2,7 +2,7 @@ import inspect
from enum import Enum from enum import Enum
from pydantic import BaseModel from pydantic import BaseModel
from typing import Literal, get_origin from typing import Literal, get_origin
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .vae import VaeModel from .vae import VaeModel
from .lora import LoRAModel from .lora import LoRAModel

View File

@ -15,6 +15,9 @@ from contextlib import suppress
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
class ModelNotFoundException(Exception):
pass
class BaseModelType(str, Enum): class BaseModelType(str, Enum):
StableDiffusion1 = "sd-1" StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2" StableDiffusion2 = "sd-2"

View File

@ -8,6 +8,7 @@ from .base import (
ModelType, ModelType,
SubModelType, SubModelType,
classproperty, classproperty,
ModelNotFoundException,
) )
# TODO: naming # TODO: naming
from ..lora import TextualInversionModel as TextualInversionModelRaw from ..lora import TextualInversionModel as TextualInversionModelRaw
@ -37,8 +38,15 @@ class TextualInversionModel(ModelBase):
if child_type is not None: if child_type is not None:
raise Exception("There is no child models in textual inversion") raise Exception("There is no child models in textual inversion")
checkpoint_path = self.model_path
if os.path.isdir(checkpoint_path):
checkpoint_path = os.path.join(checkpoint_path, "learned_embeds.bin")
if not os.path.exists(checkpoint_path):
raise ModelNotFoundException()
model = TextualInversionModelRaw.from_checkpoint( model = TextualInversionModelRaw.from_checkpoint(
file_path=self.model_path, file_path=checkpoint_path,
dtype=torch_dtype, dtype=torch_dtype,
) )

View File

@ -23,6 +23,7 @@ import { stateSelector } from 'app/store/store';
import { import {
imageDeletionConfirmed, imageDeletionConfirmed,
imageToDeleteCleared, imageToDeleteCleared,
isModalOpenChanged,
selectImageUsage, selectImageUsage,
} from '../store/imageDeletionSlice'; } from '../store/imageDeletionSlice';
@ -63,6 +64,7 @@ const DeleteImageModal = () => {
const handleClose = useCallback(() => { const handleClose = useCallback(() => {
dispatch(imageToDeleteCleared()); dispatch(imageToDeleteCleared());
dispatch(isModalOpenChanged(false));
}, [dispatch]); }, [dispatch]);
const handleDelete = useCallback(() => { const handleDelete = useCallback(() => {

View File

@ -31,6 +31,7 @@ const imageDeletion = createSlice({
}, },
imageToDeleteCleared: (state) => { imageToDeleteCleared: (state) => {
state.imageToDelete = null; state.imageToDelete = null;
state.isModalOpen = false;
}, },
}, },
}); });