Merge branch 'main' into feat/safety-checker-node

This commit is contained in:
Lincoln Stein 2023-07-26 10:04:23 -04:00 committed by GitHub
commit 4bea846199
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 1580 additions and 685 deletions

View File

@ -203,7 +203,10 @@ def invoke_api():
return find_port(port=port + 1)
else:
return port
from invokeai.backend.install.check_root import check_invokeai_root
check_invokeai_root(app_config) # note, may exit with an exception if root not set up
port = find_port(app_config.port)
if port != app_config.port:
logger.warn(f"Port {app_config.port} in use, using port {port}")

View File

@ -0,0 +1,31 @@
"""
Check that the invokeai_root is correctly configured and exit if not.
"""
import sys
from invokeai.app.services.config import (
InvokeAIAppConfig,
)
def check_invokeai_root(config: InvokeAIAppConfig):
try:
assert config.model_conf_path.exists()
assert config.db_path.exists()
assert config.models_path.exists()
for model in [
'CLIP-ViT-bigG-14-laion2B-39B-b160k',
'bert-base-uncased',
'clip-vit-large-patch14',
'sd-vae-ft-mse',
'stable-diffusion-2-clip',
'stable-diffusion-safety-checker']:
assert (config.models_path / f'core/convert/{model}').exists()
except:
print()
print('== STARTUP ABORTED ==')
print('** One or more necessary files is missing from your InvokeAI root directory **')
print('** Please rerun the configuration script to fix this problem. **')
print('** From the launcher, selection option [7]. **')
print('** From the command line, activate the virtual environment and run "invokeai-configure --yes --skip-sd-weights" **')
input('Press any key to continue...')
sys.exit(0)

View File

