Fixes, add sd variant detection

This commit is contained in:
Sergey Borisov
2023-06-12 05:52:30 +03:00
parent 893f776f1d
commit 9fa78443de
5 changed files with 143 additions and 19 deletions

View File

@ -5,26 +5,30 @@ from .lora import LoRAModel
#from .controlnet import ControlNetModel # TODO:
from .textual_inversion import TextualInversionModel
# TODO:
class ControlNetModel:
pass
MODEL_CLASSES = {
BaseModelType.StableDiffusion1_5: {
ModelType.Pipeline: StableDiffusion15Model,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
#ModelType.ControlNet: ControlNetModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
BaseModelType.StableDiffusion2: {
ModelType.Pipeline: StableDiffusion2Model,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
#ModelType.ControlNet: ControlNetModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
BaseModelType.StableDiffusion2Base: {
ModelType.Pipeline: StableDiffusion2BaseModel,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
#ModelType.ControlNet: ControlNetModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
#BaseModelType.Kandinsky2_1: {
@ -35,3 +39,11 @@ MODEL_CLASSES = {
# ModelType.TextualInversion: TextualInversionModel,
#},
}
# TODO: check with openapi annotation
def get_all_model_configs():
configs = []
for models in MODEL_CLASSES.values():
for model in models.values():
configs.extend(model._get_configs())
return configs

View File

@ -1,5 +1,7 @@
import os
import json
import torch
import safetensors.torch
from pydantic import Field
from typing import Literal, Optional
from .base import (
@ -8,10 +10,13 @@ from .base import (
BaseModelType,
ModelType,
SubModelType,
VariantType,
DiffusersModel,
)
from invokeai.app.services.config import InvokeAIAppConfig
ModelVariantType = VariantType # TODO:
# TODO: how to name properly
class StableDiffusion15Model(DiffusersModel):
@ -20,11 +25,13 @@ class StableDiffusion15Model(DiffusersModel):
class DiffusersConfig(ModelConfigBase):
format: Literal["diffusers"]
vae: Optional[str] = Field(None)
variant: ModelVariantType
class CheckpointConfig(ModelConfigBase):
format: Literal["checkpoint"]
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
@ -36,6 +43,86 @@ class StableDiffusion15Model(DiffusersModel):
model_type=ModelType.Pipeline,
)
@staticmethod
def _fast_safetensors_reader(path: str):
checkpoint = dict()
device = torch.device("meta")
with open(path, "rb") as f:
definition_len = int.from_bytes(f.read(8), 'little')
definition_json = f.read(definition_len)
definition = json.loads(definition_json)
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {"pt", "torch", "pytorch"}:
raise Exception("Supported only pytorch safetensors files")
definition.pop("__metadata__", None)
for key, info in definition.items():
dtype = {
"I8": torch.int8,
"I16": torch.int16,
"I32": torch.int32,
"I64": torch.int64,
"F16": torch.float16,
"F32": torch.float32,
"F64": torch.float64,
}[info["dtype"]]
checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device)
return checkpoint
@classmethod
def read_checkpoint_meta(cls, path: str):
if path.endswith(".safetensors"):
try:
checkpoint = cls._fast_safetensors_reader(path)
except:
checkpoint = safetensors.torch.load_file(path, device="cpu") # TODO: create issue for support "meta"?
else:
checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint
@classmethod
def build_config(cls, **kwargs):
if "format" not in kwargs:
kwargs["format"] = cls.detect_format(kwargs["path"])
if "variant" not in kwargs:
if kwargs["format"] == "checkpoint":
if "config" in kwargs:
ckpt_config = OmegaConf.load(kwargs["config"])
in_channels = ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
else:
checkpoint = cls.read_checkpoint_meta(kwargs["path"])
checkpoint = checkpoint.get('state_dict', checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif kwargs["format"] == "diffusers":
unet_config_path = os.path.join(kwargs["path"], "unet", "config.json")
if os.path.exists(unet_config_path):
unet_config = json.loads(unet_config_path)
in_channels = unet_config['in_channels']
else:
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
else:
raise NotImplementedError(f"Unknown stable diffusion format: {kwargs['format']}")
if in_channels == 9:
kwargs["variant"] = ModelVariantType.Inpaint
elif in_channels == 5:
kwargs["variant"] = ModelVariantType.Depth
elif in_channels == 4:
kwargs["variant"] = ModelVariantType.Normal
else:
raise Exception("Unkown stable diffusion model format")
return super().build_config(**kwargs)
@classmethod
def save_to_config(cls) -> bool:
return True