mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into ryan/multi-image-ip
This commit is contained in:
@ -20,12 +20,12 @@ class InvisibleWatermark:
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def invisible_watermark_available(self) -> bool:
|
||||
def invisible_watermark_available(cls) -> bool:
|
||||
return config.invisible_watermark
|
||||
|
||||
@classmethod
|
||||
def add_watermark(self, image: Image, watermark_text: str) -> Image:
|
||||
if not self.invisible_watermark_available():
|
||||
def add_watermark(cls, image: Image.Image, watermark_text: str) -> Image.Image:
|
||||
if not cls.invisible_watermark_available():
|
||||
return image
|
||||
logger.debug(f'Applying invisible watermark "{watermark_text}"')
|
||||
bgr = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
||||
|
@ -26,8 +26,8 @@ class SafetyChecker:
|
||||
tried_load: bool = False
|
||||
|
||||
@classmethod
|
||||
def _load_safety_checker(self):
|
||||
if self.tried_load:
|
||||
def _load_safety_checker(cls):
|
||||
if cls.tried_load:
|
||||
return
|
||||
|
||||
if config.nsfw_checker:
|
||||
@ -35,31 +35,31 @@ class SafetyChecker:
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH)
|
||||
self.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH)
|
||||
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH)
|
||||
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH)
|
||||
logger.info("NSFW checker initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load NSFW checker: {str(e)}")
|
||||
else:
|
||||
logger.info("NSFW checker loading disabled")
|
||||
self.tried_load = True
|
||||
cls.tried_load = True
|
||||
|
||||
@classmethod
|
||||
def safety_checker_available(self) -> bool:
|
||||
self._load_safety_checker()
|
||||
return self.safety_checker is not None
|
||||
def safety_checker_available(cls) -> bool:
|
||||
cls._load_safety_checker()
|
||||
return cls.safety_checker is not None
|
||||
|
||||
@classmethod
|
||||
def has_nsfw_concept(self, image: Image) -> bool:
|
||||
if not self.safety_checker_available():
|
||||
def has_nsfw_concept(cls, image: Image.Image) -> bool:
|
||||
if not cls.safety_checker_available():
|
||||
return False
|
||||
|
||||
device = choose_torch_device()
|
||||
features = self.feature_extractor([image], return_tensors="pt")
|
||||
features = cls.feature_extractor([image], return_tensors="pt")
|
||||
features.to(device)
|
||||
self.safety_checker.to(device)
|
||||
cls.safety_checker.to(device)
|
||||
x_image = np.array(image).astype(np.float32) / 255.0
|
||||
x_image = x_image[None].transpose(0, 3, 1, 2)
|
||||
with SilenceWarnings():
|
||||
checked_image, has_nsfw_concept = self.safety_checker(images=x_image, clip_input=features.pixel_values)
|
||||
checked_image, has_nsfw_concept = cls.safety_checker(images=x_image, clip_input=features.pixel_values)
|
||||
return has_nsfw_concept[0]
|
||||
|
@ -41,18 +41,18 @@ config = InvokeAIAppConfig.get_config()
|
||||
|
||||
|
||||
class SegmentedGrayscale(object):
|
||||
def __init__(self, image: Image, heatmap: torch.Tensor):
|
||||
def __init__(self, image: Image.Image, heatmap: torch.Tensor):
|
||||
self.heatmap = heatmap
|
||||
self.image = image
|
||||
|
||||
def to_grayscale(self, invert: bool = False) -> Image:
|
||||
def to_grayscale(self, invert: bool = False) -> Image.Image:
|
||||
return self._rescale(Image.fromarray(np.uint8(255 - self.heatmap * 255 if invert else self.heatmap * 255)))
|
||||
|
||||
def to_mask(self, threshold: float = 0.5) -> Image:
|
||||
def to_mask(self, threshold: float = 0.5) -> Image.Image:
|
||||
discrete_heatmap = self.heatmap.lt(threshold).int()
|
||||
return self._rescale(Image.fromarray(np.uint8(discrete_heatmap * 255), mode="L"))
|
||||
|
||||
def to_transparent(self, invert: bool = False) -> Image:
|
||||
def to_transparent(self, invert: bool = False) -> Image.Image:
|
||||
transparent_image = self.image.copy()
|
||||
# For img2img, we want the selected regions to be transparent,
|
||||
# but to_grayscale() returns the opposite. Thus invert.
|
||||
@ -61,7 +61,7 @@ class SegmentedGrayscale(object):
|
||||
return transparent_image
|
||||
|
||||
# unscales and uncrops the 352x352 heatmap so that it matches the image again
|
||||
def _rescale(self, heatmap: Image) -> Image:
|
||||
def _rescale(self, heatmap: Image.Image) -> Image.Image:
|
||||
size = self.image.width if (self.image.width > self.image.height) else self.image.height
|
||||
resized_image = heatmap.resize((size, size), resample=Image.Resampling.LANCZOS)
|
||||
return resized_image.crop((0, 0, self.image.width, self.image.height))
|
||||
@ -82,7 +82,7 @@ class Txt2Mask(object):
|
||||
self.model = CLIPSegForImageSegmentation.from_pretrained(CLIPSEG_MODEL, cache_dir=config.cache_dir)
|
||||
|
||||
@torch.no_grad()
|
||||
def segment(self, image, prompt: str) -> SegmentedGrayscale:
|
||||
def segment(self, image: Image.Image, prompt: str) -> SegmentedGrayscale:
|
||||
"""
|
||||
Given a prompt string such as "a bagel", tries to identify the object in the
|
||||
provided image and returns a SegmentedGrayscale object in which the brighter
|
||||
@ -99,7 +99,7 @@ class Txt2Mask(object):
|
||||
heatmap = torch.sigmoid(outputs.logits)
|
||||
return SegmentedGrayscale(image, heatmap)
|
||||
|
||||
def _scale_and_crop(self, image: Image) -> Image:
|
||||
def _scale_and_crop(self, image: Image.Image) -> Image.Image:
|
||||
scaled_image = Image.new("RGB", (CLIPSEG_SIZE, CLIPSEG_SIZE))
|
||||
if image.width > image.height: # width is constraint
|
||||
scale = CLIPSEG_SIZE / image.width
|
||||
|
@ -9,7 +9,7 @@ class InitImageResizer:
|
||||
def __init__(self, Image):
|
||||
self.image = Image
|
||||
|
||||
def resize(self, width=None, height=None) -> Image:
|
||||
def resize(self, width=None, height=None) -> Image.Image:
|
||||
"""
|
||||
Return a copy of the image resized to fit within
|
||||
a box width x height. The aspect ratio is
|
||||
|
@ -793,7 +793,11 @@ def migrate_init_file(legacy_format: Path):
|
||||
old = legacy_parser.parse_args([f"@{str(legacy_format)}"])
|
||||
new = InvokeAIAppConfig.get_config()
|
||||
|
||||
fields = [x for x, y in InvokeAIAppConfig.__fields__.items() if y.field_info.extra.get("category") != "DEPRECATED"]
|
||||
fields = [
|
||||
x
|
||||
for x, y in InvokeAIAppConfig.model_fields.items()
|
||||
if (y.json_schema_extra.get("category", None) if y.json_schema_extra else None) != "DEPRECATED"
|
||||
]
|
||||
for attr in fields:
|
||||
if hasattr(old, attr):
|
||||
try:
|
||||
|
@ -236,13 +236,13 @@ import types
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from shutil import move, rmtree
|
||||
from typing import Callable, Dict, List, Literal, Optional, Set, Tuple, Union
|
||||
from typing import Callable, Dict, List, Literal, Optional, Set, Tuple, Union, cast
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
@ -294,6 +294,8 @@ class AddModelResult(BaseModel):
|
||||
base_model: BaseModelType = Field(description="The base model")
|
||||
config: ModelConfigBase = Field(description="The configuration of the model")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
MAX_CACHE_SIZE = 6.0 # GB
|
||||
|
||||
@ -576,7 +578,7 @@ class ModelManager(object):
|
||||
"""
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
if model_key in self.models:
|
||||
return self.models[model_key].dict(exclude_defaults=True)
|
||||
return self.models[model_key].model_dump(exclude_defaults=True)
|
||||
else:
|
||||
return None # TODO: None or empty dict on not found
|
||||
|
||||
@ -632,7 +634,7 @@ class ModelManager(object):
|
||||
continue
|
||||
|
||||
model_dict = dict(
|
||||
**model_config.dict(exclude_defaults=True),
|
||||
**model_config.model_dump(exclude_defaults=True),
|
||||
# OpenAPIModelInfoBase
|
||||
model_name=cur_model_name,
|
||||
base_model=cur_base_model,
|
||||
@ -900,14 +902,16 @@ class ModelManager(object):
|
||||
Write current configuration out to the indicated file.
|
||||
"""
|
||||
data_to_save = dict()
|
||||
data_to_save["__metadata__"] = self.config_meta.dict()
|
||||
data_to_save["__metadata__"] = self.config_meta.model_dump()
|
||||
|
||||
for model_key, model_config in self.models.items():
|
||||
model_name, base_model, model_type = self.parse_key(model_key)
|
||||
model_class = self._get_implementation(base_model, model_type)
|
||||
if model_class.save_to_config:
|
||||
# TODO: or exclude_unset better fits here?
|
||||
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
|
||||
data_to_save[model_key] = cast(BaseModel, model_config).model_dump(
|
||||
exclude_defaults=True, exclude={"error"}, mode="json"
|
||||
)
|
||||
# alias for config file
|
||||
data_to_save[model_key]["format"] = data_to_save[model_key].pop("model_format")
|
||||
|
||||
|
@ -2,7 +2,7 @@ import inspect
|
||||
from enum import Enum
|
||||
from typing import Literal, get_origin
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict, create_model
|
||||
|
||||
from .base import ( # noqa: F401
|
||||
BaseModelType,
|
||||
@ -106,6 +106,8 @@ class OpenAPIModelInfoBase(BaseModel):
|
||||
base_model: BaseModelType
|
||||
model_type: ModelType
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
for base_model, models in MODEL_CLASSES.items():
|
||||
for model_type, model_class in models.items():
|
||||
@ -121,17 +123,11 @@ for base_model, models in MODEL_CLASSES.items():
|
||||
if openapi_cfg_name in vars():
|
||||
continue
|
||||
|
||||
api_wrapper = type(
|
||||
api_wrapper = create_model(
|
||||
openapi_cfg_name,
|
||||
(cfg, OpenAPIModelInfoBase),
|
||||
dict(
|
||||
__annotations__=dict(
|
||||
model_type=Literal[model_type.value],
|
||||
),
|
||||
),
|
||||
__base__=(cfg, OpenAPIModelInfoBase),
|
||||
model_type=(Literal[model_type], model_type), # type: ignore
|
||||
)
|
||||
|
||||
# globals()[openapi_cfg_name] = api_wrapper
|
||||
vars()[openapi_cfg_name] = api_wrapper
|
||||
OPENAPI_MODEL_CONFIGS.append(api_wrapper)
|
||||
|
||||
|
@ -19,7 +19,7 @@ from diffusers import logging as diffusers_logging
|
||||
from onnx import numpy_helper
|
||||
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
|
||||
from picklescan.scanner import scan_file_path
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
|
||||
@ -86,14 +86,21 @@ class ModelError(str, Enum):
|
||||
NotFound = "not_found"
|
||||
|
||||
|
||||
def model_config_json_schema_extra(schema: dict[str, Any]) -> None:
|
||||
if "required" not in schema:
|
||||
schema["required"] = []
|
||||
schema["required"].append("model_type")
|
||||
|
||||
|
||||
class ModelConfigBase(BaseModel):
|
||||
path: str # or Path
|
||||
description: Optional[str] = Field(None)
|
||||
model_format: Optional[str] = Field(None)
|
||||
error: Optional[ModelError] = Field(None)
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
model_config = ConfigDict(
|
||||
use_enum_values=True, protected_namespaces=(), json_schema_extra=model_config_json_schema_extra
|
||||
)
|
||||
|
||||
|
||||
class EmptyConfigLoader(ConfigMixin):
|
||||
|
@ -58,14 +58,16 @@ class IPAdapterModel(ModelBase):
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
torch_dtype: torch.dtype,
|
||||
child_type: Optional[SubModelType] = None,
|
||||
) -> typing.Union[IPAdapter, IPAdapterPlus]:
|
||||
if child_type is not None:
|
||||
raise ValueError("There are no child models in an IP-Adapter model.")
|
||||
|
||||
model = build_ip_adapter(
|
||||
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"), device="cpu", dtype=torch_dtype
|
||||
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"),
|
||||
device=torch.device("cpu"),
|
||||
dtype=torch_dtype,
|
||||
)
|
||||
|
||||
self.model_size = model.calc_size()
|
||||
|
@ -96,7 +96,7 @@ def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axe
|
||||
finally:
|
||||
for module, orig_conv_forward in to_restore:
|
||||
module._conv_forward = orig_conv_forward
|
||||
if hasattr(m, "asymmetric_padding_mode"):
|
||||
del m.asymmetric_padding_mode
|
||||
if hasattr(m, "asymmetric_padding"):
|
||||
del m.asymmetric_padding
|
||||
if hasattr(module, "asymmetric_padding_mode"):
|
||||
del module.asymmetric_padding_mode
|
||||
if hasattr(module, "asymmetric_padding"):
|
||||
del module.asymmetric_padding
|
||||
|
@ -1,7 +1,8 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
@ -11,7 +12,7 @@ class AttentionMapSaver:
|
||||
self.token_ids = token_ids
|
||||
self.latents_shape = latents_shape
|
||||
# self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]])
|
||||
self.collated_maps = {}
|
||||
self.collated_maps: dict[str, torch.Tensor] = {}
|
||||
|
||||
def clear_maps(self):
|
||||
self.collated_maps = {}
|
||||
@ -38,9 +39,10 @@ class AttentionMapSaver:
|
||||
|
||||
def write_maps_to_disk(self, path: str):
|
||||
pil_image = self.get_stacked_maps_image()
|
||||
pil_image.save(path, "PNG")
|
||||
if pil_image is not None:
|
||||
pil_image.save(path, "PNG")
|
||||
|
||||
def get_stacked_maps_image(self) -> PIL.Image:
|
||||
def get_stacked_maps_image(self) -> Optional[Image.Image]:
|
||||
"""
|
||||
Scale all collected attention maps to the same size, blend them together and return as an image.
|
||||
:return: An image containing a vertical stack of blended attention maps, one for each requested token.
|
||||
@ -95,4 +97,4 @@ class AttentionMapSaver:
|
||||
return None
|
||||
|
||||
merged_bytes = merged.mul(0xFF).byte()
|
||||
return PIL.Image.fromarray(merged_bytes.numpy(), mode="L")
|
||||
return Image.fromarray(merged_bytes.numpy(), mode="L")
|
||||
|
Reference in New Issue
Block a user