Add support for controlnet & sdxl checkpoint conversion (#3905)

## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ X] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [X ] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [ ] Yes
- [ X] No - not yet WIP


## Description

This PR adds support for loading and converting checkpoint-format
ControlNet and SDXL models. The SDXL and SDXL-refiner model conversions
are working; however saving the unet in safetensors format leads to
corrupted model files, so currently is saving in .bin format (after
scanning the input model).

ControlNet conversion seems to be working but needs further testing.

To use this PR, you will need to copy the files
`invokeai/configs/stable-diffusion/sd_xl_base.yaml` and
`invokeai/configs/stable-diffusion/sd_xl_refiner.yaml` into
`INVOKEAI/configs/stable-diffusion`. You will also need to run
`invokeai-configure --yes --skip-sd` in order to install additional core
model files needed by the converter.
This commit is contained in:
blessedcoolant 2023-07-27 01:50:38 +12:00 committed by GitHub
commit 3dccc4d61e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 1580 additions and 685 deletions

View File

@ -203,7 +203,10 @@ def invoke_api():
return find_port(port=port + 1)
else:
return port
from invokeai.backend.install.check_root import check_invokeai_root
check_invokeai_root(app_config) # note, may exit with an exception if root not set up
port = find_port(app_config.port)
if port != app_config.port:
logger.warn(f"Port {app_config.port} in use, using port {port}")

View File

@ -0,0 +1,31 @@
"""
Check that the invokeai_root is correctly configured and exit if not.
"""
import sys
from invokeai.app.services.config import (
InvokeAIAppConfig,
)
def check_invokeai_root(config: InvokeAIAppConfig):
try:
assert config.model_conf_path.exists()
assert config.db_path.exists()
assert config.models_path.exists()
for model in [
'CLIP-ViT-bigG-14-laion2B-39B-b160k',
'bert-base-uncased',
'clip-vit-large-patch14',
'sd-vae-ft-mse',
'stable-diffusion-2-clip',
'stable-diffusion-safety-checker']:
assert (config.models_path / f'core/convert/{model}').exists()
except:
print()
print('== STARTUP ABORTED ==')
print('** One or more necessary files is missing from your InvokeAI root directory **')
print('** Please rerun the configuration script to fix this problem. **')
print('** From the launcher, selection option [7]. **')
print('** From the command line, activate the virtual environment and run "invokeai-configure --yes --skip-sd-weights" **')
input('Press any key to continue...')
sys.exit(0)

View File

@ -32,6 +32,7 @@ from omegaconf import OmegaConf
from tqdm import tqdm
from transformers import (
CLIPTextModel,
CLIPTextConfig,
CLIPTokenizer,
AutoFeatureExtractor,
BertTokenizerFast,
@ -55,6 +56,7 @@ from invokeai.frontend.install.widgets import (
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
from invokeai.backend.install.model_install_backend import (
hf_download_from_pretrained,
hf_download_with_resume,
InstallSelections,
ModelInstall,
)
@ -204,6 +206,15 @@ def download_conversion_models():
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'text_encoder', safe_serialization=True)
# sd-xl - tokenizer_2
repo_id = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
_, model_name = repo_id.split('/')
pipeline = CLIPTokenizer.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
pipeline = CLIPTextConfig.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
# VAE
logger.info('Downloading stable diffusion VAE')
vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse', **kwargs)

View File

@ -58,7 +58,15 @@ LEGACY_CONFIGS = {
SchedulerPredictionType.Epsilon: 'v2-inpainting-inference.yaml',
SchedulerPredictionType.VPrediction: 'v2-inpainting-inference-v.yaml',
}
}
},
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: 'sd_xl_base.yaml',
},
BaseModelType.StableDiffusionXLRefiner: {
ModelVariantType.Normal: 'sd_xl_refiner.yaml',
},
}
@dataclass
@ -329,6 +337,7 @@ class ModelInstall(object):
description = str(description),
model_format = info.format,
)
legacy_conf = None
if info.model_type == ModelType.Main:
attributes.update(dict(variant = info.variant_type,))
if info.format=="checkpoint":
@ -343,11 +352,17 @@ class ModelInstall(object):
except KeyError:
legacy_conf = Path(self.config.legacy_conf_dir, 'v1-inference.yaml') # best guess
attributes.update(
dict(
config = str(legacy_conf)
)
if info.model_type == ModelType.ControlNet and info.format=="checkpoint":
possible_conf = path.with_suffix('.yaml')
if possible_conf.exists():
legacy_conf = str(self.relative_to_root(possible_conf))
if legacy_conf:
attributes.update(
dict(
config = str(legacy_conf)
)
)
return attributes
def relative_to_root(self, path: Path)->Path:

File diff suppressed because it is too large Load Diff

View File

@ -673,6 +673,7 @@ class ModelManager(object):
self.models[model_key] = model_config
self.commit()
return AddModelResult(
name = model_name,
model_type = model_type,
@ -840,7 +841,7 @@ class ModelManager(object):
Returns the preamble for the config file.
"""
return textwrap.dedent(
"""\
"""
# This file describes the alternative machine learning models
# available to InvokeAI script.
#

View File

@ -253,10 +253,13 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
return BaseModelType.StableDiffusion1
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
# TODO: Verify that this is correct! Need an XL checkpoint file for this.
key_name = 'model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight'
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
return BaseModelType.StableDiffusionXL
raise InvalidModelException("Cannot determine base type")
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
return BaseModelType.StableDiffusionXLRefiner
else:
raise InvalidModelException("Cannot determine base type")
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
type = self.get_base_type()

View File

@ -1,7 +1,8 @@
import os
import torch
from enum import Enum
from typing import Optional
from pathlib import Path
from typing import Optional, Literal
from .base import (
ModelBase,
ModelConfigBase,
@ -15,6 +16,7 @@ from .base import (
InvalidModelException,
ModelNotFoundException,
)
from invokeai.app.services.config import InvokeAIAppConfig
class ControlNetModelFormat(str, Enum):
Checkpoint = "checkpoint"
@ -24,8 +26,12 @@ class ControlNetModel(ModelBase):
#model_class: Type
#model_size: int
class Config(ModelConfigBase):
model_format: ControlNetModelFormat
class DiffusersConfig(ModelConfigBase):
model_format: Literal[ControlNetModelFormat.Diffusers]
class CheckpointConfig(ModelConfigBase):
model_format: Literal[ControlNetModelFormat.Checkpoint]
config: str
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.ControlNet
@ -99,13 +105,51 @@ class ControlNetModel(ModelBase):
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) == ControlNetModelFormat.Checkpoint:
return _convert_controlnet_ckpt_and_cache(
model_path = model_path,
model_config = config.config,
output_path = output_path,
base_model = base_model,
)
else:
return model_path
@classmethod
def _convert_controlnet_ckpt_and_cache(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase, # empty config or config of parent model
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) != ControlNetModelFormat.Diffusers:
raise NotImplementedError("Checkpoint controlnet models currently unsupported")
else:
return model_path
model_config: ControlNetModel.CheckpointConfig,
) -> str:
"""
Convert the controlnet from checkpoint format to diffusers format,
cache it to disk, and return Path to converted
file. If already on disk then just returns Path.
"""
app_config = InvokeAIAppConfig.get_config()
weights = app_config.root_path / model_path
output_path = Path(output_path)
# return cached version if it exists
if output_path.exists():
return output_path
# to avoid circular import errors
from ..convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
convert_controlnet_to_diffusers(
weights,
output_path,
original_config_file = app_config.root_path / model_config,
image_size = 512,
scan_needed = True,
from_safetensors = weights.suffix == ".safetensors"
)
return output_path

View File

@ -1,5 +1,6 @@
import os
import json
import invokeai.backend.util.logging as logger
from enum import Enum
from pydantic import Field
from typing import Literal, Optional
@ -48,7 +49,7 @@ class StableDiffusionXLModel(DiffusersModel):
if model_format == StableDiffusionXLModelFormat.Checkpoint:
if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path)
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
in_channels = ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
else:
checkpoint = read_checkpoint_meta(path)
@ -108,7 +109,20 @@ class StableDiffusionXLModel(DiffusersModel):
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
# The convert script adapted from the diffusers package uses
# strings for the base model type. To avoid making too many
# source code changes, we simply translate here
model_base_to_model_type = {BaseModelType.StableDiffusionXL: 'SDXL',
BaseModelType.StableDiffusionXLRefiner: 'SDXL-Refiner',
}
if isinstance(config, cls.CheckpointConfig):
raise NotImplementedError('conversion of SDXL checkpoint models to diffusers format is not yet supported')
from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache
return _convert_ckpt_and_cache(
version=base_model,
model_config=config,
output_path=output_path,
model_type=model_base_to_model_type[base_model],
use_safetensors=False, # corrupts sdxl models for some reason
)
else:
return model_path

View File

@ -15,9 +15,12 @@ from .base import (
classproperty,
InvalidModelException,
)
from .sdxl import StableDiffusionXLModel
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from omegaconf import OmegaConf
class StableDiffusion1ModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
@ -235,42 +238,17 @@ class StableDiffusion2Model(DiffusersModel):
else:
return model_path
def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
ckpt_configs = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: "v1-inference.yaml",
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
},
BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512)
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
ModelVariantType.Depth: "v2-midas-inference.yaml",
},
# note that these .yaml files don't yet exist!
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: "xl-inference-v.yaml",
ModelVariantType.Inpaint: "xl-inpainting-inference.yaml",
ModelVariantType.Depth: "xl-midas-inference.yaml",
}
}
app_config = InvokeAIAppConfig.get_config()
try:
config_path = app_config.legacy_conf_path / ckpt_configs[version][variant]
if config_path.is_relative_to(app_config.root_path):
config_path = config_path.relative_to(app_config.root_path)
return str(config_path)
except:
return None
# TODO: rework
# Note that convert_ckpt_to_diffuses does not currently support conversion of SDXL models
# pass precision - currently defaulting to fp16
def _convert_ckpt_and_cache(
version: BaseModelType,
model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig],
output_path: str,
version: BaseModelType,
model_config: Union[StableDiffusion1Model.CheckpointConfig,
StableDiffusion2Model.CheckpointConfig,
StableDiffusionXLModel.CheckpointConfig,
],
output_path: str,
use_save_model: bool=False,
**kwargs,
) -> str:
"""
Convert the checkpoint model indicated in mconfig into a
@ -289,6 +267,9 @@ def _convert_ckpt_and_cache(
# to avoid circular import errors
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
from ...util.devices import choose_torch_device, torch_dtype
logger.info(f'Converting {weights} to diffusers format')
with SilenceWarnings():
convert_ckpt_to_diffusers(
weights,
@ -298,5 +279,43 @@ def _convert_ckpt_and_cache(
original_config_file=config_file,
extract_ema=True,
scan_needed=True,
from_safetensors = weights.suffix == ".safetensors",
precision = torch_dtype(choose_torch_device()),
**kwargs,
)
return output_path
def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
ckpt_configs = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: "v1-inference.yaml",
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
},
BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512)
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
ModelVariantType.Depth: "v2-midas-inference.yaml",
},
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: "sd_xl_base.yaml",
ModelVariantType.Inpaint: None,
ModelVariantType.Depth: None,
},
BaseModelType.StableDiffusionXLRefiner: {
ModelVariantType.Normal: "sd_xl_refiner.yaml",
ModelVariantType.Inpaint: None,
ModelVariantType.Depth: None,
},
}
app_config = InvokeAIAppConfig.get_config()
try:
config_path = app_config.legacy_conf_path / ckpt_configs[version][variant]
if config_path.is_relative_to(app_config.root_path):
config_path = config_path.relative_to(app_config.root_path)
return str(config_path)
except:
return None