@ -32,6 +32,7 @@ from omegaconf import OmegaConf
from tqdm import tqdm
from transformers import (
CLIPTextModel,
CLIPTextConfig,
CLIPTokenizer,
AutoFeatureExtractor,
BertTokenizerFast,
@ -55,6 +56,7 @@ from invokeai.frontend.install.widgets import (
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
from invokeai.backend.install.model_install_backend import (
hf_download_from_pretrained,
hf_download_with_resume,
InstallSelections,
ModelInstall,
)
@ -204,6 +206,15 @@ def download_conversion_models():
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'text_encoder', safe_serialization=True)
# sd-xl - tokenizer_2
repo_id = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
_, model_name = repo_id.split('/')
pipeline = CLIPTokenizer.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
pipeline = CLIPTextConfig.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
# VAE
logger.info('Downloading stable diffusion VAE')
vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse', **kwargs)

View File

@ -58,7 +58,15 @@ LEGACY_CONFIGS = {
SchedulerPredictionType.Epsilon: 'v2-inpainting-inference.yaml',
SchedulerPredictionType.VPrediction: 'v2-inpainting-inference-v.yaml',
}
}
},
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: 'sd_xl_base.yaml',
},
BaseModelType.StableDiffusionXLRefiner: {
ModelVariantType.Normal: 'sd_xl_refiner.yaml',
},
}
@dataclass
@ -329,6 +337,7 @@ class ModelInstall(object):
description = str(description),
model_format = info.format,
)
legacy_conf = None
if info.model_type == ModelType.Main:
attributes.update(dict(variant = info.variant_type,))
if info.format=="checkpoint":
@ -343,11 +352,17 @@ class ModelInstall(object):
except KeyError:
legacy_conf = Path(self.config.legacy_conf_dir, 'v1-inference.yaml') # best guess
attributes.update(
dict(
config = str(legacy_conf)
)
if info.model_type == ModelType.ControlNet and info.format=="checkpoint":
possible_conf = path.with_suffix('.yaml')
if possible_conf.exists():
legacy_conf = str(self.relative_to_root(possible_conf))
if legacy_conf:
attributes.update(
dict(
config = str(legacy_conf)
)
)
return attributes
def relative_to_root(self, path: Path)->Path:

File diff suppressed because it is too large Load Diff

View File

@ -673,6 +673,7 @@ class ModelManager(object):
self.models[model_key] = model_config
self.commit()
return AddModelResult(
name = model_name,
model_type = model_type,
@ -840,7 +841,7 @@ class ModelManager(object):
Returns the preamble for the config file.
"""
return textwrap.dedent(
"""\
"""
# This file describes the alternative machine learning models
# available to InvokeAI script.
#

View File

@ -253,10 +253,13 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
return BaseModelType.StableDiffusion1
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
# TODO: Verify that this is correct! Need an XL checkpoint file for this.
key_name = 'model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight'
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
return BaseModelType.StableDiffusionXL
raise InvalidModelException("Cannot determine base type")
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
return BaseModelType.StableDiffusionXLRefiner
else:
raise InvalidModelException("Cannot determine base type")
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
type = self.get_base_type()

View File

@ -1,7 +1,8 @@
import os
import torch
from enum import Enum
from typing import Optional
from pathlib import Path
from typing import Optional, Literal
from .base import (
ModelBase,
ModelConfigBase,
@ -15,6 +16,7 @@ from .base import (
InvalidModelException,
ModelNotFoundException,
)
from invokeai.app.services.config import InvokeAIAppConfig
class ControlNetModelFormat(str, Enum):
Checkpoint = "checkpoint"
@ -24,8 +26,12 @@ class ControlNetModel(ModelBase):
#model_class: Type
#model_size: int
class Config(ModelConfigBase):
model_format: ControlNetModelFormat
class DiffusersConfig(ModelConfigBase):
model_format: Literal[ControlNetModelFormat.Diffusers]
class CheckpointConfig(ModelConfigBase):
model_format: Literal[ControlNetModelFormat.Checkpoint]
config: str
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.ControlNet
@ -99,13 +105,51 @@ class ControlNetModel(ModelBase):
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) == ControlNetModelFormat.Checkpoint:
return _convert_controlnet_ckpt_and_cache(
model_path = model_path,
model_config = config.config,
output_path = output_path,
base_model = base_model,
)
else:
return model_path
@classmethod
def _convert_controlnet_ckpt_and_cache(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase, # empty config or config of parent model
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) != ControlNetModelFormat.Diffusers:
raise NotImplementedError("Checkpoint controlnet models currently unsupported")
else:
return model_path
model_config: ControlNetModel.CheckpointConfig,
) -> str:
"""
Convert the controlnet from checkpoint format to diffusers format,
cache it to disk, and return Path to converted
file. If already on disk then just returns Path.
"""
app_config = InvokeAIAppConfig.get_config()
weights = app_config.root_path / model_path
output_path = Path(output_path)
# return cached version if it exists
if output_path.exists():
return output_path
# to avoid circular import errors
from ..convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
convert_controlnet_to_diffusers(
weights,
output_path,
original_config_file = app_config.root_path / model_config,
image_size = 512,
scan_needed = True,
from_safetensors = weights.suffix == ".safetensors"
)
return output_path

View File

@ -1,5 +1,6 @@
import os
import json
import invokeai.backend.util.logging as logger
from enum import Enum
from pydantic import Field
from typing import Literal, Optional
@ -48,7 +49,7 @@ class StableDiffusionXLModel(DiffusersModel):
if model_format == StableDiffusionXLModelFormat.Checkpoint:
if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path)
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
in_channels = ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
else:
checkpoint = read_checkpoint_meta(path)
@ -108,7 +109,20 @@ class StableDiffusionXLModel(DiffusersModel):
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
# The convert script adapted from the diffusers package uses
# strings for the base model type. To avoid making too many
# source code changes, we simply translate here
model_base_to_model_type = {BaseModelType.StableDiffusionXL: 'SDXL',
BaseModelType.StableDiffusionXLRefiner: 'SDXL-Refiner',
}
if isinstance(config, cls.CheckpointConfig):
raise NotImplementedError('conversion of SDXL checkpoint models to diffusers format is not yet supported')
from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache
return _convert_ckpt_and_cache(
version=base_model,
model_config=config,
output_path=output_path,
model_type=model_base_to_model_type[base_model],
use_safetensors=False, # corrupts sdxl models for some reason
)
else:
return model_path

View File

@ -15,9 +15,12 @@ from .base import (
classproperty,
InvalidModelException,
)
from .sdxl import StableDiffusionXLModel
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from omegaconf import OmegaConf
class StableDiffusion1ModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
@ -235,42 +238,17 @@ class StableDiffusion2Model(DiffusersModel):
else:
return model_path
def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
ckpt_configs = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: "v1-inference.yaml",
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
},
BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512)
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
ModelVariantType.Depth: "v2-midas-inference.yaml",
},
# note that these .yaml files don't yet exist!
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: "xl-inference-v.yaml",
ModelVariantType.Inpaint: "xl-inpainting-inference.yaml",
ModelVariantType.Depth: "xl-midas-inference.yaml",
}
}
app_config = InvokeAIAppConfig.get_config()
try:
config_path = app_config.legacy_conf_path / ckpt_configs[version][variant]
if config_path.is_relative_to(app_config.root_path):
config_path = config_path.relative_to(app_config.root_path)
return str(config_path)
except:
return None
# TODO: rework
# Note that convert_ckpt_to_diffuses does not currently support conversion of SDXL models
# pass precision - currently defaulting to fp16
def _convert_ckpt_and_cache(
version: BaseModelType,
model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig],
output_path: str,
version: BaseModelType,
model_config: Union[StableDiffusion1Model.CheckpointConfig,
StableDiffusion2Model.CheckpointConfig,
StableDiffusionXLModel.CheckpointConfig,
],
output_path: str,
use_save_model: bool=False,
**kwargs,
) -> str:
"""
Convert the checkpoint model indicated in mconfig into a
@ -289,6 +267,9 @@ def _convert_ckpt_and_cache(
# to avoid circular import errors
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
from ...util.devices import choose_torch_device, torch_dtype
logger.info(f'Converting {weights} to diffusers format')
with SilenceWarnings():
convert_ckpt_to_diffusers(
weights,
@ -298,5 +279,43 @@ def _convert_ckpt_and_cache(
original_config_file=config_file,
extract_ema=True,
scan_needed=True,
from_safetensors = weights.suffix == ".safetensors",
precision = torch_dtype(choose_torch_device()),
**kwargs,
)
return output_path
def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
ckpt_configs = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: "v1-inference.yaml",
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
},
BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512)
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
ModelVariantType.Depth: "v2-midas-inference.yaml",
},
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: "sd_xl_base.yaml",
ModelVariantType.Inpaint: None,
ModelVariantType.Depth: None,
},
BaseModelType.StableDiffusionXLRefiner: {
ModelVariantType.Normal: "sd_xl_refiner.yaml",
ModelVariantType.Inpaint: None,
ModelVariantType.Depth: None,
},
}
app_config = InvokeAIAppConfig.get_config()
try:
config_path = app_config.legacy_conf_path / ckpt_configs[version][variant]
if config_path.is_relative_to(app_config.root_path):
config_path = config_path.relative_to(app_config.root_path)
return str(config_path)
except:
return None

