mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add support for controlnet & sdxl conversion - not fully working
This commit is contained in:
parent
907ff165be
commit
5607794dbb
@ -55,6 +55,7 @@ from invokeai.frontend.install.widgets import (
|
|||||||
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
|
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
|
||||||
from invokeai.backend.install.model_install_backend import (
|
from invokeai.backend.install.model_install_backend import (
|
||||||
hf_download_from_pretrained,
|
hf_download_from_pretrained,
|
||||||
|
hf_download_with_resume,
|
||||||
InstallSelections,
|
InstallSelections,
|
||||||
ModelInstall,
|
ModelInstall,
|
||||||
)
|
)
|
||||||
@ -204,6 +205,13 @@ def download_conversion_models():
|
|||||||
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
|
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
|
||||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'text_encoder', safe_serialization=True)
|
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'text_encoder', safe_serialization=True)
|
||||||
|
|
||||||
|
repo_id = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||||
|
_, model_name = repo_id.split('/')
|
||||||
|
tokenizer_2 = CLIPTokenizer.from_pretrained(repo_id, **kwargs)
|
||||||
|
tokenizer_2.save_pretrained(target_dir / model_name, safe_serialization=True)
|
||||||
|
# for some reason config.json never downloads
|
||||||
|
hf_download_with_resume(repo_id, target_dir / model_name, "config.json")
|
||||||
|
|
||||||
# VAE
|
# VAE
|
||||||
logger.info('Downloading stable diffusion VAE')
|
logger.info('Downloading stable diffusion VAE')
|
||||||
vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse', **kwargs)
|
vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse', **kwargs)
|
||||||
|
@ -58,7 +58,15 @@ LEGACY_CONFIGS = {
|
|||||||
SchedulerPredictionType.Epsilon: 'v2-inpainting-inference.yaml',
|
SchedulerPredictionType.Epsilon: 'v2-inpainting-inference.yaml',
|
||||||
SchedulerPredictionType.VPrediction: 'v2-inpainting-inference-v.yaml',
|
SchedulerPredictionType.VPrediction: 'v2-inpainting-inference-v.yaml',
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
|
||||||
|
BaseModelType.StableDiffusionXL: {
|
||||||
|
ModelVariantType.Normal: 'sd_xl_base.yaml',
|
||||||
|
},
|
||||||
|
|
||||||
|
BaseModelType.StableDiffusionXLRefiner: {
|
||||||
|
ModelVariantType.Normal: 'sd_xl_refiner.yaml',
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -329,6 +337,7 @@ class ModelInstall(object):
|
|||||||
description = str(description),
|
description = str(description),
|
||||||
model_format = info.format,
|
model_format = info.format,
|
||||||
)
|
)
|
||||||
|
legacy_conf = None
|
||||||
if info.model_type == ModelType.Main:
|
if info.model_type == ModelType.Main:
|
||||||
attributes.update(dict(variant = info.variant_type,))
|
attributes.update(dict(variant = info.variant_type,))
|
||||||
if info.format=="checkpoint":
|
if info.format=="checkpoint":
|
||||||
@ -343,6 +352,12 @@ class ModelInstall(object):
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
legacy_conf = Path(self.config.legacy_conf_dir, 'v1-inference.yaml') # best guess
|
legacy_conf = Path(self.config.legacy_conf_dir, 'v1-inference.yaml') # best guess
|
||||||
|
|
||||||
|
if info.model_type == ModelType.ControlNet and info.format=="checkpoint":
|
||||||
|
possible_conf = path.with_suffix('.yaml')
|
||||||
|
if possible_conf.exists():
|
||||||
|
legacy_conf = str(self.relative_to_root(possible_conf))
|
||||||
|
|
||||||
|
if legacy_conf:
|
||||||
attributes.update(
|
attributes.update(
|
||||||
dict(
|
dict(
|
||||||
config = str(legacy_conf)
|
config = str(legacy_conf)
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -673,6 +673,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
self.models[model_key] = model_config
|
self.models[model_key] = model_config
|
||||||
self.commit()
|
self.commit()
|
||||||
|
|
||||||
return AddModelResult(
|
return AddModelResult(
|
||||||
name = model_name,
|
name = model_name,
|
||||||
model_type = model_type,
|
model_type = model_type,
|
||||||
@ -840,7 +841,7 @@ class ModelManager(object):
|
|||||||
Returns the preamble for the config file.
|
Returns the preamble for the config file.
|
||||||
"""
|
"""
|
||||||
return textwrap.dedent(
|
return textwrap.dedent(
|
||||||
"""\
|
"""
|
||||||
# This file describes the alternative machine learning models
|
# This file describes the alternative machine learning models
|
||||||
# available to InvokeAI script.
|
# available to InvokeAI script.
|
||||||
#
|
#
|
||||||
|
@ -253,7 +253,8 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
|||||||
return BaseModelType.StableDiffusion1
|
return BaseModelType.StableDiffusion1
|
||||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||||
return BaseModelType.StableDiffusion2
|
return BaseModelType.StableDiffusion2
|
||||||
# TODO: Verify that this is correct! Need an XL checkpoint file for this.
|
# TODO: This is just a guess based on N=1
|
||||||
|
key_name = 'model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight'
|
||||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
|
||||||
return BaseModelType.StableDiffusionXL
|
return BaseModelType.StableDiffusionXL
|
||||||
raise InvalidModelException("Cannot determine base type")
|
raise InvalidModelException("Cannot determine base type")
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
from pathlib import Path
|
||||||
|
from typing import Optional, Literal
|
||||||
from .base import (
|
from .base import (
|
||||||
ModelBase,
|
ModelBase,
|
||||||
ModelConfigBase,
|
ModelConfigBase,
|
||||||
@ -15,6 +16,7 @@ from .base import (
|
|||||||
InvalidModelException,
|
InvalidModelException,
|
||||||
ModelNotFoundException,
|
ModelNotFoundException,
|
||||||
)
|
)
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
class ControlNetModelFormat(str, Enum):
|
class ControlNetModelFormat(str, Enum):
|
||||||
Checkpoint = "checkpoint"
|
Checkpoint = "checkpoint"
|
||||||
@ -24,8 +26,12 @@ class ControlNetModel(ModelBase):
|
|||||||
#model_class: Type
|
#model_class: Type
|
||||||
#model_size: int
|
#model_size: int
|
||||||
|
|
||||||
class Config(ModelConfigBase):
|
class DiffusersConfig(ModelConfigBase):
|
||||||
model_format: ControlNetModelFormat
|
model_format: Literal[ControlNetModelFormat.Diffusers]
|
||||||
|
|
||||||
|
class CheckpointConfig(ModelConfigBase):
|
||||||
|
model_format: Literal[ControlNetModelFormat.Checkpoint]
|
||||||
|
config: str
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
assert model_type == ModelType.ControlNet
|
assert model_type == ModelType.ControlNet
|
||||||
@ -102,10 +108,48 @@ class ControlNetModel(ModelBase):
|
|||||||
cls,
|
cls,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
output_path: str,
|
output_path: str,
|
||||||
config: ModelConfigBase, # empty config or config of parent model
|
config: ModelConfigBase,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
) -> str:
|
) -> str:
|
||||||
if cls.detect_format(model_path) != ControlNetModelFormat.Diffusers:
|
if cls.detect_format(model_path) == ControlNetModelFormat.Checkpoint:
|
||||||
raise NotImplementedError("Checkpoint controlnet models currently unsupported")
|
return _convert_controlnet_ckpt_and_cache(
|
||||||
|
model_path = model_path,
|
||||||
|
model_config = config.config,
|
||||||
|
output_path = output_path,
|
||||||
|
base_model = base_model,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _convert_controlnet_ckpt_and_cache(
|
||||||
|
cls,
|
||||||
|
model_path: str,
|
||||||
|
output_path: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_config: ControlNetModel.CheckpointConfig,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Convert the controlnet from checkpoint format to diffusers format,
|
||||||
|
cache it to disk, and return Path to converted
|
||||||
|
file. If already on disk then just returns Path.
|
||||||
|
"""
|
||||||
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
|
weights = app_config.root_path / model_path
|
||||||
|
output_path = Path(output_path)
|
||||||
|
|
||||||
|
# return cached version if it exists
|
||||||
|
if output_path.exists():
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
# to avoid circular import errors
|
||||||
|
from ..convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
|
||||||
|
convert_controlnet_to_diffusers(
|
||||||
|
weights,
|
||||||
|
output_path,
|
||||||
|
original_config_file = app_config.root_path / model_config,
|
||||||
|
image_size = 512,
|
||||||
|
scan_needed = True,
|
||||||
|
from_safetensors = weights.suffix == ".safetensors"
|
||||||
|
)
|
||||||
|
return output_path
|
||||||
|
@ -48,7 +48,7 @@ class StableDiffusionXLModel(DiffusersModel):
|
|||||||
if model_format == StableDiffusionXLModelFormat.Checkpoint:
|
if model_format == StableDiffusionXLModelFormat.Checkpoint:
|
||||||
if ckpt_config_path:
|
if ckpt_config_path:
|
||||||
ckpt_config = OmegaConf.load(ckpt_config_path)
|
ckpt_config = OmegaConf.load(ckpt_config_path)
|
||||||
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
in_channels = ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
checkpoint = read_checkpoint_meta(path)
|
checkpoint = read_checkpoint_meta(path)
|
||||||
@ -109,6 +109,13 @@ class StableDiffusionXLModel(DiffusersModel):
|
|||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
) -> str:
|
) -> str:
|
||||||
if isinstance(config, cls.CheckpointConfig):
|
if isinstance(config, cls.CheckpointConfig):
|
||||||
raise NotImplementedError('conversion of SDXL checkpoint models to diffusers format is not yet supported')
|
from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache
|
||||||
|
return _convert_ckpt_and_cache(
|
||||||
|
version=base_model,
|
||||||
|
model_config=config,
|
||||||
|
output_path=output_path,
|
||||||
|
model_type='SDXL',
|
||||||
|
no_safetensors=True, # giving errors for some reason
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return model_path
|
return model_path
|
||||||
|
@ -15,6 +15,7 @@ from .base import (
|
|||||||
classproperty,
|
classproperty,
|
||||||
InvalidModelException,
|
InvalidModelException,
|
||||||
)
|
)
|
||||||
|
from .sdxl import StableDiffusionXLModel
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
@ -235,42 +236,16 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
else:
|
else:
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
|
|
||||||
ckpt_configs = {
|
|
||||||
BaseModelType.StableDiffusion1: {
|
|
||||||
ModelVariantType.Normal: "v1-inference.yaml",
|
|
||||||
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
|
||||||
},
|
|
||||||
BaseModelType.StableDiffusion2: {
|
|
||||||
ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512)
|
|
||||||
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
|
|
||||||
ModelVariantType.Depth: "v2-midas-inference.yaml",
|
|
||||||
},
|
|
||||||
# note that these .yaml files don't yet exist!
|
|
||||||
BaseModelType.StableDiffusionXL: {
|
|
||||||
ModelVariantType.Normal: "xl-inference-v.yaml",
|
|
||||||
ModelVariantType.Inpaint: "xl-inpainting-inference.yaml",
|
|
||||||
ModelVariantType.Depth: "xl-midas-inference.yaml",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
|
||||||
try:
|
|
||||||
config_path = app_config.legacy_conf_path / ckpt_configs[version][variant]
|
|
||||||
if config_path.is_relative_to(app_config.root_path):
|
|
||||||
config_path = config_path.relative_to(app_config.root_path)
|
|
||||||
return str(config_path)
|
|
||||||
|
|
||||||
except:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: rework
|
# TODO: rework
|
||||||
# Note that convert_ckpt_to_diffuses does not currently support conversion of SDXL models
|
|
||||||
def _convert_ckpt_and_cache(
|
def _convert_ckpt_and_cache(
|
||||||
version: BaseModelType,
|
version: BaseModelType,
|
||||||
model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig],
|
model_config: Union[StableDiffusion1Model.CheckpointConfig,
|
||||||
|
StableDiffusion2Model.CheckpointConfig,
|
||||||
|
StableDiffusionXLModel.CheckpointConfig,
|
||||||
|
],
|
||||||
output_path: str,
|
output_path: str,
|
||||||
|
use_save_model: bool=False,
|
||||||
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Convert the checkpoint model indicated in mconfig into a
|
Convert the checkpoint model indicated in mconfig into a
|
||||||
@ -298,5 +273,42 @@ def _convert_ckpt_and_cache(
|
|||||||
original_config_file=config_file,
|
original_config_file=config_file,
|
||||||
extract_ema=True,
|
extract_ema=True,
|
||||||
scan_needed=True,
|
scan_needed=True,
|
||||||
|
from_safetensors = weights.suffix == ".safetensors",
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
return output_path
|
return output_path
|
||||||
|
|
||||||
|
def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
|
||||||
|
ckpt_configs = {
|
||||||
|
BaseModelType.StableDiffusion1: {
|
||||||
|
ModelVariantType.Normal: "v1-inference.yaml",
|
||||||
|
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
||||||
|
},
|
||||||
|
BaseModelType.StableDiffusion2: {
|
||||||
|
ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512)
|
||||||
|
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
|
||||||
|
ModelVariantType.Depth: "v2-midas-inference.yaml",
|
||||||
|
},
|
||||||
|
BaseModelType.StableDiffusionXL: {
|
||||||
|
ModelVariantType.Normal: "sd_xl_base.yaml",
|
||||||
|
ModelVariantType.Inpaint: None,
|
||||||
|
ModelVariantType.Depth: None,
|
||||||
|
},
|
||||||
|
BaseModelType.StableDiffusionXLRefiner: {
|
||||||
|
ModelVariantType.Normal: "sd_xl_refiner.yaml",
|
||||||
|
ModelVariantType.Inpaint: None,
|
||||||
|
ModelVariantType.Depth: None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
|
try:
|
||||||
|
config_path = app_config.legacy_conf_path / ckpt_configs[version][variant]
|
||||||
|
if config_path.is_relative_to(app_config.root_path):
|
||||||
|
config_path = config_path.relative_to(app_config.root_path)
|
||||||
|
return str(config_path)
|
||||||
|
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
98
invokeai/configs/stable-diffusion/sd_xl_base.yaml
Normal file
98
invokeai/configs/stable-diffusion/sd_xl_base.yaml
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
model:
|
||||||
|
target: sgm.models.diffusion.DiffusionEngine
|
||||||
|
params:
|
||||||
|
scale_factor: 0.13025
|
||||||
|
disable_first_stage_autocast: True
|
||||||
|
|
||||||
|
denoiser_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
||||||
|
params:
|
||||||
|
num_idx: 1000
|
||||||
|
|
||||||
|
weighting_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
||||||
|
scaling_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
||||||
|
discretization_config:
|
||||||
|
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
||||||
|
|
||||||
|
network_config:
|
||||||
|
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
adm_in_channels: 2816
|
||||||
|
num_classes: sequential
|
||||||
|
use_checkpoint: True
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [4, 2]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [1, 2, 4]
|
||||||
|
num_head_channels: 64
|
||||||
|
use_spatial_transformer: True
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
|
||||||
|
context_dim: 2048
|
||||||
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
conditioner_config:
|
||||||
|
target: sgm.modules.GeneralConditioner
|
||||||
|
params:
|
||||||
|
emb_models:
|
||||||
|
# crossattn cond
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: txt
|
||||||
|
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
params:
|
||||||
|
layer: hidden
|
||||||
|
layer_idx: 11
|
||||||
|
# crossattn and vector cond
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: txt
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
|
||||||
|
params:
|
||||||
|
arch: ViT-bigG-14
|
||||||
|
version: laion2b_s39b_b160k
|
||||||
|
freeze: True
|
||||||
|
layer: penultimate
|
||||||
|
always_return_pooled: True
|
||||||
|
legacy: False
|
||||||
|
# vector cond
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: original_size_as_tuple
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256 # multiplied by two
|
||||||
|
# vector cond
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: crop_coords_top_left
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256 # multiplied by two
|
||||||
|
# vector cond
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: target_size_as_tuple
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256 # multiplied by two
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
91
invokeai/configs/stable-diffusion/sd_xl_refiner.yaml
Normal file
91
invokeai/configs/stable-diffusion/sd_xl_refiner.yaml
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
model:
|
||||||
|
target: sgm.models.diffusion.DiffusionEngine
|
||||||
|
params:
|
||||||
|
scale_factor: 0.13025
|
||||||
|
disable_first_stage_autocast: True
|
||||||
|
|
||||||
|
denoiser_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
||||||
|
params:
|
||||||
|
num_idx: 1000
|
||||||
|
|
||||||
|
weighting_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
||||||
|
scaling_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
||||||
|
discretization_config:
|
||||||
|
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
||||||
|
|
||||||
|
network_config:
|
||||||
|
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
adm_in_channels: 2560
|
||||||
|
num_classes: sequential
|
||||||
|
use_checkpoint: True
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 384
|
||||||
|
attention_resolutions: [4, 2]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [1, 2, 4, 4]
|
||||||
|
num_head_channels: 64
|
||||||
|
use_spatial_transformer: True
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 4
|
||||||
|
context_dim: [1280, 1280, 1280, 1280] # 1280
|
||||||
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
conditioner_config:
|
||||||
|
target: sgm.modules.GeneralConditioner
|
||||||
|
params:
|
||||||
|
emb_models:
|
||||||
|
# crossattn and vector cond
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: txt
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
|
||||||
|
params:
|
||||||
|
arch: ViT-bigG-14
|
||||||
|
version: laion2b_s39b_b160k
|
||||||
|
legacy: False
|
||||||
|
freeze: True
|
||||||
|
layer: penultimate
|
||||||
|
always_return_pooled: True
|
||||||
|
# vector cond
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: original_size_as_tuple
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256 # multiplied by two
|
||||||
|
# vector cond
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: crop_coords_top_left
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256 # multiplied by two
|
||||||
|
# vector cond
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: aesthetic_score
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256 # multiplied by one
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
Loading…
Reference in New Issue
Block a user