View File

@ -1,7 +1,7 @@
# Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team
"""
invokeai.util.logging
invokeai.backend.util.logging
Logging class for InvokeAI that produces console messages

View File

@ -0,0 +1,98 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.13025
disable_first_stage_autocast: True
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
adm_in_channels: 2816
num_classes: sequential
use_checkpoint: True
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2]
num_res_blocks: 2
channel_mult: [1, 2, 4]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
context_dim: 2048
spatial_transformer_attn_type: softmax-xformers
legacy: False
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
params:
layer: hidden
layer_idx: 11
# crossattn and vector cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
params:
arch: ViT-bigG-14
version: laion2b_s39b_b160k
freeze: True
layer: penultimate
always_return_pooled: True
legacy: False
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: target_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

View File

@ -0,0 +1,91 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.13025
disable_first_stage_autocast: True
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
adm_in_channels: 2560
num_classes: sequential
use_checkpoint: True
in_channels: 4
out_channels: 4
model_channels: 384
attention_resolutions: [4, 2]
num_res_blocks: 2
channel_mult: [1, 2, 4, 4]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 4
context_dim: [1280, 1280, 1280, 1280] # 1280
spatial_transformer_attn_type: softmax-xformers
legacy: False
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn and vector cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
params:
arch: ViT-bigG-14
version: laion2b_s39b_b160k
legacy: False
freeze: True
layer: penultimate
always_return_pooled: True
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: aesthetic_score
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by one
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

View File

@ -1,6 +1,5 @@
import { Badge, Divider, Flex, Text } from '@chakra-ui/react';
import { useForm } from '@mantine/form';
import { makeToast } from 'features/system/util/makeToast';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput';
@ -8,6 +7,7 @@ import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useCallback, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import {
@ -115,7 +115,7 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
{MODEL_TYPE_MAP[model.base_model]} Model
</Text>
</Flex>
{!['sdxl', 'sdxl-refiner'].includes(model.base_model) ? (
{![''].includes(model.base_model) ? (
<ModelConvert model={model} />
) : (
<Badge