View File

@ -1,7 +1,7 @@
# Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team
"""
invokeai.util.logging
invokeai.backend.util.logging
Logging class for InvokeAI that produces console messages

View File

@ -0,0 +1,98 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.13025
disable_first_stage_autocast: True
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
adm_in_channels: 2816
num_classes: sequential
use_checkpoint: True
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2]
num_res_blocks: 2
channel_mult: [1, 2, 4]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
context_dim: 2048
spatial_transformer_attn_type: softmax-xformers
legacy: False
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
params:
layer: hidden
layer_idx: 11
# crossattn and vector cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
params:
arch: ViT-bigG-14
version: laion2b_s39b_b160k
freeze: True
layer: penultimate
always_return_pooled: True
legacy: False
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: target_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

View File

@ -0,0 +1,91 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.13025
disable_first_stage_autocast: True
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
adm_in_channels: 2560
num_classes: sequential
use_checkpoint: True
in_channels: 4
out_channels: 4
model_channels: 384
attention_resolutions: [4, 2]
num_res_blocks: 2
channel_mult: [1, 2, 4, 4]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 4
context_dim: [1280, 1280, 1280, 1280] # 1280
spatial_transformer_attn_type: softmax-xformers
legacy: False
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn and vector cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
params:
arch: ViT-bigG-14
version: laion2b_s39b_b160k
legacy: False
freeze: True
layer: penultimate
always_return_pooled: True
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: aesthetic_score
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by one
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

View File

@ -1,6 +1,5 @@
import { Badge, Divider, Flex, Text } from '@chakra-ui/react';
import { useForm } from '@mantine/form';
import { makeToast } from 'features/system/util/makeToast';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput';
@ -8,6 +7,7 @@ import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useCallback, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import {
@ -115,7 +115,7 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
{MODEL_TYPE_MAP[model.base_model]} Model
</Text>
</Flex>
{!['sdxl', 'sdxl-refiner'].includes(model.base_model) ? (
{![''].includes(model.base_model) ? (
<ModelConvert model={model} />
) : (
<Badge