Merge branch 'main' into fix/controlnet_cfg_inj_cond

This commit is contained in:
Lincoln Stein 2023-06-28 15:44:50 -04:00 committed by GitHub
commit 20fbe81395
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
40 changed files with 2652 additions and 2043 deletions

View File

@ -87,16 +87,16 @@ Prior to installing PyPatchMatch, you need to take the following steps:
sudo pacman -S --needed base-devel sudo pacman -S --needed base-devel
``` ```
2. Install `opencv`: 2. Install `opencv` and `blas`:
```sh ```sh
sudo pacman -S opencv sudo pacman -S opencv blas
``` ```
or for CUDA support or for CUDA support
```sh ```sh
sudo pacman -S opencv-cuda sudo pacman -S opencv-cuda blas
``` ```
3. Fix the naming of the `opencv` package configuration file: 3. Fix the naming of the `opencv` package configuration file:

View File

@ -1,13 +1,13 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
from typing import Annotated, Literal, Optional, Union, Dict from typing import Literal, Optional, Union
from fastapi import Query from fastapi import Query
from fastapi.routing import APIRouter, HTTPException from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
from invokeai.backend import BaseModelType, ModelType from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)] MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
models_router = APIRouter(prefix="/v1/models", tags=["models"]) models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -51,12 +51,15 @@ class CreateModelResponse(BaseModel):
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info") info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
status: str = Field(description="The status of the API response") status: str = Field(description="The status of the API response")
class ImportModelRequest(BaseModel):
name: str = Field(description="A model path, repo_id or URL to import")
prediction_type: Optional[Literal['epsilon','v_prediction','sample']] = Field(description='Prediction type for SDv2 checkpoint files')
class ConversionRequest(BaseModel): class ConversionRequest(BaseModel):
name: str = Field(description="The name of the new model") name: str = Field(description="The name of the new model")
info: CkptModelInfo = Field(description="The converted model info") info: CkptModelInfo = Field(description="The converted model info")
save_location: str = Field(description="The path to save the converted model weights") save_location: str = Field(description="The path to save the converted model weights")
class ConvertedModelResponse(BaseModel): class ConvertedModelResponse(BaseModel):
name: str = Field(description="The name of the new model") name: str = Field(description="The name of the new model")
info: DiffusersModelInfo = Field(description="The converted model info") info: DiffusersModelInfo = Field(description="The converted model info")
@ -105,6 +108,28 @@ async def update_model(
return model_response return model_response
@models_router.post(
"/",
operation_id="import_model",
responses={200: {"status": "success"}},
)
async def import_model(
model_request: ImportModelRequest
) -> None:
""" Add Model """
items_to_import = set([model_request.name])
prediction_types = { x.value: x for x in SchedulerPredictionType }
logger = ApiDependencies.invoker.services.logger
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
items_to_import = items_to_import,
prediction_type_helper = lambda x: prediction_types.get(model_request.prediction_type)
)
if len(installed_models) > 0:
logger.info(f'Successfully imported {model_request.name}')
else:
logger.error(f'Model {model_request.name} not imported')
raise HTTPException(status_code=500, detail=f'Model {model_request.name} not imported')
@models_router.delete( @models_router.delete(
"/{model_name}", "/{model_name}",

View File

@ -1,10 +1,11 @@
# InvokeAI nodes for ControlNet image preprocessors # Invocations for ControlNet image preprocessors
# initial implementation by Gregg Helt, 2023 # initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux # heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import float, bool from builtins import float, bool
import cv2
import numpy as np import numpy as np
from typing import Literal, Optional, Union, List from typing import Literal, Optional, Union, List, Dict
from PIL import Image, ImageFilter, ImageOps from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
@ -29,8 +30,13 @@ from controlnet_aux import (
ContentShuffleDetector, ContentShuffleDetector,
ZoeDetector, ZoeDetector,
MediapipeFaceDetector, MediapipeFaceDetector,
SamDetector,
LeresDetector,
) )
from controlnet_aux.util import HWC3, ade_palette
from .image import ImageOutput, PILInvocationConfig from .image import ImageOutput, PILInvocationConfig
CONTROLNET_DEFAULT_MODELS = [ CONTROLNET_DEFAULT_MODELS = [
@ -95,6 +101,9 @@ CONTROLNET_DEFAULT_MODELS = [
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)] CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])] CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])]
# crop and fill options not ready yet
# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])]
class ControlField(BaseModel): class ControlField(BaseModel):
image: ImageField = Field(default=None, description="The control image") image: ImageField = Field(default=None, description="The control image")
@ -105,7 +114,8 @@ class ControlField(BaseModel):
description="When the ControlNet is first applied (% of total steps)") description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1, end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)") description="When the ControlNet is last applied (% of total steps)")
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The contorl mode to use") control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
# resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@validator("control_weight") @validator("control_weight")
def abs_le_one(cls, v): def abs_le_one(cls, v):
@ -180,7 +190,7 @@ class ControlNetInvocation(BaseInvocation):
), ),
) )
# TODO: move image processors to separate file (image_analysis.py
class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig): class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
"""Base class for invocations that preprocess images for ControlNet""" """Base class for invocations that preprocess images for ControlNet"""
@ -452,6 +462,104 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
# fmt: on # fmt: on
def run_processor(self, image): def run_processor(self, image):
# MediaPipeFaceDetector throws an error if image has alpha channel
# so convert to RGB if needed
if image.mode == 'RGBA':
image = image.convert('RGB')
mediapipe_face_processor = MediapipeFaceDetector() mediapipe_face_processor = MediapipeFaceDetector()
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence) processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
return processed_image return processed_image
class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies leres processing to image"""
# fmt: off
type: Literal["leres_image_processor"] = "leres_image_processor"
# Inputs
thr_a: float = Field(default=0, description="Leres parameter `thr_a`")
thr_b: float = Field(default=0, description="Leres parameter `thr_b`")
boost: bool = Field(default=False, description="Whether to use boost mode")
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
# fmt: on
def run_processor(self, image):
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
processed_image = leres_processor(image,
thr_a=self.thr_a,
thr_b=self.thr_b,
boost=self.boost,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution)
return processed_image
class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
# fmt: off
type: Literal["tile_image_processor"] = "tile_image_processor"
# Inputs
#res: int = Field(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
down_sampling_rate: float = Field(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
# fmt: on
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
def tile_resample(self,
np_img: np.ndarray,
res=512, # never used?
down_sampling_rate=1.0,
):
np_img = HWC3(np_img)
if down_sampling_rate < 1.1:
return np_img
H, W, C = np_img.shape
H = int(float(H) / float(down_sampling_rate))
W = int(float(W) / float(down_sampling_rate))
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
return np_img
def run_processor(self, img):
np_img = np.array(img, dtype=np.uint8)
processed_np_image = self.tile_resample(np_img,
#res=self.tile_size,
down_sampling_rate=self.down_sampling_rate
)
processed_image = Image.fromarray(processed_np_image)
return processed_image
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies segment anything processing to image"""
# fmt: off
type: Literal["segment_anything_processor"] = "segment_anything_processor"
# fmt: on
def run_processor(self, image):
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
np_img = np.array(image, dtype=np.uint8)
processed_image = segment_anything_processor(np_img)
return processed_image
class SamDetectorReproducibleColors(SamDetector):
# overriding SamDetector.show_anns() method to use reproducible colors for segmentation image
# base class show_anns() method randomizes colors,
# which seems to also lead to non-reproducible image generation
# so using ADE20k color palette instead
def show_anns(self, anns: List[Dict]):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
h, w = anns[0]['segmentation'].shape
final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
palette = ade_palette()
for i, ann in enumerate(sorted_anns):
m = ann['segmentation']
img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8)
# doing modulo just in case number of annotated regions exceeds number of colors in palette
ann_color = palette[i % len(palette)]
img[:, :] = ann_color
final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255)))
return np.array(final_img, dtype=np.uint8)

View File

@ -23,7 +23,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
PostprocessingSettings PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_torch_device, torch_dtype from ...backend.util.devices import torch_dtype
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext) InvocationConfig, InvocationContext)
@ -59,31 +59,12 @@ def build_latents_output(latents_name: str, latents: torch.Tensor):
height=latents.size()[2] * 8, height=latents.size()[2] * 8,
) )
class NoiseOutput(BaseInvocationOutput):
"""Invocation noise output"""
#fmt: off
type: Literal["noise_output"] = "noise_output"
# Inputs
noise: LatentsField = Field(default=None, description="The output noise")
width: int = Field(description="The width of the noise in pixels")
height: int = Field(description="The height of the noise in pixels")
#fmt: on
def build_noise_output(latents_name: str, latents: torch.Tensor):
return NoiseOutput(
noise=LatentsField(latents_name=latents_name),
width=latents.size()[3] * 8,
height=latents.size()[2] * 8,
)
SAMPLER_NAME_VALUES = Literal[ SAMPLER_NAME_VALUES = Literal[
tuple(list(SCHEDULER_MAP.keys())) tuple(list(SCHEDULER_MAP.keys()))
] ]
def get_scheduler( def get_scheduler(
context: InvocationContext, context: InvocationContext,
scheduler_info: ModelInfo, scheduler_info: ModelInfo,
@ -105,62 +86,6 @@ def get_scheduler(
return scheduler return scheduler
def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_channels:int=4, use_mps_noise:bool=False, downsampling_factor:int = 8):
# limit noise to only the diffusion image channels, not the mask channels
input_channels = min(latent_channels, 4)
use_device = "cpu" if (use_mps_noise or device.type == "mps") else device
generator = torch.Generator(device=use_device).manual_seed(seed)
x = torch.randn(
[
1,
input_channels,
height // downsampling_factor,
width // downsampling_factor,
],
dtype=torch_dtype(device),
device=use_device,
generator=generator,
).to(device)
# if self.perlin > 0.0:
# perlin_noise = self.get_perlin_noise(
# width // self.downsampling_factor, height // self.downsampling_factor
# )
# x = (1 - self.perlin) * x + self.perlin * perlin_noise
return x
class NoiseInvocation(BaseInvocation):
"""Generates latent noise."""
type: Literal["noise"] = "noise"
# Inputs
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use", default_factory=get_random_seed)
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting noise", )
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting noise", )
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "noise"],
},
}
@validator("seed", pre=True)
def modulo_seed(cls, v):
"""Returns the seed modulo SEED_MAX to ensure it is within the valid range."""
return v % SEED_MAX
def invoke(self, context: InvocationContext) -> NoiseOutput:
device = torch.device(choose_torch_device())
noise = get_noise(self.width, self.height, device, self.seed)
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.save(name, noise)
return build_noise_output(latents_name=name, latents=noise)
# Text to image # Text to image
class TextToLatentsInvocation(BaseInvocation): class TextToLatentsInvocation(BaseInvocation):
"""Generates latents from conditionings.""" """Generates latents from conditionings."""

View File

@ -73,7 +73,7 @@ class PipelineModelLoaderInvocation(BaseInvocation):
base_model = self.model.base_model base_model = self.model.base_model
model_name = self.model.model_name model_name = self.model.model_name
model_type = ModelType.Pipeline model_type = ModelType.Main
# TODO: not found exceptions # TODO: not found exceptions
if not context.services.model_manager.model_exists( if not context.services.model_manager.model_exists(

View File

@ -0,0 +1,134 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
import math
from typing import Literal
from pydantic import Field, validator
import torch
from invokeai.app.invocations.latent import LatentsField
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from ...backend.util.devices import choose_torch_device, torch_dtype
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationConfig,
InvocationContext,
)
"""
Utilities
"""
def get_noise(
width: int,
height: int,
device: torch.device,
seed: int = 0,
latent_channels: int = 4,
downsampling_factor: int = 8,
use_cpu: bool = True,
perlin: float = 0.0,
):
"""Generate noise for a given image size."""
noise_device_type = "cpu" if (use_cpu or device.type == "mps") else device.type
# limit noise to only the diffusion image channels, not the mask channels
input_channels = min(latent_channels, 4)
generator = torch.Generator(device=noise_device_type).manual_seed(seed)
noise_tensor = torch.randn(
[
1,
input_channels,
height // downsampling_factor,
width // downsampling_factor,
],
dtype=torch_dtype(device),
device=noise_device_type,
generator=generator,
).to(device)
return noise_tensor
"""
Nodes
"""
class NoiseOutput(BaseInvocationOutput):
"""Invocation noise output"""
# fmt: off
type: Literal["noise_output"] = "noise_output"
# Inputs
noise: LatentsField = Field(default=None, description="The output noise")
width: int = Field(description="The width of the noise in pixels")
height: int = Field(description="The height of the noise in pixels")
# fmt: on
def build_noise_output(latents_name: str, latents: torch.Tensor):
return NoiseOutput(
noise=LatentsField(latents_name=latents_name),
width=latents.size()[3] * 8,
height=latents.size()[2] * 8,
)
class NoiseInvocation(BaseInvocation):
"""Generates latent noise."""
type: Literal["noise"] = "noise"
# Inputs
seed: int = Field(
ge=0,
le=SEED_MAX,
description="The seed to use",
default_factory=get_random_seed,
)
width: int = Field(
default=512,
multiple_of=8,
gt=0,
description="The width of the resulting noise",
)
height: int = Field(
default=512,
multiple_of=8,
gt=0,
description="The height of the resulting noise",
)
use_cpu: bool = Field(
default=True,
description="Use CPU for noise generation (for reproducible results across platforms)",
)
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "noise"],
},
}
@validator("seed", pre=True)
def modulo_seed(cls, v):
"""Returns the seed modulo SEED_MAX to ensure it is within the valid range."""
return v % SEED_MAX
def invoke(self, context: InvocationContext) -> NoiseOutput:
noise = get_noise(
width=self.width,
height=self.height,
device=choose_torch_device(),
seed=self.seed,
use_cpu=self.use_cpu,
)
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, noise)
return build_noise_output(latents_name=name, latents=noise)

View File

@ -133,20 +133,19 @@ class StepParamEasingInvocation(BaseInvocation):
postlist = list(num_poststeps * [self.post_end_value]) postlist = list(num_poststeps * [self.post_end_value])
if log_diagnostics: if log_diagnostics:
logger = InvokeAILogger.getLogger(name="StepParamEasing") context.services.logger.debug("start_step: " + str(start_step))
logger.debug("start_step: " + str(start_step)) context.services.logger.debug("end_step: " + str(end_step))
logger.debug("end_step: " + str(end_step)) context.services.logger.debug("num_easing_steps: " + str(num_easing_steps))
logger.debug("num_easing_steps: " + str(num_easing_steps)) context.services.logger.debug("num_presteps: " + str(num_presteps))
logger.debug("num_presteps: " + str(num_presteps)) context.services.logger.debug("num_poststeps: " + str(num_poststeps))
logger.debug("num_poststeps: " + str(num_poststeps)) context.services.logger.debug("prelist size: " + str(len(prelist)))
logger.debug("prelist size: " + str(len(prelist))) context.services.logger.debug("postlist size: " + str(len(postlist)))
logger.debug("postlist size: " + str(len(postlist))) context.services.logger.debug("prelist: " + str(prelist))
logger.debug("prelist: " + str(prelist)) context.services.logger.debug("postlist: " + str(postlist))
logger.debug("postlist: " + str(postlist))
easing_class = EASING_FUNCTIONS_MAP[self.easing] easing_class = EASING_FUNCTIONS_MAP[self.easing]
if log_diagnostics: if log_diagnostics:
logger.debug("easing class: " + str(easing_class)) context.services.logger.debug("easing class: " + str(easing_class))
easing_list = list() easing_list = list()
if self.mirror: # "expected" mirroring if self.mirror: # "expected" mirroring
# if number of steps is even, squeeze duration down to (number_of_steps)/2 # if number of steps is even, squeeze duration down to (number_of_steps)/2
@ -156,7 +155,7 @@ class StepParamEasingInvocation(BaseInvocation):
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always # but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
base_easing_duration = int(np.ceil(num_easing_steps/2.0)) base_easing_duration = int(np.ceil(num_easing_steps/2.0))
if log_diagnostics: logger.debug("base easing duration: " + str(base_easing_duration)) if log_diagnostics: context.services.logger.debug("base easing duration: " + str(base_easing_duration))
even_num_steps = (num_easing_steps % 2 == 0) # even number of steps even_num_steps = (num_easing_steps % 2 == 0) # even number of steps
easing_function = easing_class(start=self.start_value, easing_function = easing_class(start=self.start_value,
end=self.end_value, end=self.end_value,
@ -166,14 +165,14 @@ class StepParamEasingInvocation(BaseInvocation):
easing_val = easing_function.ease(step_index) easing_val = easing_function.ease(step_index)
base_easing_vals.append(easing_val) base_easing_vals.append(easing_val)
if log_diagnostics: if log_diagnostics:
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val)) context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
if even_num_steps: if even_num_steps:
mirror_easing_vals = list(reversed(base_easing_vals)) mirror_easing_vals = list(reversed(base_easing_vals))
else: else:
mirror_easing_vals = list(reversed(base_easing_vals[0:-1])) mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
if log_diagnostics: if log_diagnostics:
logger.debug("base easing vals: " + str(base_easing_vals)) context.services.logger.debug("base easing vals: " + str(base_easing_vals))
logger.debug("mirror easing vals: " + str(mirror_easing_vals)) context.services.logger.debug("mirror easing vals: " + str(mirror_easing_vals))
easing_list = base_easing_vals + mirror_easing_vals easing_list = base_easing_vals + mirror_easing_vals
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely # FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
@ -206,12 +205,12 @@ class StepParamEasingInvocation(BaseInvocation):
step_val = easing_function.ease(step_index) step_val = easing_function.ease(step_index)
easing_list.append(step_val) easing_list.append(step_val)
if log_diagnostics: if log_diagnostics:
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val)) context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
if log_diagnostics: if log_diagnostics:
logger.debug("prelist size: " + str(len(prelist))) context.services.logger.debug("prelist size: " + str(len(prelist)))
logger.debug("easing_list size: " + str(len(easing_list))) context.services.logger.debug("easing_list size: " + str(len(easing_list)))
logger.debug("postlist size: " + str(len(postlist))) context.services.logger.debug("postlist size: " + str(len(postlist)))
param_list = prelist + easing_list + postlist param_list = prelist + easing_list + postlist

View File

@ -15,7 +15,7 @@ InvokeAI:
conf_path: configs/models.yaml conf_path: configs/models.yaml
legacy_conf_dir: configs/stable-diffusion legacy_conf_dir: configs/stable-diffusion
outdir: outputs outdir: outputs
autoconvert_dir: null autoimport_dir: null
Models: Models:
model: stable-diffusion-1.5 model: stable-diffusion-1.5
embeddings: true embeddings: true
@ -367,16 +367,19 @@ setting environment variables INVOKEAI_<setting>.
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance') always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance') free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
max_loaded_models : int = Field(default=2, gt=0, description="Maximum number of models to keep in memory for rapid switching", category='Memory/Performance') max_loaded_models : int = Field(default=3, gt=0, description="Maximum number of models to keep in memory for rapid switching", category='Memory/Performance')
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance') precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance')
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance') sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance') xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance') tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths') root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.', category='Paths') autoimport_dir : Path = Field(default='autoimport/main', description='Path to a directory of models files to be imported on startup.', category='Paths')
lora_dir : Path = Field(default='autoimport/lora', description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', category='Paths')
embedding_dir : Path = Field(default='autoimport/embedding', description='Path to a directory of Textual Inversion embeddings to be imported on startup.', category='Paths')
controlnet_dir : Path = Field(default='autoimport/controlnet', description='Path to a directory of ControlNet embeddings to be imported on startup.', category='Paths')
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths') conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
models_dir : Path = Field(default='./models', description='Path to the models directory', category='Paths') models_dir : Path = Field(default='models', description='Path to the models directory', category='Paths')
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths') legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
db_dir : Path = Field(default='databases', description='Path to InvokeAI databases directory', category='Paths') db_dir : Path = Field(default='databases', description='Path to InvokeAI databases directory', category='Paths')
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths') outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')

View File

@ -1,4 +1,5 @@
from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation from ..invocations.latent import LatentsToImageInvocation, TextToLatentsInvocation
from ..invocations.noise import NoiseInvocation
from ..invocations.compel import CompelInvocation from ..invocations.compel import CompelInvocation
from ..invocations.params import ParamIntInvocation from ..invocations.params import ParamIntInvocation
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph

View File

@ -7,8 +7,6 @@
# Coauthor: Kevin Turner http://github.com/keturn # Coauthor: Kevin Turner http://github.com/keturn
# #
import sys import sys
print("Loading Python libraries...\n",file=sys.stderr)
import argparse import argparse
import io import io
import os import os
@ -16,6 +14,7 @@ import shutil
import textwrap import textwrap
import traceback import traceback
import warnings import warnings
import yaml
from argparse import Namespace from argparse import Namespace
from pathlib import Path from pathlib import Path
from shutil import get_terminal_size from shutil import get_terminal_size
@ -25,6 +24,7 @@ from urllib import request
import npyscreen import npyscreen
import transformers import transformers
from diffusers import AutoencoderKL from diffusers import AutoencoderKL
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from huggingface_hub import HfFolder from huggingface_hub import HfFolder
from huggingface_hub import login as hf_hub_login from huggingface_hub import login as hf_hub_login
from omegaconf import OmegaConf from omegaconf import OmegaConf
@ -34,6 +34,8 @@ from transformers import (
CLIPSegForImageSegmentation, CLIPSegForImageSegmentation,
CLIPTextModel, CLIPTextModel,
CLIPTokenizer, CLIPTokenizer,
AutoFeatureExtractor,
BertTokenizerFast,
) )
import invokeai.configs as configs import invokeai.configs as configs
@ -52,12 +54,13 @@ from invokeai.frontend.install.widgets import (
) )
from invokeai.backend.install.legacy_arg_parsing import legacy_parser from invokeai.backend.install.legacy_arg_parsing import legacy_parser
from invokeai.backend.install.model_install_backend import ( from invokeai.backend.install.model_install_backend import (
default_dataset, hf_download_from_pretrained,
download_from_hf, InstallSelections,
hf_download_with_resume, ModelInstall,
recommended_datasets,
UserSelections,
) )
from invokeai.backend.model_management.model_probe import (
ModelType, BaseModelType
)
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
@ -81,7 +84,7 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file
# or renaming it and then running invokeai-configure again. # or renaming it and then running invokeai-configure again.
""" """
logger=None logger=InvokeAILogger.getLogger()
# -------------------------------------------- # --------------------------------------------
def postscript(errors: None): def postscript(errors: None):
@ -162,75 +165,91 @@ class ProgressBar:
# --------------------------------------------- # ---------------------------------------------
def download_with_progress_bar(model_url: str, model_dest: str, label: str = "the"): def download_with_progress_bar(model_url: str, model_dest: str, label: str = "the"):
try: try:
print(f"Installing {label} model file {model_url}...", end="", file=sys.stderr) logger.info(f"Installing {label} model file {model_url}...")
if not os.path.exists(model_dest): if not os.path.exists(model_dest):
os.makedirs(os.path.dirname(model_dest), exist_ok=True) os.makedirs(os.path.dirname(model_dest), exist_ok=True)
request.urlretrieve( request.urlretrieve(
model_url, model_dest, ProgressBar(os.path.basename(model_dest)) model_url, model_dest, ProgressBar(os.path.basename(model_dest))
) )
print("...downloaded successfully", file=sys.stderr) logger.info("...downloaded successfully")
else: else:
print("...exists", file=sys.stderr) logger.info("...exists")
except Exception: except Exception:
print("...download failed", file=sys.stderr) logger.info("...download failed")
print(f"Error downloading {label} model", file=sys.stderr) logger.info(f"Error downloading {label} model")
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
# --------------------------------------------- def download_conversion_models():
# this will preload the Bert tokenizer fles target_dir = config.root_path / 'models/core/convert'
def download_bert(): kwargs = dict() # for future use
print("Installing bert tokenizer...", file=sys.stderr) try:
logger.info('Downloading core tokenizers and text encoders')
# bert
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=DeprecationWarning)
from transformers import BertTokenizerFast bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs)
bert.save_pretrained(target_dir / 'bert-base-uncased', safe_serialization=True)
download_from_hf(BertTokenizerFast, "bert-base-uncased") # sd-1
repo_id = 'openai/clip-vit-large-patch14'
hf_download_from_pretrained(CLIPTokenizer, repo_id, target_dir / 'clip-vit-large-patch14')
hf_download_from_pretrained(CLIPTextModel, repo_id, target_dir / 'clip-vit-large-patch14')
# sd-2
repo_id = "stabilityai/stable-diffusion-2"
pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs)
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'tokenizer', safe_serialization=True)
# --------------------------------------------- pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
def download_sd1_clip(): pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'text_encoder', safe_serialization=True)
print("Installing SD1 clip model...", file=sys.stderr)
version = "openai/clip-vit-large-patch14"
download_from_hf(CLIPTokenizer, version)
download_from_hf(CLIPTextModel, version)
# VAE
logger.info('Downloading stable diffusion VAE')
vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse', **kwargs)
vae.save_pretrained(target_dir / 'sd-vae-ft-mse', safe_serialization=True)
# --------------------------------------------- # safety checking
def download_sd2_clip(): logger.info('Downloading safety checker')
version = "stabilityai/stable-diffusion-2" repo_id = "CompVis/stable-diffusion-safety-checker"
print("Installing SD2 clip model...", file=sys.stderr) pipeline = AutoFeatureExtractor.from_pretrained(repo_id,**kwargs)
download_from_hf(CLIPTokenizer, version, subfolder="tokenizer") pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
download_from_hf(CLIPTextModel, version, subfolder="text_encoder")
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id,**kwargs)
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
# --------------------------------------------- # ---------------------------------------------
def download_realesrgan(): def download_realesrgan():
print("Installing models from RealESRGAN...", file=sys.stderr) logger.info("Installing models from RealESRGAN...")
model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth" model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
wdn_model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth" wdn_model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth"
model_dest = config.root_path / "models/realesrgan/realesr-general-x4v3.pth" model_dest = config.root_path / "models/core/upscaling/realesrgan/realesr-general-x4v3.pth"
wdn_model_dest = config.root_path / "models/realesrgan/realesr-general-wdn-x4v3.pth" wdn_model_dest = config.root_path / "models/core/upscaling/realesrgan/realesr-general-wdn-x4v3.pth"
download_with_progress_bar(model_url, str(model_dest), "RealESRGAN") download_with_progress_bar(model_url, str(model_dest), "RealESRGAN")
download_with_progress_bar(wdn_model_url, str(wdn_model_dest), "RealESRGANwdn") download_with_progress_bar(wdn_model_url, str(wdn_model_dest), "RealESRGANwdn")
def download_gfpgan(): def download_gfpgan():
print("Installing GFPGAN models...", file=sys.stderr) logger.info("Installing GFPGAN models...")
for model in ( for model in (
[ [
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth", "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
"./models/gfpgan/GFPGANv1.4.pth", "./models/core/face_restoration/gfpgan/GFPGANv1.4.pth",
], ],
[ [
"https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth", "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth",
"./models/gfpgan/weights/detection_Resnet50_Final.pth", "./models/core/face_restoration/gfpgan/weights/detection_Resnet50_Final.pth",
], ],
[ [
"https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth", "https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth",
"./models/gfpgan/weights/parsing_parsenet.pth", "./models/core/face_restoration/gfpgan/weights/parsing_parsenet.pth",
], ],
): ):
model_url, model_dest = model[0], config.root_path / model[1] model_url, model_dest = model[0], config.root_path / model[1]
@ -239,70 +258,32 @@ def download_gfpgan():
# --------------------------------------------- # ---------------------------------------------
def download_codeformer(): def download_codeformer():
print("Installing CodeFormer model file...", file=sys.stderr) logger.info("Installing CodeFormer model file...")
model_url = ( model_url = (
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth" "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
) )
model_dest = config.root_path / "models/codeformer/codeformer.pth" model_dest = config.root_path / "models/core/face_restoration/codeformer/codeformer.pth"
download_with_progress_bar(model_url, str(model_dest), "CodeFormer") download_with_progress_bar(model_url, str(model_dest), "CodeFormer")
# --------------------------------------------- # ---------------------------------------------
def download_clipseg(): def download_clipseg():
print("Installing clipseg model for text-based masking...", file=sys.stderr) logger.info("Installing clipseg model for text-based masking...")
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined" CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
try: try:
download_from_hf(AutoProcessor, CLIPSEG_MODEL) hf_download_from_pretrained(AutoProcessor, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg')
download_from_hf(CLIPSegForImageSegmentation, CLIPSEG_MODEL) hf_download_from_pretrained(CLIPSegForImageSegmentation, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg')
except Exception: except Exception:
print("Error installing clipseg model:") logger.info("Error installing clipseg model:")
print(traceback.format_exc()) logger.info(traceback.format_exc())
# ------------------------------------- def download_support_models():
def download_safety_checker(): download_realesrgan()
print("Installing model for NSFW content detection...", file=sys.stderr) download_gfpgan()
try: download_codeformer()
from diffusers.pipelines.stable_diffusion.safety_checker import ( download_clipseg()
StableDiffusionSafetyChecker, download_conversion_models()
)
from transformers import AutoFeatureExtractor
except ModuleNotFoundError:
print("Error installing NSFW checker model:")
print(traceback.format_exc())
return
safety_model_id = "CompVis/stable-diffusion-safety-checker"
print("AutoFeatureExtractor...", file=sys.stderr)
download_from_hf(AutoFeatureExtractor, safety_model_id)
print("StableDiffusionSafetyChecker...", file=sys.stderr)
download_from_hf(StableDiffusionSafetyChecker, safety_model_id)
# -------------------------------------
def download_vaes():
print("Installing stabilityai VAE...", file=sys.stderr)
try:
# first the diffusers version
repo_id = "stabilityai/sd-vae-ft-mse"
args = dict(
cache_dir=config.cache_dir,
)
if not AutoencoderKL.from_pretrained(repo_id, **args):
raise Exception(f"download of {repo_id} failed")
repo_id = "stabilityai/sd-vae-ft-mse-original"
model_name = "vae-ft-mse-840000-ema-pruned.ckpt"
# next the legacy checkpoint version
if not hf_download_with_resume(
repo_id=repo_id,
model_name=model_name,
model_dir=str(config.root_path / Model_dir / Weights_dir),
):
raise Exception(f"download of {model_name} failed")
except Exception as e:
print(f"Error downloading StabilityAI standard VAE: {str(e)}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
# ------------------------------------- # -------------------------------------
def get_root(root: str = None) -> str: def get_root(root: str = None) -> str:
@ -465,38 +446,18 @@ to allow InvokeAI to download restricted styles & subjects from the "Concept Lib
editable=False, editable=False,
color="CONTROL", color="CONTROL",
) )
self.embedding_dir = self.add_widget_intelligent( self.autoimport_dirs = {}
for description, config_name, path in autoimport_paths(old_opts):
self.autoimport_dirs[config_name] = self.add_widget_intelligent(
npyscreen.TitleFilename, npyscreen.TitleFilename,
name=" Textual Inversion Embeddings:", name=description+':',
value=str(default_embedding_dir()), value=str(path),
select_dir=True, select_dir=True,
must_exist=False, must_exist=False,
use_two_lines=False, use_two_lines=False,
labelColor="GOOD", labelColor="GOOD",
begin_entry_at=32, begin_entry_at=32,
scroll_exit=True, scroll_exit=True
)
self.lora_dir = self.add_widget_intelligent(
npyscreen.TitleFilename,
name=" LoRA and LyCORIS:",
value=str(default_lora_dir()),
select_dir=True,
must_exist=False,
use_two_lines=False,
labelColor="GOOD",
begin_entry_at=32,
scroll_exit=True,
)
self.controlnet_dir = self.add_widget_intelligent(
npyscreen.TitleFilename,
name=" ControlNets:",
value=str(default_controlnet_dir()),
select_dir=True,
must_exist=False,
use_two_lines=False,
labelColor="GOOD",
begin_entry_at=32,
scroll_exit=True,
) )
self.nextrely += 1 self.nextrely += 1
self.add_widget_intelligent( self.add_widget_intelligent(
@ -562,10 +523,6 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
bad_fields.append( bad_fields.append(
f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory." f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory."
) )
if not Path(opt.embedding_dir).parent.exists():
bad_fields.append(
f"The embedding directory does not seem to be valid. Please check that {str(Path(opt.embedding_dir).parent)} is an existing directory."
)
if len(bad_fields) > 0: if len(bad_fields) > 0:
message = "The following problems were detected and must be corrected:\n" message = "The following problems were detected and must be corrected:\n"
for problem in bad_fields: for problem in bad_fields:
@ -585,12 +542,15 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
"max_loaded_models", "max_loaded_models",
"xformers_enabled", "xformers_enabled",
"always_use_cpu", "always_use_cpu",
"embedding_dir",
"lora_dir",
"controlnet_dir",
]: ]:
setattr(new_opts, attr, getattr(self, attr).value) setattr(new_opts, attr, getattr(self, attr).value)
for attr in self.autoimport_dirs:
directory = Path(self.autoimport_dirs[attr].value)
if directory.is_relative_to(config.root_path):
directory = directory.relative_to(config.root_path)
setattr(new_opts, attr, directory)
new_opts.hf_token = self.hf_token.value new_opts.hf_token = self.hf_token.value
new_opts.license_acceptance = self.license_acceptance.value new_opts.license_acceptance = self.license_acceptance.value
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]] new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
@ -607,7 +567,8 @@ class EditOptApplication(npyscreen.NPSAppManaged):
self.program_opts = program_opts self.program_opts = program_opts
self.invokeai_opts = invokeai_opts self.invokeai_opts = invokeai_opts
self.user_cancelled = False self.user_cancelled = False
self.user_selections = default_user_selections(program_opts) self.autoload_pending = True
self.install_selections = default_user_selections(program_opts)
def onStart(self): def onStart(self):
npyscreen.setTheme(npyscreen.Themes.DefaultTheme) npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
@ -642,40 +603,61 @@ def default_startup_options(init_file: Path) -> Namespace:
opts.nsfw_checker = True opts.nsfw_checker = True
return opts return opts
def default_user_selections(program_opts: Namespace) -> UserSelections: def default_user_selections(program_opts: Namespace) -> InstallSelections:
return UserSelections( installer = ModelInstall(config)
install_models=default_dataset() models = installer.all_models()
return InstallSelections(
install_models=[models[installer.default_model()].path or models[installer.default_model()].repo_id]
if program_opts.default_only if program_opts.default_only
else recommended_datasets() else [models[x].path or models[x].repo_id for x in installer.recommended_models()]
if program_opts.yes_to_all if program_opts.yes_to_all
else dict(), else list(),
purge_deleted_models=False, # scan_directory=None,
scan_directory=None, # autoscan_on_startup=None,
autoscan_on_startup=None,
) )
# -------------------------------------
def autoimport_paths(config: InvokeAIAppConfig):
return [
('Checkpoints & diffusers models', 'autoimport_dir', config.root_path / config.autoimport_dir),
('LoRA/LyCORIS models', 'lora_dir', config.root_path / config.lora_dir),
('Controlnet models', 'controlnet_dir', config.root_path / config.controlnet_dir),
('Textual Inversion Embeddings', 'embedding_dir', config.root_path / config.embedding_dir),
]
# ------------------------------------- # -------------------------------------
def initialize_rootdir(root: Path, yes_to_all: bool = False): def initialize_rootdir(root: Path, yes_to_all: bool = False):
print("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **") logger.info("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **")
for name in ( for name in (
"models", "models",
"configs",
"embeddings",
"databases", "databases",
"loras",
"controlnets",
"text-inversion-output", "text-inversion-output",
"text-inversion-training-data", "text-inversion-training-data",
"configs"
): ):
os.makedirs(os.path.join(root, name), exist_ok=True) os.makedirs(os.path.join(root, name), exist_ok=True)
for model_type in ModelType:
Path(root, 'autoimport', model_type.value).mkdir(parents=True, exist_ok=True)
configs_src = Path(configs.__path__[0]) configs_src = Path(configs.__path__[0])
configs_dest = root / "configs" configs_dest = root / "configs"
if not os.path.samefile(configs_src, configs_dest): if not os.path.samefile(configs_src, configs_dest):
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True) shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
dest = root / 'models'
for model_base in BaseModelType:
for model_type in ModelType:
path = dest / model_base.value / model_type.value
path.mkdir(parents=True, exist_ok=True)
path = dest / 'core'
path.mkdir(parents=True, exist_ok=True)
with open(root / 'configs' / 'models.yaml','w') as yaml_file:
yaml_file.write(yaml.dump({'__metadata__':
{'version':'3.0.0'}
}
)
)
# ------------------------------------- # -------------------------------------
def run_console_ui( def run_console_ui(
@ -699,7 +681,7 @@ def run_console_ui(
if editApp.user_cancelled: if editApp.user_cancelled:
return (None, None) return (None, None)
else: else:
return (editApp.new_opts, editApp.user_selections) return (editApp.new_opts, editApp.install_selections)
# ------------------------------------- # -------------------------------------
@ -722,18 +704,6 @@ def write_opts(opts: Namespace, init_file: Path):
def default_output_dir() -> Path: def default_output_dir() -> Path:
return config.root_path / "outputs" return config.root_path / "outputs"
# -------------------------------------
def default_embedding_dir() -> Path:
return config.root_path / "embeddings"
# -------------------------------------
def default_lora_dir() -> Path:
return config.root_path / "loras"
# -------------------------------------
def default_controlnet_dir() -> Path:
return config.root_path / "controlnets"
# ------------------------------------- # -------------------------------------
def write_default_options(program_opts: Namespace, initfile: Path): def write_default_options(program_opts: Namespace, initfile: Path):
opt = default_startup_options(initfile) opt = default_startup_options(initfile)
@ -758,13 +728,41 @@ def migrate_init_file(legacy_format:Path):
new.nsfw_checker = old.safety_checker new.nsfw_checker = old.safety_checker
new.xformers_enabled = old.xformers new.xformers_enabled = old.xformers
new.conf_path = old.conf new.conf_path = old.conf
new.embedding_dir = old.embedding_path new.root = legacy_format.parent.resolve()
invokeai_yaml = legacy_format.parent / 'invokeai.yaml' invokeai_yaml = legacy_format.parent / 'invokeai.yaml'
with open(invokeai_yaml,"w", encoding="utf-8") as outfile: with open(invokeai_yaml,"w", encoding="utf-8") as outfile:
outfile.write(new.to_yaml()) outfile.write(new.to_yaml())
legacy_format.replace(legacy_format.parent / 'invokeai.init.old') legacy_format.replace(legacy_format.parent / 'invokeai.init.orig')
# -------------------------------------
def migrate_models(root: Path):
from invokeai.backend.install.migrate_to_3 import do_migrate
do_migrate(root, root)
def migrate_if_needed(opt: Namespace, root: Path)->bool:
# We check for to see if the runtime directory is correctly initialized.
old_init_file = root / 'invokeai.init'
new_init_file = root / 'invokeai.yaml'
old_hub = root / 'models/hub'
migration_needed = old_init_file.exists() and not new_init_file.exists() or old_hub.exists()
if migration_needed:
if opt.yes_to_all or \
yes_or_no(f'{str(config.root_path)} appears to be a 2.3 format root directory. Convert to version 3.0?'):
logger.info('** Migrating invokeai.init to invokeai.yaml')
migrate_init_file(old_init_file)
config.parse_args(argv=[],conf=OmegaConf.load(new_init_file))
if old_hub.exists():
migrate_models(config.root_path)
else:
print('Cannot continue without conversion. Aborting.')
return migration_needed
# ------------------------------------- # -------------------------------------
def main(): def main():
@ -831,20 +829,16 @@ def main():
errors = set() errors = set()
try: try:
models_to_download = default_user_selections(opt) # if we do a root migration/upgrade, then we are keeping previous
# configuration and we are done.
# We check for to see if the runtime directory is correctly initialized. if migrate_if_needed(opt, config.root_path):
old_init_file = config.root_path / 'invokeai.init' sys.exit(0)
new_init_file = config.root_path / 'invokeai.yaml'
if old_init_file.exists() and not new_init_file.exists():
print('** Migrating invokeai.init to invokeai.yaml')
migrate_init_file(old_init_file)
# Load new init file into config
config.parse_args(argv=[],conf=OmegaConf.load(new_init_file))
if not config.model_conf_path.exists(): if not config.model_conf_path.exists():
initialize_rootdir(config.root_path, opt.yes_to_all) initialize_rootdir(config.root_path, opt.yes_to_all)
models_to_download = default_user_selections(opt)
new_init_file = config.root_path / 'invokeai.yaml'
if opt.yes_to_all: if opt.yes_to_all:
write_default_options(opt, new_init_file) write_default_options(opt, new_init_file)
init_options = Namespace( init_options = Namespace(
@ -855,29 +849,21 @@ def main():
if init_options: if init_options:
write_opts(init_options, new_init_file) write_opts(init_options, new_init_file)
else: else:
print( logger.info(
'\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n' '\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n'
) )
sys.exit(0) sys.exit(0)
if opt.skip_support_models: if opt.skip_support_models:
print("\n** SKIPPING SUPPORT MODEL DOWNLOADS PER USER REQUEST **") logger.info("SKIPPING SUPPORT MODEL DOWNLOADS PER USER REQUEST")
else: else:
print("\n** CHECKING/UPDATING SUPPORT MODELS **") logger.info("CHECKING/UPDATING SUPPORT MODELS")
download_bert() download_support_models()
download_sd1_clip()
download_sd2_clip()
download_realesrgan()
download_gfpgan()
download_codeformer()
download_clipseg()
download_safety_checker()
download_vaes()
if opt.skip_sd_weights: if opt.skip_sd_weights:
print("\n** SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST **") logger.info("\n** SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST **")
elif models_to_download: elif models_to_download:
print("\n** DOWNLOADING DIFFUSION WEIGHTS **") logger.info("\n** DOWNLOADING DIFFUSION WEIGHTS **")
process_and_execute(opt, models_to_download) process_and_execute(opt, models_to_download)
postscript(errors=errors) postscript(errors=errors)

View File

@ -0,0 +1,581 @@
'''
Migrate the models directory and models.yaml file from an existing
InvokeAI 2.3 installation to 3.0.0.
'''
import io
import os
import argparse
import shutil
import yaml
import transformers
import diffusers
import warnings
from dataclasses import dataclass
from pathlib import Path
from omegaconf import OmegaConf, DictConfig
from typing import Union
from diffusers import StableDiffusionPipeline, AutoencoderKL
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import (
CLIPTextModel,
CLIPTokenizer,
AutoFeatureExtractor,
BertTokenizerFast,
)
import invokeai.backend.util.logging as logger
from invokeai.backend.model_management import ModelManager
from invokeai.backend.model_management.model_probe import (
ModelProbe, ModelType, BaseModelType, SchedulerPredictionType, ModelProbeInfo
)
warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error()
diffusers.logging.set_verbosity_error()
# holder for paths that we will migrate
@dataclass
class ModelPaths:
models: Path
embeddings: Path
loras: Path
controlnets: Path
class MigrateTo3(object):
def __init__(self,
root_directory: Path,
dest_models: Path,
yaml_file: io.TextIOBase,
src_paths: ModelPaths,
):
self.root_directory = root_directory
self.dest_models = dest_models
self.dest_yaml = yaml_file
self.model_names = set()
self.src_paths = src_paths
self._initialize_yaml()
def _initialize_yaml(self):
self.dest_yaml.write(
yaml.dump(
{
'__metadata__':
{
'version':'3.0.0'}
}
)
)
def unique_name(self,name,info)->str:
'''
Create a unique name for a model for use within models.yaml.
'''
done = False
key = ModelManager.create_key(name,info.base_type,info.model_type)
unique_name = key
counter = 1
while not done:
if unique_name in self.model_names:
unique_name = f'{key}-{counter:0>2d}'
counter += 1
else:
done = True
self.model_names.add(unique_name)
name,_,_ = ModelManager.parse_key(unique_name)
return name
def create_directory_structure(self):
'''
Create the basic directory structure for the models folder.
'''
for model_base in [BaseModelType.StableDiffusion1,BaseModelType.StableDiffusion2]:
for model_type in [ModelType.Main, ModelType.Vae, ModelType.Lora,
ModelType.ControlNet,ModelType.TextualInversion]:
path = self.dest_models / model_base.value / model_type.value
path.mkdir(parents=True, exist_ok=True)
path = self.dest_models / 'core'
path.mkdir(parents=True, exist_ok=True)
@staticmethod
def copy_file(src:Path,dest:Path):
'''
copy a single file with logging
'''
if dest.exists():
logger.info(f'Skipping existing {str(dest)}')
return
logger.info(f'Copying {str(src)} to {str(dest)}')
try:
shutil.copy(src, dest)
except Exception as e:
logger.error(f'COPY FAILED: {str(e)}')
@staticmethod
def copy_dir(src:Path,dest:Path):
'''
Recursively copy a directory with logging
'''
if dest.exists():
logger.info(f'Skipping existing {str(dest)}')
return
logger.info(f'Copying {str(src)} to {str(dest)}')
try:
shutil.copytree(src, dest)
except Exception as e:
logger.error(f'COPY FAILED: {str(e)}')
def migrate_models(self, src_dir: Path):
'''
Recursively walk through src directory, probe anything
that looks like a model, and copy the model into the
appropriate location within the destination models directory.
'''
for root, dirs, files in os.walk(src_dir):
for f in files:
# hack - don't copy raw learned_embeds.bin, let them
# be copied as part of a tree copy operation
if f == 'learned_embeds.bin':
continue
try:
model = Path(root,f)
info = ModelProbe().heuristic_probe(model)
if not info:
continue
dest = self._model_probe_to_path(info) / f
self.copy_file(model, dest)
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
for d in dirs:
try:
model = Path(root,d)
info = ModelProbe().heuristic_probe(model)
if not info:
continue
dest = self._model_probe_to_path(info) / model.name
self.copy_dir(model, dest)
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
def migrate_support_models(self):
'''
Copy the clipseg, upscaler, and restoration models to their new
locations.
'''
dest_directory = self.dest_models
if (self.root_directory / 'models/clipseg').exists():
self.copy_dir(self.root_directory / 'models/clipseg', dest_directory / 'core/misc/clipseg')
if (self.root_directory / 'models/realesrgan').exists():
self.copy_dir(self.root_directory / 'models/realesrgan', dest_directory / 'core/upscaling/realesrgan')
for d in ['codeformer','gfpgan']:
path = self.root_directory / 'models' / d
if path.exists():
self.copy_dir(path,dest_directory / f'core/face_restoration/{d}')
def migrate_tuning_models(self):
'''
Migrate the embeddings, loras and controlnets directories to their new homes.
'''
for src in [self.src_paths.embeddings, self.src_paths.loras, self.src_paths.controlnets]:
if not src:
continue
if src.is_dir():
logger.info(f'Scanning {src}')
self.migrate_models(src)
else:
logger.info(f'{src} directory not found; skipping')
continue
def migrate_conversion_models(self):
'''
Migrate all the models that are needed by the ckpt_to_diffusers conversion
script.
'''
dest_directory = self.dest_models
kwargs = dict(
cache_dir = self.root_directory / 'models/hub',
#local_files_only = True
)
try:
logger.info('Migrating core tokenizers and text encoders')
target_dir = dest_directory / 'core' / 'convert'
self._migrate_pretrained(BertTokenizerFast,
repo_id='bert-base-uncased',
dest = target_dir / 'bert-base-uncased',
**kwargs)
# sd-1
repo_id = 'openai/clip-vit-large-patch14'
self._migrate_pretrained(CLIPTokenizer,
repo_id= repo_id,
dest= target_dir / 'clip-vit-large-patch14' / 'tokenizer',
**kwargs)
self._migrate_pretrained(CLIPTextModel,
repo_id = repo_id,
dest = target_dir / 'clip-vit-large-patch14' / 'text_encoder',
**kwargs)
# sd-2
repo_id = "stabilityai/stable-diffusion-2"
self._migrate_pretrained(CLIPTokenizer,
repo_id = repo_id,
dest = target_dir / 'stable-diffusion-2-clip' / 'tokenizer',
**{'subfolder':'tokenizer',**kwargs}
)
self._migrate_pretrained(CLIPTextModel,
repo_id = repo_id,
dest = target_dir / 'stable-diffusion-2-clip' / 'text_encoder',
**{'subfolder':'text_encoder',**kwargs}
)
# VAE
logger.info('Migrating stable diffusion VAE')
self._migrate_pretrained(AutoencoderKL,
repo_id = 'stabilityai/sd-vae-ft-mse',
dest = target_dir / 'sd-vae-ft-mse',
**kwargs)
# safety checking
logger.info('Migrating safety checker')
repo_id = "CompVis/stable-diffusion-safety-checker"
self._migrate_pretrained(AutoFeatureExtractor,
repo_id = repo_id,
dest = target_dir / 'stable-diffusion-safety-checker',
**kwargs)
self._migrate_pretrained(StableDiffusionSafetyChecker,
repo_id = repo_id,
dest = target_dir / 'stable-diffusion-safety-checker',
**kwargs)
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
def write_yaml(self, model_name: str, path:Path, info:ModelProbeInfo, **kwargs):
'''
Write a stanza for a moved model into the new models.yaml file.
'''
name = self.unique_name(model_name, info)
stanza = {
f'{info.base_type.value}/{info.model_type.value}/{name}': {
'name': model_name,
'path': str(path),
'description': f'A {info.base_type.value} {info.model_type.value} model',
'format': info.format,
'image_size': info.image_size,
'base': info.base_type.value,
'variant': info.variant_type.value,
'prediction_type': info.prediction_type.value,
'upcast_attention': info.prediction_type == SchedulerPredictionType.VPrediction,
**kwargs,
}
}
self.dest_yaml.write(yaml.dump(stanza))
self.dest_yaml.flush()
def _model_probe_to_path(self, info: ModelProbeInfo)->Path:
return Path(self.dest_models, info.base_type.value, info.model_type.value)
def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, **kwargs):
if dest.exists():
logger.info(f'Skipping existing {dest}')
return
model = model_class.from_pretrained(repo_id, **kwargs)
self._save_pretrained(model, dest)
def _save_pretrained(self, model, dest: Path):
if dest.exists():
logger.info(f'Skipping existing {dest}')
return
model_name = dest.name
download_path = dest.with_name(f'{model_name}.downloading')
model.save_pretrained(download_path, safe_serialization=True)
download_path.replace(dest)
def _download_vae(self, repo_id: str, subfolder:str=None)->Path:
vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / 'models/hub', subfolder=subfolder)
info = ModelProbe().heuristic_probe(vae)
_, model_name = repo_id.split('/')
dest = self._model_probe_to_path(info) / self.unique_name(model_name, info)
vae.save_pretrained(dest, safe_serialization=True)
return dest
def _vae_path(self, vae: Union[str,dict])->Path:
'''
Convert 2.3 VAE stanza to a straight path.
'''
vae_path = None
# First get a path
if isinstance(vae,str):
vae_path = vae
elif isinstance(vae,DictConfig):
if p := vae.get('path'):
vae_path = p
elif repo_id := vae.get('repo_id'):
if repo_id=='stabilityai/sd-vae-ft-mse': # this guy is already downloaded
vae_path = 'models/core/convert/se-vae-ft-mse'
else:
vae_path = self._download_vae(repo_id, vae.get('subfolder'))
assert vae_path is not None, "Couldn't find VAE for this model"
# if the VAE is in the old models directory, then we must move it into the new
# one. VAEs outside of this directory can stay where they are.
vae_path = Path(vae_path)
if vae_path.is_relative_to(self.src_paths.models):
info = ModelProbe().heuristic_probe(vae_path)
dest = self._model_probe_to_path(info) / vae_path.name
if not dest.exists():
self.copy_dir(vae_path,dest)
vae_path = dest
if vae_path.is_relative_to(self.dest_models):
rel_path = vae_path.relative_to(self.dest_models)
return Path('models',rel_path)
else:
return vae_path
def migrate_repo_id(self, repo_id: str, model_name :str=None, **extra_config):
'''
Migrate a locally-cached diffusers pipeline identified with a repo_id
'''
dest_dir = self.dest_models
cache = self.root_directory / 'models/hub'
kwargs = dict(
cache_dir = cache,
safety_checker = None,
# local_files_only = True,
)
owner,repo_name = repo_id.split('/')
model_name = model_name or repo_name
model = cache / '--'.join(['models',owner,repo_name])
if len(list(model.glob('snapshots/**/model_index.json')))==0:
return
revisions = [x.name for x in model.glob('refs/*')]
# if an fp16 is available we use that
revision = 'fp16' if len(revisions) > 1 and 'fp16' in revisions else revisions[0]
pipeline = StableDiffusionPipeline.from_pretrained(
repo_id,
revision=revision,
**kwargs)
info = ModelProbe().heuristic_probe(pipeline)
if not info:
return
dest = self._model_probe_to_path(info) / repo_name
self._save_pretrained(pipeline, dest)
rel_path = Path('models',dest.relative_to(dest_dir))
self.write_yaml(model_name, path=rel_path, info=info, **extra_config)
def migrate_path(self, location: Path, model_name: str=None, **extra_config):
'''
Migrate a model referred to using 'weights' or 'path'
'''
# handle relative paths
dest_dir = self.dest_models
location = self.root_directory / location
info = ModelProbe().heuristic_probe(location)
if not info:
return
# uh oh, weights is in the old models directory - move it into the new one
if Path(location).is_relative_to(self.src_paths.models):
dest = Path(dest_dir, info.base_type.value, info.model_type.value, location.name)
self.copy_dir(location,dest)
location = Path('models', info.base_type.value, info.model_type.value, location.name)
model_name = model_name or location.stem
model_name = self.unique_name(model_name, info)
self.write_yaml(model_name, path=location, info=info, **extra_config)
def migrate_defined_models(self):
'''
Migrate models defined in models.yaml
'''
# find any models referred to in old models.yaml
conf = OmegaConf.load(self.root_directory / 'configs/models.yaml')
for model_name, stanza in conf.items():
try:
passthru_args = {}
if vae := stanza.get('vae'):
try:
passthru_args['vae'] = str(self._vae_path(vae))
except Exception as e:
logger.warning(f'Could not find a VAE matching "{vae}" for model "{model_name}"')
logger.warning(str(e))
if config := stanza.get('config'):
passthru_args['config'] = config
if repo_id := stanza.get('repo_id'):
logger.info(f'Migrating diffusers model {model_name}')
self.migrate_repo_id(repo_id, model_name, **passthru_args)
elif location := stanza.get('weights'):
logger.info(f'Migrating checkpoint model {model_name}')
self.migrate_path(Path(location), model_name, **passthru_args)
elif location := stanza.get('path'):
logger.info(f'Migrating diffusers model {model_name}')
self.migrate_path(Path(location), model_name, **passthru_args)
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
def migrate(self):
self.create_directory_structure()
# the configure script is doing this
self.migrate_support_models()
self.migrate_conversion_models()
self.migrate_tuning_models()
self.migrate_defined_models()
def _parse_legacy_initfile(root: Path, initfile: Path)->ModelPaths:
'''
Returns tuple of (embedding_path, lora_path, controlnet_path)
'''
parser = argparse.ArgumentParser(fromfile_prefix_chars='@')
parser.add_argument(
'--embedding_directory',
'--embedding_path',
type=Path,
dest='embedding_path',
default=Path('embeddings'),
)
parser.add_argument(
'--lora_directory',
dest='lora_path',
type=Path,
default=Path('loras'),
)
opt,_ = parser.parse_known_args([f'@{str(initfile)}'])
return ModelPaths(
models = root / 'models',
embeddings = root / str(opt.embedding_path).strip('"'),
loras = root / str(opt.lora_path).strip('"'),
controlnets = root / 'controlnets',
)
def _parse_legacy_yamlfile(root: Path, initfile: Path)->ModelPaths:
'''
Returns tuple of (embedding_path, lora_path, controlnet_path)
'''
# Don't use the config object because it is unforgiving of version updates
# Just use omegaconf directly
opt = OmegaConf.load(initfile)
paths = opt.InvokeAI.Paths
models = paths.get('models_dir','models')
embeddings = paths.get('embedding_dir','embeddings')
loras = paths.get('lora_dir','loras')
controlnets = paths.get('controlnet_dir','controlnets')
return ModelPaths(
models = root / models,
embeddings = root / embeddings,
loras = root /loras,
controlnets = root / controlnets,
)
def get_legacy_embeddings(root: Path) -> ModelPaths:
path = root / 'invokeai.init'
if path.exists():
return _parse_legacy_initfile(root, path)
path = root / 'invokeai.yaml'
if path.exists():
return _parse_legacy_yamlfile(root, path)
def do_migrate(src_directory: Path, dest_directory: Path):
dest_models = dest_directory / 'models-3.0'
dest_yaml = dest_directory / 'configs/models.yaml-3.0'
paths = get_legacy_embeddings(src_directory)
with open(dest_yaml,'w') as yaml_file:
migrator = MigrateTo3(src_directory,
dest_models,
yaml_file,
src_paths = paths,
)
migrator.migrate()
shutil.rmtree(dest_directory / 'models.orig', ignore_errors=True)
(dest_directory / 'models').replace(dest_directory / 'models.orig')
dest_models.replace(dest_directory / 'models')
(dest_directory /'configs/models.yaml').replace(dest_directory / 'configs/models.yaml.orig')
dest_yaml.replace(dest_directory / 'configs/models.yaml')
print(f"""Migration successful.
Original models directory moved to {dest_directory}/models.orig
Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig
""")
def main():
parser = argparse.ArgumentParser(prog="invokeai-migrate3",
description="""
This will copy and convert the models directory and the configs/models.yaml from the InvokeAI 2.3 format
'--from-directory' root to the InvokeAI 3.0 '--to-directory' root. These may be abbreviated '--from' and '--to'.a
The old models directory and config file will be renamed 'models.orig' and 'models.yaml.orig' respectively.
It is safe to provide the same directory for both arguments, but it is better to use the invokeai_configure
script, which will perform a full upgrade in place."""
)
parser.add_argument('--from-directory',
dest='root_directory',
type=Path,
required=True,
help='Source InvokeAI 2.3 root directory (containing "invokeai.init" or "invokeai.yaml")'
)
parser.add_argument('--to-directory',
dest='dest_directory',
type=Path,
required=True,
help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")'
)
# TO DO: Implement full directory scanning
# parser.add_argument('--all-models',
# action="store_true",
# help='Migrate all models found in `models` directory, not just those mentioned in models.yaml',
# )
args = parser.parse_args()
root_directory = args.root_directory
assert root_directory.is_dir(), f"{root_directory} is not a valid directory"
assert (root_directory / 'models').is_dir(), f"{root_directory} does not contain a 'models' subdirectory"
assert (root_directory / 'invokeai.init').exists() or (root_directory / 'invokeai.yaml').exists(), f"{root_directory} does not contain an InvokeAI init file."
dest_directory = args.dest_directory
assert dest_directory.is_dir(), f"{dest_directory} is not a valid directory"
assert (dest_directory / 'models').is_dir(), f"{dest_directory} does not contain a 'models' subdirectory"
assert (dest_directory / 'invokeai.yaml').exists(), f"{dest_directory} does not contain an InvokeAI init file."
do_migrate(root_directory,dest_directory)
if __name__ == '__main__':
main()

View File

@ -2,46 +2,36 @@
Utility (backend) functions used by model_install.py Utility (backend) functions used by model_install.py
""" """
import os import os
import re
import shutil import shutil
import sys
import warnings import warnings
from dataclasses import dataclass,field from dataclasses import dataclass,field
from pathlib import Path from pathlib import Path
from tempfile import TemporaryFile from tempfile import TemporaryDirectory
from typing import List, Dict, Callable from typing import List, Dict, Callable, Union, Set
import requests import requests
from diffusers import AutoencoderKL from diffusers import StableDiffusionPipeline
from huggingface_hub import hf_hub_url, HfFolder from huggingface_hub import hf_hub_url, HfFolder, HfApi
from omegaconf import OmegaConf from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from tqdm import tqdm from tqdm import tqdm
import invokeai.configs as configs import invokeai.configs as configs
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from ..stable_diffusion import StableDiffusionGeneratorPipeline from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
from invokeai.backend.util import download_with_resume
from ..util.logging import InvokeAILogger from ..util.logging import InvokeAILogger
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
# --------------------------globals----------------------- # --------------------------globals-----------------------
config = InvokeAIAppConfig.get_config() config = InvokeAIAppConfig.get_config()
logger = InvokeAILogger.getLogger(name='InvokeAI')
Model_dir = "models"
Weights_dir = "ldm/stable-diffusion-v1/"
# the initial "configs" dir is now bundled in the `invokeai.configs` package # the initial "configs" dir is now bundled in the `invokeai.configs` package
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml" Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
# initial models omegaconf
Datasets = None
# logger
logger = InvokeAILogger.getLogger(name='InvokeAI')
Config_preamble = """ Config_preamble = """
# This file describes the alternative machine learning models # This file describes the alternative machine learning models
# available to InvokeAI script. # available to InvokeAI script.
@ -52,6 +42,24 @@ Config_preamble = """
# was trained on. # was trained on.
""" """
LEGACY_CONFIGS = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: 'v1-inference.yaml',
ModelVariantType.Inpaint: 'v1-inpainting-inference.yaml',
},
BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: {
SchedulerPredictionType.Epsilon: 'v2-inference.yaml',
SchedulerPredictionType.VPrediction: 'v2-inference-v.yaml',
},
ModelVariantType.Inpaint: {
SchedulerPredictionType.Epsilon: 'v2-inpainting-inference.yaml',
SchedulerPredictionType.VPrediction: 'v2-inpainting-inference-v.yaml',
}
}
}
@dataclass @dataclass
class ModelInstallList: class ModelInstallList:
'''Class for listing models to be installed/removed''' '''Class for listing models to be installed/removed'''
@ -59,133 +67,321 @@ class ModelInstallList:
remove_models: List[str] = field(default_factory=list) remove_models: List[str] = field(default_factory=list)
@dataclass @dataclass
class UserSelections(): class InstallSelections():
install_models: List[str]= field(default_factory=list) install_models: List[str]= field(default_factory=list)
remove_models: List[str]=field(default_factory=list) remove_models: List[str]=field(default_factory=list)
purge_deleted_models: bool=field(default_factory=list) # scan_directory: Path = None
install_cn_models: List[str] = field(default_factory=list) # autoscan_on_startup: bool=False
remove_cn_models: List[str] = field(default_factory=list)
install_lora_models: List[str] = field(default_factory=list)
remove_lora_models: List[str] = field(default_factory=list)
install_ti_models: List[str] = field(default_factory=list)
remove_ti_models: List[str] = field(default_factory=list)
scan_directory: Path = None
autoscan_on_startup: bool=False
import_model_paths: str=None
def default_config_file(): @dataclass
return config.model_conf_path class ModelLoadInfo():
name: str
model_type: ModelType
base_type: BaseModelType
path: Path = None
repo_id: str = None
description: str = ''
installed: bool = False
recommended: bool = False
default: bool = False
def sd_configs(): class ModelInstall(object):
return config.legacy_conf_path def __init__(self,
config:InvokeAIAppConfig,
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
model_manager: ModelManager = None,
access_token:str = None):
self.config = config
self.mgr = model_manager or ModelManager(config.model_conf_path)
self.datasets = OmegaConf.load(Dataset_path)
self.prediction_helper = prediction_type_helper
self.access_token = access_token or HfFolder.get_token()
self.reverse_paths = self._reverse_paths(self.datasets)
def initial_models(): def all_models(self)->Dict[str,ModelLoadInfo]:
global Datasets '''
if Datasets: Return dict of model_key=>ModelLoadInfo objects.
return Datasets This method consolidates and simplifies the entries in both
return (Datasets := OmegaConf.load(Dataset_path)['diffusers']) models.yaml and INITIAL_MODELS.yaml so that they can
be treated uniformly. It also sorts the models alphabetically
by their name, to improve the display somewhat.
'''
model_dict = dict()
def install_requested_models( # first populate with the entries in INITIAL_MODELS.yaml
diffusers: ModelInstallList = None, for key, value in self.datasets.items():
controlnet: ModelInstallList = None, name,base,model_type = ModelManager.parse_key(key)
lora: ModelInstallList = None, value['name'] = name
ti: ModelInstallList = None, value['base_type'] = base
cn_model_map: Dict[str,str] = None, # temporary - move to model manager value['model_type'] = model_type
scan_directory: Path = None, model_dict[key] = ModelLoadInfo(**value)
external_models: List[str] = None,
scan_at_startup: bool = False,
precision: str = "float16",
purge_deleted: bool = False,
config_file_path: Path = None,
model_config_file_callback: Callable[[Path],Path] = None
):
"""
Entry point for installing/deleting starter models, or installing external models.
"""
access_token = HfFolder.get_token()
config_file_path = config_file_path or default_config_file()
if not config_file_path.exists():
open(config_file_path, "w")
# prevent circular import here # supplement with entries in models.yaml
from ..model_management import ModelManager installed_models = self.mgr.list_models()
model_manager = ModelManager(OmegaConf.load(config_file_path), precision=precision) for md in installed_models:
if controlnet: base = md['base_model']
model_manager.install_controlnet_models(controlnet.install_models, access_token=access_token) model_type = md['type']
model_manager.delete_controlnet_models(controlnet.remove_models) name = md['name']
key = ModelManager.create_key(name, base, model_type)
if lora: if key in model_dict:
model_manager.install_lora_models(lora.install_models, access_token=access_token) model_dict[key].installed = True
model_manager.delete_lora_models(lora.remove_models)
if ti:
model_manager.install_ti_models(ti.install_models, access_token=access_token)
model_manager.delete_ti_models(ti.remove_models)
if diffusers:
# TODO: Replace next three paragraphs with calls into new model manager
if diffusers.remove_models and len(diffusers.remove_models) > 0:
logger.info("Processing requested deletions")
for model in diffusers.remove_models:
logger.info(f"{model}...")
model_manager.del_model(model, delete_files=purge_deleted)
model_manager.commit(config_file_path)
if diffusers.install_models and len(diffusers.install_models) > 0:
logger.info("Installing requested models")
downloaded_paths = download_weight_datasets(
models=diffusers.install_models,
access_token=None,
precision=precision,
)
successful = {x:v for x,v in downloaded_paths.items() if v is not None}
if len(successful) > 0:
update_config_file(successful, config_file_path)
if len(successful) < len(diffusers.install_models):
unsuccessful = [x for x in downloaded_paths if downloaded_paths[x] is None]
logger.warning(f"Some of the model downloads were not successful: {unsuccessful}")
# due to above, we have to reload the model manager because conf file
# was changed behind its back
model_manager = ModelManager(OmegaConf.load(config_file_path), precision=precision)
external_models = external_models or list()
if scan_directory:
external_models.append(str(scan_directory))
if len(external_models) > 0:
logger.info("INSTALLING EXTERNAL MODELS")
for path_url_or_repo in external_models:
try:
logger.debug(f'In install_requested_models; callback = {model_config_file_callback}')
model_manager.heuristic_import(
path_url_or_repo,
commit_to_conf=config_file_path,
config_file_callback = model_config_file_callback,
)
except KeyboardInterrupt:
sys.exit(-1)
except Exception:
pass
if scan_at_startup and scan_directory.is_dir():
update_autoconvert_dir(scan_directory)
else: else:
update_autoconvert_dir(None) model_dict[key] = ModelLoadInfo(
name = name,
base_type = base,
model_type = model_type,
path = value.get('path'),
installed = True,
)
return {x : model_dict[x] for x in sorted(model_dict.keys(),key=lambda y: model_dict[y].name.lower())}
def update_autoconvert_dir(autodir: Path): def starter_models(self)->Set[str]:
''' models = set()
Update the "autoconvert_dir" option in invokeai.yaml for key, value in self.datasets.items():
''' name,base,model_type = ModelManager.parse_key(key)
invokeai_config_path = config.init_file_path if model_type==ModelType.Main:
conf = OmegaConf.load(invokeai_config_path) models.add(key)
conf.InvokeAI.Paths.autoconvert_dir = str(autodir) if autodir else None return models
yaml = OmegaConf.to_yaml(conf)
tmpfile = invokeai_config_path.parent / "new_config.tmp"
with open(tmpfile, "w", encoding="utf-8") as outfile:
outfile.write(yaml)
tmpfile.replace(invokeai_config_path)
def recommended_models(self)->Set[str]:
starters = self.starter_models()
return set([x for x in starters if self.datasets[x].get('recommended',False)])
def default_model(self)->str:
starters = self.starter_models()
defaults = [x for x in starters if self.datasets[x].get('default',False)]
return defaults[0]
def install(self, selections: InstallSelections):
job = 1
jobs = len(selections.remove_models) + len(selections.install_models)
# remove requested models
for key in selections.remove_models:
name,base,mtype = self.mgr.parse_key(key)
logger.info(f'Deleting {mtype} model {name} [{job}/{jobs}]')
self.mgr.del_model(name,base,mtype)
job += 1
# add requested models
for path in selections.install_models:
logger.info(f'Installing {path} [{job}/{jobs}]')
self.heuristic_install(path)
job += 1
self.mgr.commit()
def heuristic_install(self,
model_path_id_or_url: Union[str,Path],
models_installed: Set[Path]=None)->Set[Path]:
if not models_installed:
models_installed = set()
# A little hack to allow nested routines to retrieve info on the requested ID
self.current_id = model_path_id_or_url
path = Path(model_path_id_or_url)
try:
# checkpoint file, or similar
if path.is_file():
models_installed.add(self._install_path(path))
# folders style or similar
elif path.is_dir() and any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
models_installed.add(self._install_path(path))
# recursive scan
elif path.is_dir():
for child in path.iterdir():
self.heuristic_install(child, models_installed=models_installed)
# huggingface repo
elif len(str(path).split('/')) == 2:
models_installed.add(self._install_repo(str(path)))
# a URL
elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")):
models_installed.add(self._install_url(model_path_id_or_url))
else:
logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
except ValueError as e:
logger.error(str(e))
return models_installed
# install a model from a local path. The optional info parameter is there to prevent
# the model from being probed twice in the event that it has already been probed.
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Path:
try:
# logger.debug(f'Probing {path}')
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
model_name = path.stem if info.format=='checkpoint' else path.name
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
raise ValueError(f'A model named "{model_name}" is already installed.')
attributes = self._make_attributes(path,info)
self.mgr.add_model(model_name = model_name,
base_model = info.base_type,
model_type = info.model_type,
model_attributes = attributes,
)
except Exception as e:
logger.warning(f'{str(e)} Skipping registration.')
return path
def _install_url(self, url: str)->Path:
# copy to a staging area, probe, import and delete
with TemporaryDirectory(dir=self.config.models_path) as staging:
location = download_with_resume(url,Path(staging))
if not location:
logger.error(f'Unable to download {url}. Skipping.')
info = ModelProbe().heuristic_probe(location)
dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name
models_path = shutil.move(location,dest)
# staged version will be garbage-collected at this time
return self._install_path(Path(models_path), info)
def _install_repo(self, repo_id: str)->Path:
hinfo = HfApi().model_info(repo_id)
# we try to figure out how to download this most economically
# list all the files in the repo
files = [x.rfilename for x in hinfo.siblings]
location = None
with TemporaryDirectory(dir=self.config.models_path) as staging:
staging = Path(staging)
if 'model_index.json' in files:
location = self._download_hf_pipeline(repo_id, staging) # pipeline
else:
for suffix in ['safetensors','bin']:
if f'pytorch_lora_weights.{suffix}' in files:
location = self._download_hf_model(repo_id, ['pytorch_lora_weights.bin'], staging) # LoRA
break
elif self.config.precision=='float16' and f'diffusion_pytorch_model.fp16.{suffix}' in files: # vae, controlnet or some other standalone
files = ['config.json', f'diffusion_pytorch_model.fp16.{suffix}']
location = self._download_hf_model(repo_id, files, staging)
break
elif f'diffusion_pytorch_model.{suffix}' in files:
files = ['config.json', f'diffusion_pytorch_model.{suffix}']
location = self._download_hf_model(repo_id, files, staging)
break
elif f'learned_embeds.{suffix}' in files:
location = self._download_hf_model(repo_id, ['learned_embeds.suffix'], staging)
break
if not location:
logger.warning(f'Could not determine type of repo {repo_id}. Skipping install.')
return
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
if not info:
logger.warning(f'Could not probe {location}. Skipping install.')
return
dest = self.config.models_path / info.base_type.value / info.model_type.value / self._get_model_name(repo_id,location)
if dest.exists():
shutil.rmtree(dest)
shutil.copytree(location,dest)
return self._install_path(dest, info)
def _get_model_name(self,path_name: str, location: Path)->str:
'''
Calculate a name for the model - primitive implementation.
'''
if key := self.reverse_paths.get(path_name):
(name, base, mtype) = ModelManager.parse_key(key)
return name
else:
return location.stem
def _make_attributes(self, path: Path, info: ModelProbeInfo)->dict:
model_name = path.name if path.is_dir() else path.stem
description = f'{info.base_type.value} {info.model_type.value} model {model_name}'
if key := self.reverse_paths.get(self.current_id):
if key in self.datasets:
description = self.datasets[key].get('description') or description
rel_path = self.relative_to_root(path)
attributes = dict(
path = str(rel_path),
description = str(description),
model_format = info.format,
)
if info.model_type == ModelType.Main:
attributes.update(dict(variant = info.variant_type,))
if info.format=="checkpoint":
try:
possible_conf = path.with_suffix('.yaml')
if possible_conf.exists():
legacy_conf = str(self.relative_to_root(possible_conf))
elif info.base_type == BaseModelType.StableDiffusion2:
legacy_conf = Path(self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type])
else:
legacy_conf = Path(self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type])
except KeyError:
legacy_conf = Path(self.config.legacy_conf_dir, 'v1-inference.yaml') # best guess
attributes.update(
dict(
config = str(legacy_conf)
)
)
return attributes
def relative_to_root(self, path: Path)->Path:
root = self.config.root_path
if path.is_relative_to(root):
return path.relative_to(root)
else:
return path
def _download_hf_pipeline(self, repo_id: str, staging: Path)->Path:
'''
This retrieves a StableDiffusion model from cache or remote and then
does a save_pretrained() to the indicated staging area.
'''
_,name = repo_id.split("/")
revisions = ['fp16','main'] if self.config.precision=='float16' else ['main']
model = None
for revision in revisions:
try:
model = StableDiffusionPipeline.from_pretrained(repo_id,revision=revision,safety_checker=None)
except: # most errors are due to fp16 not being present. Fix this to catch other errors
pass
if model:
break
if not model:
logger.error(f'Diffusers model {repo_id} could not be downloaded. Skipping.')
return None
model.save_pretrained(staging / name, safe_serialization=True)
return staging / name
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path)->Path:
_,name = repo_id.split("/")
location = staging / name
paths = list()
for filename in files:
p = hf_download_with_resume(repo_id,
model_dir=location,
model_name=filename,
access_token = self.access_token
)
if p:
paths.append(p)
else:
logger.warning(f'Could not download {filename} from {repo_id}.')
return location if len(paths)>0 else None
@classmethod
def _reverse_paths(cls,datasets)->dict:
'''
Reverse mapping from repo_id/path to destination name.
'''
return {v.get('path') or v.get('repo_id') : k for k, v in datasets.items()}
# ------------------------------------- # -------------------------------------
def yes_or_no(prompt: str, default_yes=True): def yes_or_no(prompt: str, default_yes=True):
@ -197,133 +393,19 @@ def yes_or_no(prompt: str, default_yes=True):
return response[0] in ("y", "Y") return response[0] in ("y", "Y")
# --------------------------------------------- # ---------------------------------------------
def recommended_datasets() -> List['str']: def hf_download_from_pretrained(
datasets = set() model_class: object, model_name: str, destination: Path, **kwargs
for ds in initial_models().keys():
if initial_models()[ds].get("recommended", False):
datasets.add(ds)
return list(datasets)
# ---------------------------------------------
def default_dataset() -> dict:
datasets = set()
for ds in initial_models().keys():
if initial_models()[ds].get("default", False):
datasets.add(ds)
return list(datasets)
# ---------------------------------------------
def all_datasets() -> dict:
datasets = dict()
for ds in initial_models().keys():
datasets[ds] = True
return datasets
# ---------------------------------------------
# look for legacy model.ckpt in models directory and offer to
# normalize its name
def migrate_models_ckpt():
model_path = os.path.join(config.root_dir, Model_dir, Weights_dir)
if not os.path.exists(os.path.join(model_path, "model.ckpt")):
return
new_name = initial_models()["stable-diffusion-1.4"]["file"]
logger.warning(
'The Stable Diffusion v4.1 "model.ckpt" is already installed. The name will be changed to {new_name} to avoid confusion.'
)
logger.warning(f"model.ckpt => {new_name}")
os.replace(
os.path.join(model_path, "model.ckpt"), os.path.join(model_path, new_name)
)
# ---------------------------------------------
def download_weight_datasets(
models: List[str], access_token: str, precision: str = "float32"
):
migrate_models_ckpt()
successful = dict()
for mod in models:
logger.info(f"Downloading {mod}:")
successful[mod] = _download_repo_or_file(
initial_models()[mod], access_token, precision=precision
)
return successful
def _download_repo_or_file(
mconfig: DictConfig, access_token: str, precision: str = "float32"
) -> Path:
path = None
if mconfig["format"] == "ckpt":
path = _download_ckpt_weights(mconfig, access_token)
else:
path = _download_diffusion_weights(mconfig, access_token, precision=precision)
if "vae" in mconfig and "repo_id" in mconfig["vae"]:
_download_diffusion_weights(
mconfig["vae"], access_token, precision=precision
)
return path
def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
repo_id = mconfig["repo_id"]
filename = mconfig["file"]
cache_dir = os.path.join(config.root_dir, Model_dir, Weights_dir)
return hf_download_with_resume(
repo_id=repo_id,
model_dir=cache_dir,
model_name=filename,
access_token=access_token,
)
# ---------------------------------------------
def download_from_hf(
model_class: object, model_name: str, **kwargs
): ):
logger = InvokeAILogger.getLogger('InvokeAI') logger = InvokeAILogger.getLogger('InvokeAI')
logger.addFilter(lambda x: 'fp16 is not a valid' not in x.getMessage()) logger.addFilter(lambda x: 'fp16 is not a valid' not in x.getMessage())
path = config.cache_dir
model = model_class.from_pretrained( model = model_class.from_pretrained(
model_name, model_name,
cache_dir=path,
resume_download=True, resume_download=True,
**kwargs, **kwargs,
) )
model_name = "--".join(("models", *model_name.split("/"))) model.save_pretrained(destination, safe_serialization=True)
return path / model_name if model else None return destination
def _download_diffusion_weights(
mconfig: DictConfig, access_token: str, precision: str = "float32"
):
repo_id = mconfig["repo_id"]
model_class = (
StableDiffusionGeneratorPipeline
if mconfig.get("format", None) == "diffusers"
else AutoencoderKL
)
extra_arg_list = [{"revision": "fp16"}, {}] if precision == "float16" else [{}]
path = None
for extra_args in extra_arg_list:
try:
path = download_from_hf(
model_class,
repo_id,
safety_checker=None,
**extra_args,
)
except OSError as e:
if 'Revision Not Found' in str(e):
pass
else:
logger.error(str(e))
if path:
break
return path
# --------------------------------------------- # ---------------------------------------------
def hf_download_with_resume( def hf_download_with_resume(
@ -383,128 +465,3 @@ def hf_download_with_resume(
return model_dest return model_dest
# ---------------------------------------------
def update_config_file(successfully_downloaded: dict, config_file: Path):
config_file = (
Path(config_file) if config_file is not None else default_config_file()
)
# In some cases (incomplete setup, etc), the default configs directory might be missing.
# Create it if it doesn't exist.
# this check is ignored if opt.config_file is specified - user is assumed to know what they
# are doing if they are passing a custom config file from elsewhere.
if config_file is default_config_file() and not config_file.parent.exists():
configs_src = Dataset_path.parent
configs_dest = default_config_file().parent
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
yaml = new_config_file_contents(successfully_downloaded, config_file)
try:
backup = None
if os.path.exists(config_file):
logger.warning(
f"{config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
)
backup = config_file.with_suffix(".yaml.orig")
## Ugh. Windows is unable to overwrite an existing backup file, raises a WinError 183
if sys.platform == "win32" and backup.is_file():
backup.unlink()
config_file.rename(backup)
with TemporaryFile() as tmp:
tmp.write(Config_preamble.encode())
tmp.write(yaml.encode())
with open(str(config_file.expanduser().resolve()), "wb") as new_config:
tmp.seek(0)
new_config.write(tmp.read())
except Exception as e:
logger.error(f"Error creating config file {config_file}: {str(e)}")
if backup is not None:
logger.info("restoring previous config file")
## workaround, for WinError 183, see above
if sys.platform == "win32" and config_file.is_file():
config_file.unlink()
backup.rename(config_file)
return
logger.info(f"Successfully created new configuration file {config_file}")
# ---------------------------------------------
def new_config_file_contents(
successfully_downloaded: dict,
config_file: Path,
) -> str:
if config_file.exists():
conf = OmegaConf.load(str(config_file.expanduser().resolve()))
else:
conf = OmegaConf.create()
default_selected = None
for model in successfully_downloaded:
# a bit hacky - what we are doing here is seeing whether a checkpoint
# version of the model was previously defined, and whether the current
# model is a diffusers (indicated with a path)
if conf.get(model) and Path(successfully_downloaded[model]).is_dir():
delete_weights(model, conf[model])
stanza = {}
mod = initial_models()[model]
stanza["description"] = mod["description"]
stanza["repo_id"] = mod["repo_id"]
stanza["format"] = mod["format"]
# diffusers don't need width and height (probably .ckpt doesn't either)
# so we no longer require these in INITIAL_MODELS.yaml
if "width" in mod:
stanza["width"] = mod["width"]
if "height" in mod:
stanza["height"] = mod["height"]
if "file" in mod:
stanza["weights"] = os.path.relpath(
successfully_downloaded[model], start=config.root_dir
)
stanza["config"] = os.path.normpath(
os.path.join(sd_configs(), mod["config"])
)
if "vae" in mod:
if "file" in mod["vae"]:
stanza["vae"] = os.path.normpath(
os.path.join(Model_dir, Weights_dir, mod["vae"]["file"])
)
else:
stanza["vae"] = mod["vae"]
if mod.get("default", False):
stanza["default"] = True
default_selected = True
conf[model] = stanza
# if no default model was chosen, then we select the first
# one in the list
if not default_selected:
conf[list(successfully_downloaded.keys())[0]]["default"] = True
return OmegaConf.to_yaml(conf)
# ---------------------------------------------
def delete_weights(model_name: str, conf_stanza: dict):
if not (weights := conf_stanza.get("weights")):
return
if re.match("/VAE/", conf_stanza.get("config")):
return
logger.warning(
f"\nThe checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?"
)
weights = Path(weights)
if not weights.is_absolute():
weights = config.root_dir / weights
try:
weights.unlink()
except OSError as e:
logger.error(str(e))

View File

@ -4,3 +4,4 @@ Initialization file for invokeai.backend.model_management
from .model_manager import ModelManager, ModelInfo from .model_manager import ModelManager, ModelInfo
from .model_cache import ModelCache from .model_cache import ModelCache
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType from .models import BaseModelType, ModelType, SubModelType, ModelVariantType

View File

@ -30,7 +30,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
from .model_manager import ModelManager from .model_manager import ModelManager
from .model_cache import ModelCache from .model_cache import ModelCache
from .models import SchedulerPredictionType, BaseModelType, ModelVariantType from .models import BaseModelType, ModelVariantType
try: try:
from omegaconf import OmegaConf from omegaconf import OmegaConf
@ -73,7 +73,9 @@ from transformers import (
from ..stable_diffusion import StableDiffusionGeneratorPipeline from ..stable_diffusion import StableDiffusionGeneratorPipeline
MODEL_ROOT = None # TODO: redo in future
#CONVERT_MODEL_ROOT = InvokeAIAppConfig.get_config().models_path / "core" / "convert"
CONVERT_MODEL_ROOT = InvokeAIAppConfig.get_config().root_path / "models" / "core" / "convert"
def shave_segments(path, n_shave_prefix_segments=1): def shave_segments(path, n_shave_prefix_segments=1):
""" """
@ -605,7 +607,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
else: else:
vae_state_dict = checkpoint vae_state_dict = checkpoint
new_checkpoint = convert_ldm_vae_state_dict(vae_state_dict,config) new_checkpoint = convert_ldm_vae_state_dict(vae_state_dict, config)
return new_checkpoint return new_checkpoint
def convert_ldm_vae_state_dict(vae_state_dict, config): def convert_ldm_vae_state_dict(vae_state_dict, config):
@ -828,7 +830,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
def convert_ldm_clip_checkpoint(checkpoint): def convert_ldm_clip_checkpoint(checkpoint):
text_model = CLIPTextModel.from_pretrained(MODEL_ROOT / 'clip-vit-large-patch14') text_model = CLIPTextModel.from_pretrained(CONVERT_MODEL_ROOT / 'clip-vit-large-patch14')
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
text_model_dict = {} text_model_dict = {}
@ -882,7 +884,7 @@ textenc_pattern = re.compile("|".join(protected.keys()))
def convert_open_clip_checkpoint(checkpoint): def convert_open_clip_checkpoint(checkpoint):
text_model = CLIPTextModel.from_pretrained( text_model = CLIPTextModel.from_pretrained(
MODEL_ROOT / 'stable-diffusion-2-clip', CONVERT_MODEL_ROOT / 'stable-diffusion-2-clip',
subfolder='text_encoder', subfolder='text_encoder',
) )
@ -949,7 +951,7 @@ def convert_open_clip_checkpoint(checkpoint):
return text_model return text_model
def replace_checkpoint_vae(checkpoint, vae_path:str): def replace_checkpoint_vae(checkpoint, vae_path: str):
if vae_path.endswith(".safetensors"): if vae_path.endswith(".safetensors"):
vae_ckpt = load_file(vae_path) vae_ckpt = load_file(vae_path)
else: else:
@ -959,7 +961,7 @@ def replace_checkpoint_vae(checkpoint, vae_path:str):
new_key = f'first_stage_model.{vae_key}' new_key = f'first_stage_model.{vae_key}'
checkpoint[new_key] = state_dict[vae_key] checkpoint[new_key] = state_dict[vae_key]
def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int)->AutoencoderKL: def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int) -> AutoencoderKL:
vae_config = create_vae_diffusers_config( vae_config = create_vae_diffusers_config(
vae_config, image_size=image_size vae_config, image_size=image_size
) )
@ -979,8 +981,6 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
original_config_file: str, original_config_file: str,
extract_ema: bool = True, extract_ema: bool = True,
precision: torch.dtype = torch.float32, precision: torch.dtype = torch.float32,
upcast_attention: bool = False,
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon,
scan_needed: bool = True, scan_needed: bool = True,
) -> StableDiffusionPipeline: ) -> StableDiffusionPipeline:
""" """
@ -994,8 +994,6 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
:param checkpoint_path: Path to `.ckpt` file. :param checkpoint_path: Path to `.ckpt` file.
:param original_config_file: Path to `.yaml` config file corresponding to the original architecture. :param original_config_file: Path to `.yaml` config file corresponding to the original architecture.
If `None`, will be automatically inferred by looking for a key that only exists in SD2.0 models. If `None`, will be automatically inferred by looking for a key that only exists in SD2.0 models.
:param prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion
v1.X and Stable Diffusion v2 Base. Use `'v-prediction'` for Stable Diffusion v2.
:param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", :param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler",
"euler-ancestral", "dpm", "ddim"]`. :param model_type: The pipeline type. `None` to automatically infer, or one of "euler-ancestral", "dpm", "ddim"]`. :param model_type: The pipeline type. `None` to automatically infer, or one of
`["FrozenOpenCLIPEmbedder", "FrozenCLIPEmbedder"]`. :param extract_ema: Only relevant for `["FrozenOpenCLIPEmbedder", "FrozenCLIPEmbedder"]`. :param extract_ema: Only relevant for
@ -1003,17 +1001,16 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher
quality images for inference. Non-EMA weights are usually better to continue fine-tuning. quality images for inference. Non-EMA weights are usually better to continue fine-tuning.
:param precision: precision to use - torch.float16, torch.float32 or torch.autocast :param precision: precision to use - torch.float16, torch.float32 or torch.autocast
:param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when
running stable diffusion 2.1.
""" """
config = InvokeAIAppConfig.get_config() if not isinstance(checkpoint_path, Path):
checkpoint_path = Path(checkpoint_path)
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
verbosity = dlogging.get_verbosity() verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error() dlogging.set_verbosity_error()
if str(checkpoint_path).endswith(".safetensors"): if checkpoint_path.suffix == ".safetensors":
checkpoint = load_file(checkpoint_path) checkpoint = load_file(checkpoint_path)
else: else:
if scan_needed: if scan_needed:
@ -1026,9 +1023,13 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
original_config = OmegaConf.load(original_config_file) original_config = OmegaConf.load(original_config_file)
if model_version == BaseModelType.StableDiffusion2 and prediction_type == SchedulerPredictionType.VPrediction: if model_version == BaseModelType.StableDiffusion2 and original_config["model"]["params"]["parameterization"] == "v":
prediction_type = "v_prediction"
upcast_attention = True
image_size = 768 image_size = 768
else: else:
prediction_type = "epsilon"
upcast_attention = False
image_size = 512 image_size = 512
# #
@ -1083,7 +1084,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
if model_type == "FrozenOpenCLIPEmbedder": if model_type == "FrozenOpenCLIPEmbedder":
text_model = convert_open_clip_checkpoint(checkpoint) text_model = convert_open_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained( tokenizer = CLIPTokenizer.from_pretrained(
MODEL_ROOT / 'stable-diffusion-2-clip', CONVERT_MODEL_ROOT / 'stable-diffusion-2-clip',
subfolder='tokenizer', subfolder='tokenizer',
) )
pipe = StableDiffusionPipeline( pipe = StableDiffusionPipeline(
@ -1099,9 +1100,9 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
elif model_type in ["FrozenCLIPEmbedder", "WeightedFrozenCLIPEmbedder"]: elif model_type in ["FrozenCLIPEmbedder", "WeightedFrozenCLIPEmbedder"]:
text_model = convert_ldm_clip_checkpoint(checkpoint) text_model = convert_ldm_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained(MODEL_ROOT / 'clip-vit-large-patch14') tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / 'clip-vit-large-patch14')
safety_checker = StableDiffusionSafetyChecker.from_pretrained(MODEL_ROOT / 'stable-diffusion-safety-checker') safety_checker = StableDiffusionSafetyChecker.from_pretrained(CONVERT_MODEL_ROOT / 'stable-diffusion-safety-checker')
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ROOT / 'stable-diffusion-safety-checker') feature_extractor = AutoFeatureExtractor.from_pretrained(CONVERT_MODEL_ROOT / 'stable-diffusion-safety-checker')
pipe = StableDiffusionPipeline( pipe = StableDiffusionPipeline(
vae=vae.to(precision), vae=vae.to(precision),
text_encoder=text_model.to(precision), text_encoder=text_model.to(precision),
@ -1115,7 +1116,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
else: else:
text_config = create_ldm_bert_config(original_config) text_config = create_ldm_bert_config(original_config)
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
tokenizer = BertTokenizerFast.from_pretrained(MODEL_ROOT / "bert-base-uncased") tokenizer = BertTokenizerFast.from_pretrained(CONVERT_MODEL_ROOT / "bert-base-uncased")
pipe = LDMTextToImagePipeline( pipe = LDMTextToImagePipeline(
vqvae=vae, vqvae=vae,
bert=text_model, bert=text_model,
@ -1131,7 +1132,6 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
def convert_ckpt_to_diffusers( def convert_ckpt_to_diffusers(
checkpoint_path: Union[str, Path], checkpoint_path: Union[str, Path],
dump_path: Union[str, Path], dump_path: Union[str, Path],
model_root: Union[str, Path],
**kwargs, **kwargs,
): ):
""" """
@ -1139,9 +1139,6 @@ def convert_ckpt_to_diffusers(
and in addition a path-like object indicating the location of the desired diffusers and in addition a path-like object indicating the location of the desired diffusers
model to be written. model to be written.
""" """
# setting global here to avoid massive changes late at night
global MODEL_ROOT
MODEL_ROOT = Path(model_root) / 'core/convert'
pipe = load_pipeline_from_original_stable_diffusion_ckpt(checkpoint_path, **kwargs) pipe = load_pipeline_from_original_stable_diffusion_ckpt(checkpoint_path, **kwargs)
pipe.save_pretrained( pipe.save_pretrained(

View File

@ -1,118 +0,0 @@
"""
Routines for downloading and installing models.
"""
import json
import safetensors
import safetensors.torch
import shutil
import tempfile
import torch
import traceback
from dataclasses import dataclass
from diffusers import ModelMixin
from enum import Enum
from typing import Callable
from pathlib import Path
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from . import ModelManager
from .models import BaseModelType, ModelType, VariantType
from .model_probe import ModelProbe, ModelVariantInfo
from .model_cache import SilenceWarnings
class ModelInstall(object):
'''
This class is able to download and install several different kinds of
InvokeAI models. The helper function, if provided, is called on to distinguish
between v2-base and v2-768 stable diffusion pipelines. This usually involves
asking the user to select the proper type, as there is no way of distinguishing
the two type of v2 file programmatically (as far as I know).
'''
def __init__(self,
config: InvokeAIAppConfig,
model_base_helper: Callable[[Path],BaseModelType]=None,
clobber:bool = False
):
'''
:param config: InvokeAI configuration object
:param model_base_helper: A function call that accepts the Path to a checkpoint model and returns a ModelType enum
:param clobber: If true, models with colliding names will be overwritten
'''
self.config = config
self.clogger = clobber
self.helper = model_base_helper
self.prober = ModelProbe()
def install_checkpoint_file(self, checkpoint: Path)->dict:
'''
Install the checkpoint file at path and return a
configuration entry that can be added to `models.yaml`.
Model checkpoints and VAEs will be converted into
diffusers before installation. Note that the model manager
does not hold entries for anything but diffusers pipelines,
and the configuration file stanzas returned from such models
can be safely ignored.
'''
model_info = self.prober.probe(checkpoint, self.helper)
if not model_info:
raise ValueError(f"Unable to determine type of checkpoint file {checkpoint}")
key = ModelManager.create_key(
model_name = checkpoint.stem,
base_model = model_info.base_type,
model_type = model_info.model_type,
)
destination_path = self._dest_path(model_info) / checkpoint
destination_path.parent.mkdir(parents=True, exist_ok=True)
self._check_for_collision(destination_path)
stanza = {
key: dict(
name = checkpoint.stem,
description = f'{model_info.model_type} model {checkpoint.stem}',
base = model_info.base_model.value,
type = model_info.model_type.value,
variant = model_info.variant_type.value,
path = str(destination_path),
)
}
# non-pipeline; no conversion needed, just copy into right place
if model_info.model_type != ModelType.Pipeline:
shutil.copyfile(checkpoint, destination_path)
stanza[key].update({'format': 'checkpoint'})
# pipeline - conversion needed here
else:
destination_path = self._dest_path(model_info) / checkpoint.stem
config_file = self._pipeline_type_to_config_file(model_info.model_type)
from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
with SilenceWarnings:
convert_ckpt_to_diffusers(
checkpoint,
destination_path,
extract_ema=True,
original_config_file=config_file,
scan_needed=False,
)
stanza[key].update({'format': 'folder',
'path': destination_path, # no suffix on this
})
return stanza
def _check_for_collision(self, path: Path):
if not path.exists():
return
if self.clobber:
shutil.rmtree(path)
else:
raise ValueError(f"Destination {path} already exists. Won't overwrite unless clobber=True.")
def _staging_directory(self)->tempfile.TemporaryDirectory:
return tempfile.TemporaryDirectory(dir=self.config.root_path)

View File

@ -1,53 +1,209 @@
"""This module manages the InvokeAI `models.yaml` file, mapping """This module manages the InvokeAI `models.yaml` file, mapping
symbolic diffusers model names to the paths and repo_ids used symbolic diffusers model names to the paths and repo_ids used by the
by the underlying `from_pretrained()` call. underlying `from_pretrained()` call.
For fetching models, use manager.get_model('symbolic name'). This will SYNOPSIS:
return a ModelInfo object that contains the following attributes:
* context -- a context manager Generator that loads and locks the mgr = ModelManager('/home/phi/invokeai/configs/models.yaml')
model into GPU VRAM and returns the model for use. sd1_5 = mgr.get_model('stable-diffusion-v1-5',
See below for usage. model_type=ModelType.Main,
* name -- symbolic name of the model base_model=BaseModelType.StableDiffusion1,
* type -- SubModelType of the model submodel_type=SubModelType.Unet)
* hash -- unique hash for the model with sd1_5 as unet:
* location -- path or repo_id of the model run_some_inference(unet)
* revision -- revision of the model if coming from a repo id,
e.g. 'fp16'
* precision -- torch precision of the model
Typical usage: FETCHING MODELS:
from invokeai.backend import ModelManager Models are described using four attributes:
manager = ModelManager( 1) model_name -- the symbolic name for the model
config='./configs/models.yaml',
max_cache_size=8
) # gigabytes
model_info = manager.get_model('stable-diffusion-1.5', SubModelType.Diffusers) 2) ModelType -- an enum describing the type of the model. Currently
with model_info.context as my_model: defined types are:
my_model.latents_from_embeddings(...) ModelType.Main -- a full model capable of generating images
ModelType.Vae -- a VAE model
ModelType.Lora -- a LoRA or LyCORIS fine-tune
ModelType.TextualInversion -- a textual inversion embedding
ModelType.ControlNet -- a ControlNet model
The manager uses the underlying ModelCache class to keep 3) BaseModelType -- an enum indicating the stable diffusion base model, one of:
frequently-used models in RAM and move them into GPU as needed for BaseModelType.StableDiffusion1
generation operations. The optional `max_cache_size` argument BaseModelType.StableDiffusion2
indicates the maximum size the cache can grow to, in gigabytes. The
underlying ModelCache object can be accessed using the manager's "cache"
attribute.
Because the model manager can return multiple different types of 4) SubModelType (optional) -- an enum that refers to one of the submodels contained
models, you may wish to add additional type checking on the class within the main model. Values are:
of model returned. To do this, provide the option `model_type`
parameter:
model_info = manager.get_model( SubModelType.UNet
'clip-tokenizer', SubModelType.TextEncoder
model_type=SubModelType.Tokenizer SubModelType.Tokenizer
SubModelType.Scheduler
SubModelType.SafetyChecker
To fetch a model, use `manager.get_model()`. This takes the symbolic
name of the model, the ModelType, the BaseModelType and the
SubModelType. The latter is required for ModelType.Main.
get_model() will return a ModelInfo object that can then be used in
context to retrieve the model and move it into GPU VRAM (on GPU
systems).
A typical example is:
sd1_5 = mgr.get_model('stable-diffusion-v1-5',
model_type=ModelType.Main,
base_model=BaseModelType.StableDiffusion1,
submodel_type=SubModelType.Unet)
with sd1_5 as unet:
run_some_inference(unet)
The ModelInfo object provides a number of useful fields describing the
model, including:
name -- symbolic name of the model
base_model -- base model (BaseModelType)
type -- model type (ModelType)
location -- path to the model file
precision -- torch precision of the model
hash -- unique sha256 checksum for this model
SUBMODELS:
When fetching a main model, you must specify the submodel. Retrieval
of full pipelines is not supported.
vae_info = mgr.get_model('stable-diffusion-1.5',
model_type = ModelType.Main,
base_model = BaseModelType.StableDiffusion1,
submodel_type = SubModelType.Vae
) )
with vae_info as vae:
do_something(vae)
This will raise an InvalidModelError if the format defined in the This rule does not apply to controlnets, embeddings, loras and standalone
config file doesn't match the requested model type. VAEs, which do not have submodels.
LISTING MODELS
The model_names() method will return a list of Tuples describing each
model it knows about:
>> mgr.model_names()
[
('stable-diffusion-1.5', <BaseModelType.StableDiffusion1: 'sd-1'>, <ModelType.Main: 'main'>),
('stable-diffusion-2.1', <BaseModelType.StableDiffusion2: 'sd-2'>, <ModelType.Main: 'main'>),
('inpaint', <BaseModelType.StableDiffusion1: 'sd-1'>, <ModelType.ControlNet: 'controlnet'>)
('Ink scenery', <BaseModelType.StableDiffusion1: 'sd-1'>, <ModelType.Lora: 'lora'>)
...
]
The tuple is in the correct order to pass to get_model():
for m in mgr.model_names():
info = get_model(*m)
In contrast, the list_models() method returns a list of dicts, each
providing information about a model defined in models.yaml. For example:
>>> models = mgr.list_models()
>>> json.dumps(models[0])
{"path": "/home/lstein/invokeai-main/models/sd-1/controlnet/canny",
"model_format": "diffusers",
"name": "canny",
"base_model": "sd-1",
"type": "controlnet"
}
You can filter by model type and base model as shown here:
controlnets = mgr.list_models(model_type=ModelType.ControlNet,
base_model=BaseModelType.StableDiffusion1)
for c in controlnets:
name = c['name']
format = c['model_format']
path = c['path']
type = c['type']
# etc
ADDING AND REMOVING MODELS
At startup time, the `models` directory will be scanned for
checkpoints, diffusers pipelines, controlnets, LoRAs and TI
embeddings. New entries will be added to the model manager and defunct
ones removed. Anything that is a main model (ModelType.Main) will be
added to models.yaml. For scanning to succeed, files need to be in
their proper places. For example, a controlnet folder built on the
stable diffusion 2 base, will need to be placed in
`models/sd-2/controlnet`.
Layout of the `models` directory:
models
sd-1
   controlnet
   lora
   main
   embedding
sd-2
   controlnet
   lora
   main
embedding
core
face_reconstruction
codeformer
gfpgan
sd-conversion
clip-vit-large-patch14 - tokenizer, text_encoder subdirs
stable-diffusion-2 - tokenizer, text_encoder subdirs
stable-diffusion-safety-checker
upscaling
esrgan
class ConfigMeta(BaseModel):Loras, textual_inversion and controlnet models are not listed
explicitly in models.yaml, but are added to the in-memory data
structure at initialization time by scanning the models directory. The
in-memory data structure can be resynchronized by calling
`manager.scan_models_directory()`.
Files and folders placed inside the `autoimport` paths (paths
defined in `invokeai.yaml`) will also be scanned for new models at
initialization time and added to `models.yaml`. Files will not be
moved from this location but preserved in-place. These directories
are:
configuration default description
------------- ------- -----------
autoimport_dir autoimport/main main models
lora_dir autoimport/lora LoRA/LyCORIS models
embedding_dir autoimport/embedding TI embeddings
controlnet_dir autoimport/controlnet ControlNet models
In actuality, models located in any of these directories are scanned
to determine their type, so it isn't strictly necessary to organize
the different types in this way. This entry in `invokeai.yaml` will
recursively scan all subdirectories within `autoimport`, scan models
files it finds, and import them if recognized.
Paths:
autoimport_dir: autoimport
A model can be manually added using `add_model()` using the model's
name, base model, type and a dict of model attributes. See
`invokeai/backend/model_management/models` for the attributes required
by each model type.
A model can be deleted using `del_model()`, providing the same
identifying information as `get_model()`
The `heuristic_import()` method will take a set of strings
corresponding to local paths, remote URLs, and repo_ids, probe the
object to determine what type of model it is (if any), and import new
models into the manager. If passed a directory, it will recursively
scan it for models to import. The return value is a set of the models
successfully added.
MODELS.YAML MODELS.YAML
@ -56,93 +212,18 @@ The general format of a models.yaml section is:
type-of-model/name-of-model: type-of-model/name-of-model:
path: /path/to/local/file/or/directory path: /path/to/local/file/or/directory
description: a description description: a description
format: folder|ckpt|safetensors|pt format: diffusers|checkpoint
base: SD-1|SD-2 variant: normal|inpaint|depth
subfolder: subfolder-name
The type of model is given in the stanza key, and is one of The type of model is given in the stanza key, and is one of
{diffusers, ckpt, vae, text_encoder, tokenizer, unet, scheduler, {main, vae, lora, controlnet, textual}
safety_checker, feature_extractor, lora, textual_inversion,
controlnet}, and correspond to items in the SubModelType enum defined
in model_cache.py
The format indicates whether the model is organized as a folder with The format indicates whether the model is organized as a diffusers
model subdirectories, or is contained in a single checkpoint or folder with model subdirectories, or is contained in a single
safetensors file. checkpoint or safetensors file.
One, but not both, of repo_id and path are provided. repo_id is the The path points to a file or directory on disk. If a relative path,
HuggingFace repository ID of the model, and path points to the file or the root is the InvokeAI ROOTDIR.
directory on disk.
If subfolder is provided, then the model exists in a subdirectory of
the main model. These are usually named after the model type, such as
"unet".
This example summarizes the two ways of getting a non-diffuser model:
text_encoder/clip-test-1:
format: folder
path: /path/to/folder
description: Returns standalone CLIPTextModel
text_encoder/clip-test-2:
format: folder
repo_id: /path/to/folder
subfolder: text_encoder
description: Returns the text_encoder in the subfolder of the diffusers model (just the encoder in RAM)
SUBMODELS:
It is also possible to fetch an isolated submodel from a diffusers
model. Use the `submodel` parameter to select which part:
vae = manager.get_model('stable-diffusion-1.5',submodel=SubModelType.Vae)
with vae.context as my_vae:
print(type(my_vae))
# "AutoencoderKL"
DIRECTORY_SCANNING:
Loras, textual_inversion and controlnet models are usually not listed
explicitly in models.yaml, but are added to the in-memory data
structure at initialization time by scanning the models directory. The
in-memory data structure can be resynchronized by calling
`manager.scan_models_directory`.
DISAMBIGUATION:
You may wish to use the same name for a related family of models. To
do this, disambiguate the stanza key with the model and and format
separated by "/". Example:
tokenizer/clip-large:
format: tokenizer
path: /path/to/folder
description: Returns standalone tokenizer
text_encoder/clip-large:
format: text_encoder
path: /path/to/folder
description: Returns standalone text encoder
You can now use the `model_type` argument to indicate which model you
want:
tokenizer = mgr.get('clip-large',model_type=SubModelType.Tokenizer)
encoder = mgr.get('clip-large',model_type=SubModelType.TextEncoder)
OTHER FUNCTIONS:
Other methods provided by ModelManager support importing, editing,
converting and deleting models.
IMPORTANT CHANGES AND LIMITATIONS SINCE 2.3:
1. Only local paths are supported. Repo_ids are no longer accepted. This
simplifies the logic.
2. VAEs can't be swapped in and out at load time. They must be baked
into the model when downloaded or converted.
""" """
from __future__ import annotations from __future__ import annotations
@ -151,13 +232,11 @@ import os
import hashlib import hashlib
import textwrap import textwrap
from dataclasses import dataclass from dataclasses import dataclass
from packaging import version
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, List, Tuple, Union, types from typing import Optional, List, Tuple, Union, Set, Callable, types
from shutil import rmtree from shutil import rmtree
import torch import torch
from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig from omegaconf.dictconfig import DictConfig
@ -165,9 +244,13 @@ from pydantic import BaseModel
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util import CUDA_DEVICE, download_with_resume from invokeai.backend.util import CUDA_DEVICE, Chdir
from .model_cache import ModelCache, ModelLocker from .model_cache import ModelCache, ModelLocker
from .models import BaseModelType, ModelType, SubModelType, ModelError, MODEL_CLASSES from .models import (
BaseModelType, ModelType, SubModelType,
ModelError, SchedulerPredictionType, MODEL_CLASSES,
ModelConfigBase,
)
# We are only starting to number the config file with release 3. # We are only starting to number the config file with release 3.
# The config file version doesn't have to start at release version, but it will help # The config file version doesn't have to start at release version, but it will help
@ -183,7 +266,6 @@ class ModelInfo():
hash: str hash: str
location: Union[Path, str] location: Union[Path, str]
precision: torch.dtype precision: torch.dtype
revision: str = None
_cache: ModelCache = None _cache: ModelCache = None
def __enter__(self): def __enter__(self):
@ -199,31 +281,6 @@ class InvalidModelError(Exception):
MAX_CACHE_SIZE = 6.0 # GB MAX_CACHE_SIZE = 6.0 # GB
# layout of the models directory:
# models
# ├── sd-1
# │   ├── controlnet
# │   ├── lora
# │   ├── pipeline
# │   └── textual_inversion
# ├── sd-2
# │   ├── controlnet
# │   ├── lora
# │   ├── pipeline
# │ └── textual_inversion
# └── core
# ├── face_reconstruction
# │ ├── codeformer
# │ └── gfpgan
# ├── sd-conversion
# │ ├── clip-vit-large-patch14 - tokenizer, text_encoder subdirs
# │ ├── stable-diffusion-2 - tokenizer, text_encoder subdirs
# │ └── stable-diffusion-safety-checker
# └── upscaling
# └─── esrgan
class ConfigMeta(BaseModel): class ConfigMeta(BaseModel):
version: str version: str
@ -271,7 +328,7 @@ class ModelManager(object):
self.models[model_key] = model_class.create_config(**model_config) self.models[model_key] = model_class.create_config(**model_config)
# check config version number and update on disk/RAM if necessary # check config version number and update on disk/RAM if necessary
self.globals = InvokeAIAppConfig.get_config() self.app_config = InvokeAIAppConfig.get_config()
self.logger = logger self.logger = logger
self.cache = ModelCache( self.cache = ModelCache(
max_cache_size=max_cache_size, max_cache_size=max_cache_size,
@ -307,7 +364,8 @@ class ModelManager(object):
) -> str: ) -> str:
return f"{base_model}/{model_type}/{model_name}" return f"{base_model}/{model_type}/{model_name}"
def parse_key(self, model_key: str) -> Tuple[str, BaseModelType, ModelType]: @classmethod
def parse_key(cls, model_key: str) -> Tuple[str, BaseModelType, ModelType]:
base_model_str, model_type_str, model_name = model_key.split('/', 2) base_model_str, model_type_str, model_name = model_key.split('/', 2)
try: try:
model_type = ModelType(model_type_str) model_type = ModelType(model_type_str)
@ -321,69 +379,37 @@ class ModelManager(object):
return (model_name, base_model, model_type) return (model_name, base_model, model_type)
def _get_model_cache_path(self, model_path):
return self.app_config.models_path / ".cache" / hashlib.md5(str(model_path).encode()).hexdigest()
def get_model( def get_model(
self, self,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
submodel_type: Optional[SubModelType] = None submodel_type: Optional[SubModelType] = None
): )->ModelInfo:
"""Given a model named identified in models.yaml, return """Given a model named identified in models.yaml, return
an ModelInfo object describing it. an ModelInfo object describing it.
:param model_name: symbolic name of the model in models.yaml :param model_name: symbolic name of the model in models.yaml
:param model_type: ModelType enum indicating the type of model to return :param model_type: ModelType enum indicating the type of model to return
:param base_model: BaseModelType enum indicating the base model used by this model
:param submode_typel: an ModelType enum indicating the portion of :param submode_typel: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae) the model to retrieve (e.g. ModelType.Vae)
If not provided, the model_type will be read from the `format` field
of the corresponding stanza. If provided, the model_type will be used
to disambiguate stanzas in the configuration file. The default is to
assume a diffusers pipeline. The behavior is illustrated here:
[models.yaml]
diffusers/test1:
repo_id: foo/bar
description: Typical diffusers pipeline
lora/test1:
repo_id: /tmp/loras/test1.safetensors
description: Typical lora file
test1_pipeline = mgr.get_model('test1')
# returns a StableDiffusionGeneratorPipeline
test1_vae1 = mgr.get_model('test1', submodel=ModelType.Vae)
# returns the VAE part of a diffusers model as an AutoencoderKL
test1_vae2 = mgr.get_model('test1', model_type=ModelType.Diffusers, submodel=ModelType.Vae)
# does the same thing as the previous statement. Note that model_type
# is for the parent model, and submodel is for the part
test1_lora = mgr.get_model('test1', model_type=ModelType.Lora)
# returns a LoRA embed (as a 'dict' of tensors)
test1_encoder = mgr.get_modelI('test1', model_type=ModelType.TextEncoder)
# raises an InvalidModelError
""" """
model_class = MODEL_CLASSES[base_model][model_type] model_class = MODEL_CLASSES[base_model][model_type]
model_key = self.create_key(model_name, base_model, model_type) model_key = self.create_key(model_name, base_model, model_type)
# if model not found try to find it (maybe file just pasted) # if model not found try to find it (maybe file just pasted)
if model_key not in self.models: if model_key not in self.models:
# TODO: find by mask or try rescan? self.scan_models_directory(base_model=base_model, model_type=model_type)
path_mask = f"/models/{base_model}/{model_type}/{model_name}*" if model_key not in self.models:
if False: # model_path = next(find_by_mask(path_mask)):
model_path = None # TODO:
model_config = model_class.probe_config(model_path)
self.models[model_key] = model_config
else:
raise Exception(f"Model not found - {model_key}") raise Exception(f"Model not found - {model_key}")
# if it known model check that target path exists (if manualy deleted) model_config = self.models[model_key]
else: model_path = self.app_config.root_path / model_config.path
# logic repeated twice(in rescan too) any way to optimize?
if not os.path.exists(self.models[model_key].path): if not model_path.exists():
if model_class.save_to_config: if model_class.save_to_config:
self.models[model_key].error = ModelError.NotFound self.models[model_key].error = ModelError.NotFound
raise Exception(f"Files for model \"{model_key}\" not found") raise Exception(f"Files for model \"{model_key}\" not found")
@ -392,16 +418,6 @@ class ModelManager(object):
self.models.pop(model_key, None) self.models.pop(model_key, None)
raise Exception(f"Model not found - {model_key}") raise Exception(f"Model not found - {model_key}")
# reset model errors?
model_config = self.models[model_key]
# /models/{base_model}/{model_type}/{name}.ckpt or .safentesors
# /models/{base_model}/{model_type}/{name}/
model_path = model_config.path
# vae/movq override # vae/movq override
# TODO: # TODO:
if submodel_type is not None and hasattr(model_config, submodel_type): if submodel_type is not None and hasattr(model_config, submodel_type):
@ -414,10 +430,10 @@ class ModelManager(object):
# TODO: path # TODO: path
# TODO: is it accurate to use path as id # TODO: is it accurate to use path as id
dst_convert_path = self.globals.models_dir / ".cache" / hashlib.md5(model_path.encode()).hexdigest() dst_convert_path = self._get_model_cache_path(model_path)
model_path = model_class.convert_if_required( model_path = model_class.convert_if_required(
base_model=base_model, base_model=base_model,
model_path=model_path, model_path=str(model_path), # TODO: refactor str/Path types logic
output_path=dst_convert_path, output_path=dst_convert_path,
config=model_config, config=model_config,
) )
@ -476,11 +492,6 @@ class ModelManager(object):
) -> list[dict]: ) -> list[dict]:
""" """
Return a list of models. Return a list of models.
Please use model_manager.models() to get all the model names,
model_manager.model_info('model-name') to get the stanza for the model
named 'model-name', and model_manager.config to get the full OmegaConf
object derived from models.yaml
""" """
models = [] models = []
@ -507,7 +518,7 @@ class ModelManager(object):
def print_models(self) -> None: def print_models(self) -> None:
""" """
Print a table of models, their descriptions Print a table of models and their descriptions. This needs to be redone
""" """
# TODO: redo # TODO: redo
for model_type, model_dict in self.list_models().items(): for model_type, model_dict in self.list_models().items():
@ -515,7 +526,7 @@ class ModelManager(object):
line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}' line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
print(line) print(line)
# TODO: test when ui implemented # Tested - LS
def del_model( def del_model(
self, self,
model_name: str, model_name: str,
@ -525,7 +536,6 @@ class ModelManager(object):
""" """
Delete the named model. Delete the named model.
""" """
raise Exception("TODO: del_model") # TODO: redo
model_key = self.create_key(model_name, base_model, model_type) model_key = self.create_key(model_name, base_model, model_type)
model_cfg = self.models.pop(model_key, None) model_cfg = self.models.pop(model_key, None)
@ -541,14 +551,18 @@ class ModelManager(object):
self.cache.uncache_model(cache_id) self.cache.uncache_model(cache_id)
# if model inside invoke models folder - delete files # if model inside invoke models folder - delete files
if model_cfg.path.startswith("models/") or model_cfg.path.startswith("models\\"): model_path = self.app_config.root_path / model_cfg.path
model_path = self.globals.root_dir / model_cfg.path cache_path = self._get_model_cache_path(model_path)
if model_path.isdir(): if cache_path.exists():
shutil.rmtree(str(model_path)) rmtree(str(cache_path))
if model_path.is_relative_to(self.app_config.models_path):
if model_path.is_dir():
rmtree(str(model_path))
else: else:
model_path.unlink() model_path.unlink()
# TODO: test when ui implemented # LS: tested
def add_model( def add_model(
self, self,
model_name: str, model_name: str,
@ -569,18 +583,30 @@ class ModelManager(object):
model_config = model_class.create_config(**model_attributes) model_config = model_class.create_config(**model_attributes)
model_key = self.create_key(model_name, base_model, model_type) model_key = self.create_key(model_name, base_model, model_type)
assert ( if model_key in self.models and not clobber:
clobber or model_key not in self.models raise Exception(f'Attempt to overwrite existing model definition "{model_key}"')
), f'attempt to overwrite existing model definition "{model_key}"'
self.models[model_key] = model_config old_model = self.models.pop(model_key, None)
if old_model is not None:
# TODO: if path changed and old_model.path inside models folder should we delete this too?
if clobber and model_key in self.cache_keys: # remove conversion cache as config changed
old_model_path = self.app_config.root_path / old_model.path
old_model_cache = self._get_model_cache_path(old_model_path)
if old_model_cache.exists():
if old_model_cache.is_dir():
rmtree(str(old_model_cache))
else:
old_model_cache.unlink()
# remove in-memory cache
# note: it not garantie to release memory(model can has other references) # note: it not garantie to release memory(model can has other references)
cache_ids = self.cache_keys.pop(model_key, []) cache_ids = self.cache_keys.pop(model_key, [])
for cache_id in cache_ids: for cache_id in cache_ids:
self.cache.uncache_model(cache_id) self.cache.uncache_model(cache_id)
self.models[model_key] = model_config
def search_models(self, search_folder): def search_models(self, search_folder):
self.logger.info(f"Finding Models In: {search_folder}") self.logger.info(f"Finding Models In: {search_folder}")
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt") models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
@ -621,7 +647,7 @@ class ModelManager(object):
yaml_str = OmegaConf.to_yaml(data_to_save) yaml_str = OmegaConf.to_yaml(data_to_save)
config_file_path = conf_file or self.config_path config_file_path = conf_file or self.config_path
assert config_file_path is not None,'no config file path to write to' assert config_file_path is not None,'no config file path to write to'
config_file_path = self.globals.root_dir / config_file_path config_file_path = self.app_config.root_path / config_file_path
tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp") tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp")
with open(tmpfile, "w", encoding="utf-8") as outfile: with open(tmpfile, "w", encoding="utf-8") as outfile:
outfile.write(self.preamble()) outfile.write(self.preamble())
@ -644,15 +670,20 @@ class ModelManager(object):
""" """
) )
def scan_models_directory(self): def scan_models_directory(
self,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
):
loaded_files = set() loaded_files = set()
new_models_found = False new_models_found = False
with Chdir(self.app_config.root_path):
for model_key, model_config in list(self.models.items()): for model_key, model_config in list(self.models.items()):
model_name, base_model, model_type = self.parse_key(model_key) model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
model_path = str(self.globals.root / model_config.path) model_path = self.app_config.root_path / model_config.path
if not os.path.exists(model_path): if not model_path.exists():
model_class = MODEL_CLASSES[base_model][model_type] model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
if model_class.save_to_config: if model_class.save_to_config:
model_config.error = ModelError.NotFound model_config.error = ModelError.NotFound
else: else:
@ -660,26 +691,129 @@ class ModelManager(object):
else: else:
loaded_files.add(model_path) loaded_files.add(model_path)
for base_model in BaseModelType: for cur_base_model in BaseModelType:
for model_type in ModelType: if base_model is not None and cur_base_model != base_model:
model_class = MODEL_CLASSES[base_model][model_type] continue
models_dir = os.path.join(self.globals.models_path, base_model, model_type)
if not os.path.exists(models_dir): for cur_model_type in ModelType:
if model_type is not None and cur_model_type != model_type:
continue
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
models_dir = self.app_config.models_path / cur_base_model.value / cur_model_type.value
if not models_dir.exists():
continue # TODO: or create all folders? continue # TODO: or create all folders?
for entry_name in os.listdir(models_dir): for model_path in models_dir.iterdir():
model_path = os.path.join(models_dir, entry_name)
if model_path not in loaded_files: # TODO: check if model_path not in loaded_files: # TODO: check
model_name = Path(model_path).stem model_name = model_path.name if model_path.is_dir() else model_path.stem
model_key = self.create_key(model_name, base_model, model_type) model_key = self.create_key(model_name, cur_base_model, cur_model_type)
if model_key in self.models: if model_key in self.models:
raise Exception(f"Model with key {model_key} added twice") raise Exception(f"Model with key {model_key} added twice")
model_config: ModelConfigBase = model_class.probe_config(model_path) if model_path.is_relative_to(self.app_config.root_path):
model_path = model_path.relative_to(self.app_config.root_path)
try:
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
self.models[model_key] = model_config self.models[model_key] = model_config
new_models_found = True new_models_found = True
except NotImplementedError as e:
self.logger.warning(e)
if new_models_found: imported_models = self.autoimport()
if (new_models_found or imported_models) and self.config_path:
self.commit() self.commit()
def autoimport(self)->set[Path]:
'''
Scan the autoimport directory (if defined) and import new models, delete defunct models.
'''
# avoid circular import
from invokeai.backend.install.model_install_backend import ModelInstall
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
installer = ModelInstall(config = self.app_config,
model_manager = self,
prediction_type_helper = ask_user_for_prediction_type,
)
installed = set()
scanned_dirs = set()
config = self.app_config
known_paths = {(self.app_config.root_path / x['path']) for x in self.list_models()}
for autodir in [config.autoimport_dir,
config.lora_dir,
config.embedding_dir,
config.controlnet_dir]:
if autodir is None:
continue
self.logger.info(f'Scanning {autodir} for models to import')
autodir = self.app_config.root_path / autodir
if not autodir.exists():
continue
items_scanned = 0
new_models_found = set()
for root, dirs, files in os.walk(autodir):
items_scanned += len(dirs) + len(files)
for d in dirs:
path = Path(root) / d
if path in known_paths or path.parent in scanned_dirs:
scanned_dirs.add(path)
continue
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
new_models_found.update(installer.heuristic_install(path))
scanned_dirs.add(path)
for f in files:
path = Path(root) / f
if path in known_paths or path.parent in scanned_dirs:
continue
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
new_models_found.update(installer.heuristic_install(path))
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
installed.update(new_models_found)
return installed
def heuristic_import(self,
items_to_import: Set[str],
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
)->Set[str]:
'''Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported.
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
The prediction type helper is necessary to distinguish between
models based on Stable Diffusion 2 Base (requiring
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
(requiring SchedulerPredictionType.VPrediction). It is
generally impossible to do this programmatically, so the
prediction_type_helper usually asks the user to choose.
'''
# avoid circular import here
from invokeai.backend.install.model_install_backend import ModelInstall
successfully_installed = set()
installer = ModelInstall(config = self.app_config,
prediction_type_helper = prediction_type_helper,
model_manager = self)
for thing in items_to_import:
try:
installed = installer.heuristic_install(thing)
successfully_installed.update(installed)
except Exception as e:
self.logger.warning(f'{thing} could not be imported: {str(e)}')
self.commit()
return successfully_installed

View File

@ -1,27 +1,28 @@
import json import json
import traceback
import torch import torch
import safetensors.torch import safetensors.torch
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum
from diffusers import ModelMixin, ConfigMixin, StableDiffusionPipeline, AutoencoderKL, ControlNetModel from diffusers import ModelMixin, ConfigMixin
from pathlib import Path from pathlib import Path
from typing import Callable, Literal, Union, Dict from typing import Callable, Literal, Union, Dict
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger from .models import (
from .models import BaseModelType, ModelType, ModelVariantType, SchedulerPredictionType, SilenceWarnings BaseModelType, ModelType, ModelVariantType,
SchedulerPredictionType, SilenceWarnings,
)
from .models.base import read_checkpoint_meta
@dataclass @dataclass
class ModelVariantInfo(object): class ModelProbeInfo(object):
model_type: ModelType model_type: ModelType
base_type: BaseModelType base_type: BaseModelType
variant_type: ModelVariantType variant_type: ModelVariantType
prediction_type: SchedulerPredictionType prediction_type: SchedulerPredictionType
upcast_attention: bool upcast_attention: bool
format: Literal['folder','checkpoint'] format: Literal['diffusers','checkpoint', 'lycoris']
image_size: int image_size: int
class ProbeBase(object): class ProbeBase(object):
@ -31,19 +32,19 @@ class ProbeBase(object):
class ModelProbe(object): class ModelProbe(object):
PROBES = { PROBES = {
'folder': { }, 'diffusers': { },
'checkpoint': { }, 'checkpoint': { },
} }
CLASS2TYPE = { CLASS2TYPE = {
'StableDiffusionPipeline' : ModelType.Pipeline, 'StableDiffusionPipeline' : ModelType.Main,
'AutoencoderKL' : ModelType.Vae, 'AutoencoderKL' : ModelType.Vae,
'ControlNetModel' : ModelType.ControlNet, 'ControlNetModel' : ModelType.ControlNet,
} }
@classmethod @classmethod
def register_probe(cls, def register_probe(cls,
format: Literal['folder','file'], format: Literal['diffusers','checkpoint'],
model_type: ModelType, model_type: ModelType,
probe_class: ProbeBase): probe_class: ProbeBase):
cls.PROBES[format][model_type] = probe_class cls.PROBES[format][model_type] = probe_class
@ -51,8 +52,8 @@ class ModelProbe(object):
@classmethod @classmethod
def heuristic_probe(cls, def heuristic_probe(cls,
model: Union[Dict, ModelMixin, Path], model: Union[Dict, ModelMixin, Path],
prediction_type_helper: Callable[[Path],BaseModelType]=None, prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
)->ModelVariantInfo: )->ModelProbeInfo:
if isinstance(model,Path): if isinstance(model,Path):
return cls.probe(model_path=model,prediction_type_helper=prediction_type_helper) return cls.probe(model_path=model,prediction_type_helper=prediction_type_helper)
elif isinstance(model,(dict,ModelMixin,ConfigMixin)): elif isinstance(model,(dict,ModelMixin,ConfigMixin)):
@ -64,7 +65,7 @@ class ModelProbe(object):
def probe(cls, def probe(cls,
model_path: Path, model_path: Path,
model: Union[Dict, ModelMixin] = None, model: Union[Dict, ModelMixin] = None,
prediction_type_helper: Callable[[Path],BaseModelType] = None)->ModelVariantInfo: prediction_type_helper: Callable[[Path],SchedulerPredictionType] = None)->ModelProbeInfo:
''' '''
Probe the model at model_path and return sufficient information about it Probe the model at model_path and return sufficient information about it
to place it somewhere in the models directory hierarchy. If the model is to place it somewhere in the models directory hierarchy. If the model is
@ -74,23 +75,24 @@ class ModelProbe(object):
between V2-Base and V2-768 SD models. between V2-Base and V2-768 SD models.
''' '''
if model_path: if model_path:
format = 'folder' if model_path.is_dir() else 'checkpoint' format_type = 'diffusers' if model_path.is_dir() else 'checkpoint'
else: else:
format = 'folder' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint' format_type = 'diffusers' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
model_info = None model_info = None
try: try:
model_type = cls.get_model_type_from_folder(model_path, model) \ model_type = cls.get_model_type_from_folder(model_path, model) \
if format == 'folder' \ if format_type == 'diffusers' \
else cls.get_model_type_from_checkpoint(model_path, model) else cls.get_model_type_from_checkpoint(model_path, model)
probe_class = cls.PROBES[format].get(model_type) probe_class = cls.PROBES[format_type].get(model_type)
if not probe_class: if not probe_class:
return None return None
probe = probe_class(model_path, model, prediction_type_helper) probe = probe_class(model_path, model, prediction_type_helper)
base_type = probe.get_base_type() base_type = probe.get_base_type()
variant_type = probe.get_variant_type() variant_type = probe.get_variant_type()
prediction_type = probe.get_scheduler_prediction_type() prediction_type = probe.get_scheduler_prediction_type()
model_info = ModelVariantInfo( format = probe.get_format()
model_info = ModelProbeInfo(
model_type = model_type, model_type = model_type,
base_type = base_type, base_type = base_type,
variant_type = variant_type, variant_type = variant_type,
@ -102,32 +104,40 @@ class ModelProbe(object):
and prediction_type==SchedulerPredictionType.VPrediction \ and prediction_type==SchedulerPredictionType.VPrediction \
) else 512, ) else 512,
) )
except Exception as e: except Exception:
return None return None
return model_info return model_info
@classmethod @classmethod
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict)->ModelType: def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
if model_path.suffix not in ('.bin','.pt','.ckpt','.safetensors'): if model_path.suffix not in ('.bin','.pt','.ckpt','.safetensors','.pth'):
return None return None
if model_path.name=='learned_embeds.bin':
if model_path.name == "learned_embeds.bin":
return ModelType.TextualInversion return ModelType.TextualInversion
checkpoint = checkpoint or cls._scan_and_load_checkpoint(model_path)
state_dict = checkpoint.get("state_dict") or checkpoint ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
if any([x.startswith("model.diffusion_model") for x in state_dict.keys()]): ckpt = ckpt.get("state_dict", ckpt)
return ModelType.Pipeline
if any([x.startswith("encoder.conv_in") for x in state_dict.keys()]): for key in ckpt.keys():
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
return ModelType.Main
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
return ModelType.Vae return ModelType.Vae
if "string_to_token" in state_dict or "emb_params" in state_dict: elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
return ModelType.TextualInversion
if any([x.startswith("lora") for x in state_dict.keys()]):
return ModelType.Lora return ModelType.Lora
if any([x.startswith("control_model") for x in state_dict.keys()]): elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
return ModelType.ControlNet return ModelType.ControlNet
if any([x.startswith("input_blocks") for x in state_dict.keys()]): elif key in {"emb_params", "string_to_param"}:
return ModelType.ControlNet return ModelType.TextualInversion
return None # give up
else:
# diffusers-ti
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
return ModelType.TextualInversion
raise ValueError("Unable to determine model type")
@classmethod @classmethod
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType: def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
@ -192,11 +202,14 @@ class ProbeBase(object):
def get_scheduler_prediction_type(self)->SchedulerPredictionType: def get_scheduler_prediction_type(self)->SchedulerPredictionType:
pass pass
def get_format(self)->str:
pass
class CheckpointProbeBase(ProbeBase): class CheckpointProbeBase(ProbeBase):
def __init__(self, def __init__(self,
checkpoint_path: Path, checkpoint_path: Path,
checkpoint: dict, checkpoint: dict,
helper: Callable[[Path],BaseModelType] = None helper: Callable[[Path],SchedulerPredictionType] = None
)->BaseModelType: )->BaseModelType:
self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path) self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
self.checkpoint_path = checkpoint_path self.checkpoint_path = checkpoint_path
@ -205,9 +218,12 @@ class CheckpointProbeBase(ProbeBase):
def get_base_type(self)->BaseModelType: def get_base_type(self)->BaseModelType:
pass pass
def get_format(self)->str:
return 'checkpoint'
def get_variant_type(self)-> ModelVariantType: def get_variant_type(self)-> ModelVariantType:
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path,self.checkpoint) model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path,self.checkpoint)
if model_type != ModelType.Pipeline: if model_type != ModelType.Main:
return ModelVariantType.Normal return ModelVariantType.Normal
state_dict = self.checkpoint.get('state_dict') or self.checkpoint state_dict = self.checkpoint.get('state_dict') or self.checkpoint
in_channels = state_dict[ in_channels = state_dict[
@ -246,7 +262,8 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
return SchedulerPredictionType.Epsilon return SchedulerPredictionType.Epsilon
elif checkpoint["global_step"] == 110000: elif checkpoint["global_step"] == 110000:
return SchedulerPredictionType.VPrediction return SchedulerPredictionType.VPrediction
if self.checkpoint_path and self.helper: if self.checkpoint_path and self.helper \
and not self.checkpoint_path.with_suffix('.yaml').exists(): # if a .yaml config file exists, then this step not needed
return self.helper(self.checkpoint_path) return self.helper(self.checkpoint_path)
else: else:
return None return None
@ -257,6 +274,9 @@ class VaeCheckpointProbe(CheckpointProbeBase):
return BaseModelType.StableDiffusion1 return BaseModelType.StableDiffusion1
class LoRACheckpointProbe(CheckpointProbeBase): class LoRACheckpointProbe(CheckpointProbeBase):
def get_format(self)->str:
return 'lycoris'
def get_base_type(self)->BaseModelType: def get_base_type(self)->BaseModelType:
checkpoint = self.checkpoint checkpoint = self.checkpoint
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight" key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
@ -276,6 +296,9 @@ class LoRACheckpointProbe(CheckpointProbeBase):
return None return None
class TextualInversionCheckpointProbe(CheckpointProbeBase): class TextualInversionCheckpointProbe(CheckpointProbeBase):
def get_format(self)->str:
return None
def get_base_type(self)->BaseModelType: def get_base_type(self)->BaseModelType:
checkpoint = self.checkpoint checkpoint = self.checkpoint
if 'string_to_token' in checkpoint: if 'string_to_token' in checkpoint:
@ -322,17 +345,16 @@ class FolderProbeBase(ProbeBase):
def get_variant_type(self)->ModelVariantType: def get_variant_type(self)->ModelVariantType:
return ModelVariantType.Normal return ModelVariantType.Normal
def get_format(self)->str:
return 'diffusers'
class PipelineFolderProbe(FolderProbeBase): class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType: def get_base_type(self)->BaseModelType:
if self.model: if self.model:
unet_conf = self.model.unet.config unet_conf = self.model.unet.config
scheduler_conf = self.model.scheduler.config
else: else:
with open(self.folder_path / 'unet' / 'config.json','r') as file: with open(self.folder_path / 'unet' / 'config.json','r') as file:
unet_conf = json.load(file) unet_conf = json.load(file)
with open(self.folder_path / 'scheduler' / 'scheduler_config.json','r') as file:
scheduler_conf = json.load(file)
if unet_conf['cross_attention_dim'] == 768: if unet_conf['cross_attention_dim'] == 768:
return BaseModelType.StableDiffusion1 return BaseModelType.StableDiffusion1
elif unet_conf['cross_attention_dim'] == 1024: elif unet_conf['cross_attention_dim'] == 1024:
@ -381,6 +403,9 @@ class VaeFolderProbe(FolderProbeBase):
return BaseModelType.StableDiffusion1 return BaseModelType.StableDiffusion1
class TextualInversionFolderProbe(FolderProbeBase): class TextualInversionFolderProbe(FolderProbeBase):
def get_format(self)->str:
return None
def get_base_type(self)->BaseModelType: def get_base_type(self)->BaseModelType:
path = self.folder_path / 'learned_embeds.bin' path = self.folder_path / 'learned_embeds.bin'
if not path.exists(): if not path.exists():
@ -401,16 +426,24 @@ class ControlNetFolderProbe(FolderProbeBase):
else BaseModelType.StableDiffusion2 else BaseModelType.StableDiffusion2
class LoRAFolderProbe(FolderProbeBase): class LoRAFolderProbe(FolderProbeBase):
# I've never seen one of these in the wild, so this is a noop def get_base_type(self)->BaseModelType:
pass model_file = None
for suffix in ['safetensors','bin']:
base_file = self.folder_path / f'pytorch_lora_weights.{suffix}'
if base_file.exists():
model_file = base_file
break
if not model_file:
raise Exception('Unknown LoRA format encountered')
return LoRACheckpointProbe(model_file,None).get_base_type()
############## register probe classes ###### ############## register probe classes ######
ModelProbe.register_probe('folder', ModelType.Pipeline, PipelineFolderProbe) ModelProbe.register_probe('diffusers', ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe('folder', ModelType.Vae, VaeFolderProbe) ModelProbe.register_probe('diffusers', ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe('folder', ModelType.Lora, LoRAFolderProbe) ModelProbe.register_probe('diffusers', ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe('folder', ModelType.TextualInversion, TextualInversionFolderProbe) ModelProbe.register_probe('diffusers', ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe('folder', ModelType.ControlNet, ControlNetFolderProbe) ModelProbe.register_probe('diffusers', ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe('checkpoint', ModelType.Pipeline, PipelineCheckpointProbe) ModelProbe.register_probe('checkpoint', ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe('checkpoint', ModelType.Vae, VaeCheckpointProbe) ModelProbe.register_probe('checkpoint', ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe('checkpoint', ModelType.Lora, LoRACheckpointProbe) ModelProbe.register_probe('checkpoint', ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe('checkpoint', ModelType.TextualInversion, TextualInversionCheckpointProbe) ModelProbe.register_probe('checkpoint', ModelType.TextualInversion, TextualInversionCheckpointProbe)

View File

@ -11,21 +11,21 @@ from .textual_inversion import TextualInversionModel
MODEL_CLASSES = { MODEL_CLASSES = {
BaseModelType.StableDiffusion1: { BaseModelType.StableDiffusion1: {
ModelType.Pipeline: StableDiffusion1Model, ModelType.Main: StableDiffusion1Model,
ModelType.Vae: VaeModel, ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel, ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel, ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel, ModelType.TextualInversion: TextualInversionModel,
}, },
BaseModelType.StableDiffusion2: { BaseModelType.StableDiffusion2: {
ModelType.Pipeline: StableDiffusion2Model, ModelType.Main: StableDiffusion2Model,
ModelType.Vae: VaeModel, ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel, ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel, ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel, ModelType.TextualInversion: TextualInversionModel,
}, },
#BaseModelType.Kandinsky2_1: { #BaseModelType.Kandinsky2_1: {
# ModelType.Pipeline: Kandinsky2_1Model, # ModelType.Main: Kandinsky2_1Model,
# ModelType.MoVQ: MoVQModel, # ModelType.MoVQ: MoVQModel,
# ModelType.Lora: LoRAModel, # ModelType.Lora: LoRAModel,
# ModelType.ControlNet: ControlNetModel, # ModelType.ControlNet: ControlNetModel,

View File

@ -1,9 +1,12 @@
import json
import os import os
import sys import sys
import typing import typing
import inspect import inspect
from enum import Enum from enum import Enum
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from pathlib import Path
from picklescan.scanner import scan_file_path
import torch import torch
import safetensors.torch import safetensors.torch
from diffusers import DiffusionPipeline, ConfigMixin from diffusers import DiffusionPipeline, ConfigMixin
@ -18,7 +21,7 @@ class BaseModelType(str, Enum):
#Kandinsky2_1 = "kandinsky-2.1" #Kandinsky2_1 = "kandinsky-2.1"
class ModelType(str, Enum): class ModelType(str, Enum):
Pipeline = "pipeline" Main = "main"
Vae = "vae" Vae = "vae"
Lora = "lora" Lora = "lora"
ControlNet = "controlnet" # used by model_probe ControlNet = "controlnet" # used by model_probe
@ -56,7 +59,6 @@ class ModelConfigBase(BaseModel):
class Config: class Config:
use_enum_values = True use_enum_values = True
class EmptyConfigLoader(ConfigMixin): class EmptyConfigLoader(ConfigMixin):
@classmethod @classmethod
def load_config(cls, *args, **kwargs): def load_config(cls, *args, **kwargs):
@ -124,7 +126,10 @@ class ModelBase(metaclass=ABCMeta):
if not isinstance(value, type) or not issubclass(value, ModelConfigBase): if not isinstance(value, type) or not issubclass(value, ModelConfigBase):
continue continue
if hasattr(inspect,'get_annotations'):
fields = inspect.get_annotations(value) fields = inspect.get_annotations(value)
else:
fields = value.__annotations__
try: try:
field = fields["model_format"] field = fields["model_format"]
except: except:
@ -383,15 +388,18 @@ def _fast_safetensors_reader(path: str):
return checkpoint return checkpoint
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
def read_checkpoint_meta(path: str): if str(path).endswith(".safetensors"):
if path.endswith(".safetensors"):
try: try:
checkpoint = _fast_safetensors_reader(path) checkpoint = _fast_safetensors_reader(path)
except: except:
# TODO: create issue for support "meta"? # TODO: create issue for support "meta"?
checkpoint = safetensors.torch.load_file(path, device="cpu") checkpoint = safetensors.torch.load_file(path, device="cpu")
else: else:
if scan:
scan_result = scan_file_path(path)
if scan_result.infected_files != 0:
raise Exception(f"The model file \"{path}\" is potentially infected by malware. Aborting import.")
checkpoint = torch.load(path, map_location=torch.device("meta")) checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint return checkpoint

View File

@ -34,17 +34,17 @@ class StableDiffusion1Model(DiffusersModel):
class CheckpointConfig(ModelConfigBase): class CheckpointConfig(ModelConfigBase):
model_format: Literal[StableDiffusion1ModelFormat.Checkpoint] model_format: Literal[StableDiffusion1ModelFormat.Checkpoint]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
config: Optional[str] = Field(None) config: str
variant: ModelVariantType variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion1 assert base_model == BaseModelType.StableDiffusion1
assert model_type == ModelType.Pipeline assert model_type == ModelType.Main
super().__init__( super().__init__(
model_path=model_path, model_path=model_path,
base_model=BaseModelType.StableDiffusion1, base_model=BaseModelType.StableDiffusion1,
model_type=ModelType.Pipeline, model_type=ModelType.Main,
) )
@classmethod @classmethod
@ -69,7 +69,7 @@ class StableDiffusion1Model(DiffusersModel):
in_channels = unet_config['in_channels'] in_channels = unet_config['in_channels']
else: else:
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)") raise NotImplementedError(f"{path} is not a supported stable diffusion diffusers format")
else: else:
raise NotImplementedError(f"Unknown stable diffusion 1.* format: {model_format}") raise NotImplementedError(f"Unknown stable diffusion 1.* format: {model_format}")
@ -81,6 +81,8 @@ class StableDiffusion1Model(DiffusersModel):
else: else:
raise Exception("Unkown stable diffusion 1.* model format") raise Exception("Unkown stable diffusion 1.* model format")
if ckpt_config_path is None:
ckpt_config_path = _select_ckpt_config(BaseModelType.StableDiffusion1, variant)
return cls.create_config( return cls.create_config(
path=path, path=path,
@ -109,14 +111,12 @@ class StableDiffusion1Model(DiffusersModel):
config: ModelConfigBase, config: ModelConfigBase,
base_model: BaseModelType, base_model: BaseModelType,
) -> str: ) -> str:
assert model_path == config.path
if isinstance(config, cls.CheckpointConfig): if isinstance(config, cls.CheckpointConfig):
return _convert_ckpt_and_cache( return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion1, version=BaseModelType.StableDiffusion1,
model_config=config, model_config=config,
output_path=output_path, output_path=output_path,
) # TODO: args )
else: else:
return model_path return model_path
@ -131,25 +131,20 @@ class StableDiffusion2Model(DiffusersModel):
model_format: Literal[StableDiffusion2ModelFormat.Diffusers] model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
variant: ModelVariantType variant: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
class CheckpointConfig(ModelConfigBase): class CheckpointConfig(ModelConfigBase):
model_format: Literal[StableDiffusion2ModelFormat.Checkpoint] model_format: Literal[StableDiffusion2ModelFormat.Checkpoint]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
config: Optional[str] = Field(None) config: str
variant: ModelVariantType variant: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion2 assert base_model == BaseModelType.StableDiffusion2
assert model_type == ModelType.Pipeline assert model_type == ModelType.Main
super().__init__( super().__init__(
model_path=model_path, model_path=model_path,
base_model=BaseModelType.StableDiffusion2, base_model=BaseModelType.StableDiffusion2,
model_type=ModelType.Pipeline, model_type=ModelType.Main,
) )
@classmethod @classmethod
@ -188,13 +183,8 @@ class StableDiffusion2Model(DiffusersModel):
else: else:
raise Exception("Unkown stable diffusion 2.* model format") raise Exception("Unkown stable diffusion 2.* model format")
if variant == ModelVariantType.Normal: if ckpt_config_path is None:
prediction_type = SchedulerPredictionType.VPrediction ckpt_config_path = _select_ckpt_config(BaseModelType.StableDiffusion2, variant)
upcast_attention = True
else:
prediction_type = SchedulerPredictionType.Epsilon
upcast_attention = False
return cls.create_config( return cls.create_config(
path=path, path=path,
@ -202,8 +192,6 @@ class StableDiffusion2Model(DiffusersModel):
config=ckpt_config_path, config=ckpt_config_path,
variant=variant, variant=variant,
prediction_type=prediction_type,
upcast_attention=upcast_attention,
) )
@classproperty @classproperty
@ -225,14 +213,12 @@ class StableDiffusion2Model(DiffusersModel):
config: ModelConfigBase, config: ModelConfigBase,
base_model: BaseModelType, base_model: BaseModelType,
) -> str: ) -> str:
assert model_path == config.path
if isinstance(config, cls.CheckpointConfig): if isinstance(config, cls.CheckpointConfig):
return _convert_ckpt_and_cache( return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion2, version=BaseModelType.StableDiffusion2,
model_config=config, model_config=config,
output_path=output_path, output_path=output_path,
) # TODO: args )
else: else:
return model_path return model_path
@ -243,18 +229,18 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml", ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
}, },
BaseModelType.StableDiffusion2: { BaseModelType.StableDiffusion2: {
# code further will manually set upcast_attention and v_prediction ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512)
ModelVariantType.Normal: "v2-inference.yaml",
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml", ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
ModelVariantType.Depth: "v2-midas-inference.yaml", ModelVariantType.Depth: "v2-midas-inference.yaml",
} }
} }
app_config = InvokeAIAppConfig.get_config()
try: try:
# TODO: path config_path = app_config.legacy_conf_path / ckpt_configs[version][variant]
#model_config.config = app_config.config_dir / "stable-diffusion" / ckpt_configs[version][model_config.variant] if config_path.is_relative_to(app_config.root_path):
#return InvokeAIAppConfig.get_config().legacy_conf_dir / ckpt_configs[version][variant] config_path = config_path.relative_to(app_config.root_path)
return InvokeAIAppConfig.get_config().root_dir / "configs" / "stable-diffusion" / ckpt_configs[version][variant] return str(config_path)
except: except:
return None return None
@ -273,36 +259,14 @@ def _convert_ckpt_and_cache(
""" """
app_config = InvokeAIAppConfig.get_config() app_config = InvokeAIAppConfig.get_config()
if model_config.config is None:
model_config.config = _select_ckpt_config(version, model_config.variant)
if model_config.config is None:
raise Exception(f"Model variant {model_config.variant} not supported for {version}")
weights = app_config.root_path / model_config.path weights = app_config.root_path / model_config.path
config_file = app_config.root_path / model_config.config config_file = app_config.root_path / model_config.config
output_path = Path(output_path) output_path = Path(output_path)
if version == BaseModelType.StableDiffusion1:
upcast_attention = False
prediction_type = SchedulerPredictionType.Epsilon
elif version == BaseModelType.StableDiffusion2:
upcast_attention = model_config.upcast_attention
prediction_type = model_config.prediction_type
else:
raise Exception(f"Unknown model provided: {version}")
# return cached version if it exists # return cached version if it exists
if output_path.exists(): if output_path.exists():
return output_path return output_path
# TODO: I think that it more correctly to convert with embedded vae
# as if user will delete custom vae he will got not embedded but also custom vae
#vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
# to avoid circular import errors # to avoid circular import errors
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
with SilenceWarnings(): with SilenceWarnings():
@ -313,9 +277,6 @@ def _convert_ckpt_and_cache(
model_variant=model_config.variant, model_variant=model_config.variant,
original_config_file=config_file, original_config_file=config_file,
extract_ema=True, extract_ema=True,
upcast_attention=upcast_attention,
prediction_type=prediction_type,
scan_needed=True, scan_needed=True,
model_root=app_config.models_path,
) )
return output_path return output_path

View File

@ -16,6 +16,7 @@ from .util import (
download_with_resume, download_with_resume,
instantiate_from_config, instantiate_from_config,
url_attachment_name, url_attachment_name,
Chdir
) )

View File

@ -381,3 +381,18 @@ def image_to_dataURL(image: Image.Image, image_format: str = "PNG") -> str:
buffered.getvalue() buffered.getvalue()
).decode("UTF-8") ).decode("UTF-8")
return image_base64 return image_base64
class Chdir(object):
'''Context manager to chdir to desired directory and change back after context exits:
Args:
path (Path): The path to the cwd
'''
def __init__(self, path: Path):
self.path = path
self.original = Path().absolute()
def __enter__(self):
os.chdir(self.path)
def __exit__(self,*args):
os.chdir(self.original)

View File

@ -1,107 +1,92 @@
# This file predefines a few models that the user may want to install. # This file predefines a few models that the user may want to install.
diffusers: sd-1/main/stable-diffusion-v1-5:
stable-diffusion-1.5:
description: Stable Diffusion version 1.5 diffusers model (4.27 GB) description: Stable Diffusion version 1.5 diffusers model (4.27 GB)
repo_id: runwayml/stable-diffusion-v1-5 repo_id: runwayml/stable-diffusion-v1-5
format: diffusers
vae:
repo_id: stabilityai/sd-vae-ft-mse
recommended: True recommended: True
default: True default: True
sd-inpainting-1.5: sd-1/main/stable-diffusion-inpainting:
description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB) description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB)
repo_id: runwayml/stable-diffusion-inpainting repo_id: runwayml/stable-diffusion-inpainting
format: diffusers
vae:
repo_id: stabilityai/sd-vae-ft-mse
recommended: True recommended: True
stable-diffusion-2.1: sd-2/main/stable-diffusion-2-1:
description: Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB) description: Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB)
repo_id: stabilityai/stable-diffusion-2-1 repo_id: stabilityai/stable-diffusion-2-1
format: diffusers
recommended: True recommended: True
sd-inpainting-2.0: sd-2/main/stable-diffusion-2-inpainting:
description: Stable Diffusion version 2.0 inpainting model (5.21 GB) description: Stable Diffusion version 2.0 inpainting model (5.21 GB)
repo_id: stabilityai/stable-diffusion-2-inpainting repo_id: stabilityai/stable-diffusion-2-inpainting
format: diffusers
recommended: False recommended: False
analog-diffusion-1.0: sd-1/main/Analog-Diffusion:
description: An SD-1.5 model trained on diverse analog photographs (2.13 GB) description: An SD-1.5 model trained on diverse analog photographs (2.13 GB)
repo_id: wavymulder/Analog-Diffusion repo_id: wavymulder/Analog-Diffusion
format: diffusers
recommended: false recommended: false
deliberate-1.0: sd-1/main/Deliberate:
description: Versatile model that produces detailed images up to 768px (4.27 GB) description: Versatile model that produces detailed images up to 768px (4.27 GB)
format: diffusers
repo_id: XpucT/Deliberate repo_id: XpucT/Deliberate
recommended: False recommended: False
d&d-diffusion-1.0: sd-1/main/Dungeons-and-Diffusion:
description: Dungeons & Dragons characters (2.13 GB) description: Dungeons & Dragons characters (2.13 GB)
format: diffusers
repo_id: 0xJustin/Dungeons-and-Diffusion repo_id: 0xJustin/Dungeons-and-Diffusion
recommended: False recommended: False
dreamlike-photoreal-2.0: sd-1/main/dreamlike-photoreal-2:
description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB) description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB)
format: diffusers
repo_id: dreamlike-art/dreamlike-photoreal-2.0 repo_id: dreamlike-art/dreamlike-photoreal-2.0
recommended: False recommended: False
inkpunk-1.0: sd-1/main/Inkpunk-Diffusion:
description: Stylized illustrations inspired by Gorillaz, FLCL and Shinkawa; prompt with "nvinkpunk" (4.27 GB) description: Stylized illustrations inspired by Gorillaz, FLCL and Shinkawa; prompt with "nvinkpunk" (4.27 GB)
format: diffusers
repo_id: Envvi/Inkpunk-Diffusion repo_id: Envvi/Inkpunk-Diffusion
recommended: False recommended: False
openjourney-4.0: sd-1/main/openjourney:
description: An SD 1.5 model fine tuned on Midjourney; prompt with "mdjrny-v4 style" (2.13 GB) description: An SD 1.5 model fine tuned on Midjourney; prompt with "mdjrny-v4 style" (2.13 GB)
format: diffusers
repo_id: prompthero/openjourney repo_id: prompthero/openjourney
vae:
repo_id: stabilityai/sd-vae-ft-mse
recommended: False recommended: False
portrait-plus-1.0: sd-1/main/portraitplus:
description: An SD-1.5 model trained on close range portraits of people; prompt with "portrait+" (2.13 GB) description: An SD-1.5 model trained on close range portraits of people; prompt with "portrait+" (2.13 GB)
format: diffusers
repo_id: wavymulder/portraitplus repo_id: wavymulder/portraitplus
recommended: False recommended: False
seek-art-mega-1.0: sd-1/main/seek.art_MEGA:
description: A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB)
repo_id: coreco/seek.art_MEGA repo_id: coreco/seek.art_MEGA
format: diffusers description: A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB)
vae:
repo_id: stabilityai/sd-vae-ft-mse
recommended: False recommended: False
trinart-2.0: sd-1/main/trinart_stable_diffusion_v2:
description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB) description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB)
repo_id: naclbit/trinart_stable_diffusion_v2 repo_id: naclbit/trinart_stable_diffusion_v2
format: diffusers
vae:
repo_id: stabilityai/sd-vae-ft-mse
recommended: False recommended: False
waifu-diffusion-1.4: sd-1/main/waifu-diffusion:
description: An SD-1.5 model trained on 680k anime/manga-style images (2.13 GB) description: An SD-1.5 model trained on 680k anime/manga-style images (2.13 GB)
repo_id: hakurei/waifu-diffusion repo_id: hakurei/waifu-diffusion
format: diffusers
vae:
repo_id: stabilityai/sd-vae-ft-mse
recommended: False recommended: False
controlnet: sd-1/controlnet/canny:
canny: lllyasviel/control_v11p_sd15_canny repo_id: lllyasviel/control_v11p_sd15_canny
inpaint: lllyasviel/control_v11p_sd15_inpaint sd-1/controlnet/inpaint:
mlsd: lllyasviel/control_v11p_sd15_mlsd repo_id: lllyasviel/control_v11p_sd15_inpaint
depth: lllyasviel/control_v11f1p_sd15_depth sd-1/controlnet/mlsd:
normal_bae: lllyasviel/control_v11p_sd15_normalbae repo_id: lllyasviel/control_v11p_sd15_mlsd
seg: lllyasviel/control_v11p_sd15_seg sd-1/controlnet/depth:
lineart: lllyasviel/control_v11p_sd15_lineart repo_id: lllyasviel/control_v11f1p_sd15_depth
lineart_anime: lllyasviel/control_v11p_sd15s2_lineart_anime sd-1/controlnet/normal_bae:
scribble: lllyasviel/control_v11p_sd15_scribble repo_id: lllyasviel/control_v11p_sd15_normalbae
softedge: lllyasviel/control_v11p_sd15_softedge sd-1/controlnet/seg:
shuffle: lllyasviel/control_v11e_sd15_shuffle repo_id: lllyasviel/control_v11p_sd15_seg
tile: lllyasviel/control_v11f1e_sd15_tile sd-1/controlnet/lineart:
ip2p: lllyasviel/control_v11e_sd15_ip2p repo_id: lllyasviel/control_v11p_sd15_lineart
textual_inversion: sd-1/controlnet/lineart_anime:
'EasyNegative': https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors repo_id: lllyasviel/control_v11p_sd15s2_lineart_anime
'ahx-beta-453407d': sd-concepts-library/ahx-beta-453407d sd-1/controlnet/scribble:
lora: repo_id: lllyasviel/control_v11p_sd15_scribble
'LowRA': https://civitai.com/api/download/models/63006 sd-1/controlnet/softedge:
'Ink scenery': https://civitai.com/api/download/models/83390 repo_id: lllyasviel/control_v11p_sd15_softedge
'sd-model-finetuned-lora-t4': sayakpaul/sd-model-finetuned-lora-t4 sd-1/controlnet/shuffle:
repo_id: lllyasviel/control_v11e_sd15_shuffle
sd-1/controlnet/tile:
repo_id: lllyasviel/control_v11f1e_sd15_tile
sd-1/controlnet/ip2p:
repo_id: lllyasviel/control_v11e_sd15_ip2p
sd-1/embedding/EasyNegative:
path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
sd-1/embedding/ahx-beta-453407d:
repo_id: sd-concepts-library/ahx-beta-453407d
sd-1/lora/LowRA:
path: https://civitai.com/api/download/models/63006
sd-1/lora/Ink scenery:
path: https://civitai.com/api/download/models/83390

View File

@ -0,0 +1,159 @@
model:
base_learning_rate: 5.0e-05
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
parameterization: "v"
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: hybrid
scale_factor: 0.18215
monitor: val/loss_simple_ema
finetune_keys: null
use_ema: False
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
image_size: 32 # unused
in_channels: 9
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"
data:
target: ldm.data.laion.WebDataModuleFromConfig
params:
tar_base: null # for concat as in LAION-A
p_unsafe_threshold: 0.1
filter_word_list: "data/filters.yaml"
max_pwatermark: 0.45
batch_size: 8
num_workers: 6
multinode: True
min_size: 512
train:
shards:
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar"
shuffle: 10000
image_key: jpg
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 512
interpolation: 3
- target: torchvision.transforms.RandomCrop
params:
size: 512
postprocess:
target: ldm.data.laion.AddMask
params:
mode: "512train-large"
p_drop: 0.25
# NOTE use enough shards to avoid empty validation loops in workers
validation:
shards:
- "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - "
shuffle: 0
image_key: jpg
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 512
interpolation: 3
- target: torchvision.transforms.CenterCrop
params:
size: 512
postprocess:
target: ldm.data.laion.AddMask
params:
mode: "512train-large"
p_drop: 0.25
lightning:
find_unused_parameters: True
modelcheckpoint:
params:
every_n_train_steps: 5000
callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 10000
image_logger:
target: main.ImageLogger
params:
enable_autocast: False
disabled: False
batch_frequency: 1000
max_images: 4
increase_log_steps: False
log_first_step: False
log_images_kwargs:
use_ema_scope: False
inpaint: False
plot_progressive_rows: False
plot_diffusion_rows: False
N: 4
unconditional_guidance_scale: 5.0
unconditional_guidance_label: [""]
ddim_steps: 50 # todo check these out for depth2img,
ddim_eta: 0.0 # todo check these out for depth2img,
trainer:
benchmark: True
val_check_interval: 5000000
num_sanity_val_steps: 0
accumulate_grad_batches: 1

View File

@ -0,0 +1,158 @@
model:
base_learning_rate: 5.0e-05
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: hybrid
scale_factor: 0.18215
monitor: val/loss_simple_ema
finetune_keys: null
use_ema: False
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
image_size: 32 # unused
in_channels: 9
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"
data:
target: ldm.data.laion.WebDataModuleFromConfig
params:
tar_base: null # for concat as in LAION-A
p_unsafe_threshold: 0.1
filter_word_list: "data/filters.yaml"
max_pwatermark: 0.45
batch_size: 8
num_workers: 6
multinode: True
min_size: 512
train:
shards:
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar"
shuffle: 10000
image_key: jpg
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 512
interpolation: 3
- target: torchvision.transforms.RandomCrop
params:
size: 512
postprocess:
target: ldm.data.laion.AddMask
params:
mode: "512train-large"
p_drop: 0.25
# NOTE use enough shards to avoid empty validation loops in workers
validation:
shards:
- "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - "
shuffle: 0
image_key: jpg
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 512
interpolation: 3
- target: torchvision.transforms.CenterCrop
params:
size: 512
postprocess:
target: ldm.data.laion.AddMask
params:
mode: "512train-large"
p_drop: 0.25
lightning:
find_unused_parameters: True
modelcheckpoint:
params:
every_n_train_steps: 5000
callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 10000
image_logger:
target: main.ImageLogger
params:
enable_autocast: False
disabled: False
batch_frequency: 1000
max_images: 4
increase_log_steps: False
log_first_step: False
log_images_kwargs:
use_ema_scope: False
inpaint: False
plot_progressive_rows: False
plot_diffusion_rows: False
N: 4
unconditional_guidance_scale: 5.0
unconditional_guidance_label: [""]
ddim_steps: 50 # todo check these out for depth2img,
ddim_eta: 0.0 # todo check these out for depth2img,
trainer:
benchmark: True
val_check_interval: 5000000
num_sanity_val_steps: 0
accumulate_grad_batches: 1

View File

@ -11,7 +11,6 @@ The work is actually done in backend code in model_install_backend.py.
import argparse import argparse
import curses import curses
import os
import sys import sys
import textwrap import textwrap
import traceback import traceback
@ -20,28 +19,22 @@ from multiprocessing import Process
from multiprocessing.connection import Connection, Pipe from multiprocessing.connection import Connection, Pipe
from pathlib import Path from pathlib import Path
from shutil import get_terminal_size from shutil import get_terminal_size
from typing import List
import logging import logging
import npyscreen import npyscreen
import torch import torch
from npyscreen import widget from npyscreen import widget
from omegaconf import OmegaConf
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.install.model_install_backend import ( from invokeai.backend.install.model_install_backend import (
Dataset_path,
default_config_file,
default_dataset,
install_requested_models,
recommended_datasets,
ModelInstallList, ModelInstallList,
UserSelections, InstallSelections,
ModelInstall,
SchedulerPredictionType,
) )
from invokeai.backend import ModelManager from invokeai.backend.model_management import ModelManager, ModelType
from invokeai.backend.util import choose_precision, choose_torch_device from invokeai.backend.util import choose_precision, choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.install.widgets import ( from invokeai.frontend.install.widgets import (
CenteredTitleText, CenteredTitleText,
MultiSelectColumns, MultiSelectColumns,
@ -58,6 +51,7 @@ from invokeai.frontend.install.widgets import (
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
config = InvokeAIAppConfig.get_config() config = InvokeAIAppConfig.get_config()
logger = InvokeAILogger.getLogger()
# build a table mapping all non-printable characters to None # build a table mapping all non-printable characters to None
# for stripping control characters # for stripping control characters
@ -71,8 +65,8 @@ def make_printable(s:str)->str:
return s.translate(NOPRINT_TRANS_TABLE) return s.translate(NOPRINT_TRANS_TABLE)
class addModelsForm(CyclingForm, npyscreen.FormMultiPage): class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
# for responsive resizing - disabled # for responsive resizing set to False, but this seems to cause a crash!
# FIX_MINIMUM_SIZE_WHEN_CREATED = False FIX_MINIMUM_SIZE_WHEN_CREATED = True
# for persistence # for persistence
current_tab = 0 current_tab = 0
@ -90,25 +84,10 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
if not config.model_conf_path.exists(): if not config.model_conf_path.exists():
with open(config.model_conf_path,'w') as file: with open(config.model_conf_path,'w') as file:
print('# InvokeAI model configuration file',file=file) print('# InvokeAI model configuration file',file=file)
model_manager = ModelManager(config.model_conf_path) self.installer = ModelInstall(config)
self.all_models = self.installer.all_models()
self.starter_models = OmegaConf.load(Dataset_path)['diffusers'] self.starter_models = self.installer.starter_models()
self.installed_diffusers_models = self.list_additional_diffusers_models( self.model_labels = self._get_model_labels()
model_manager,
self.starter_models,
)
self.installed_cn_models = model_manager.list_controlnet_models()
self.installed_lora_models = model_manager.list_lora_models()
self.installed_ti_models = model_manager.list_ti_models()
try:
self.existing_models = OmegaConf.load(default_config_file())
except:
self.existing_models = dict()
self.starter_model_list = list(self.starter_models.keys())
self.installed_models = dict()
window_width, window_height = get_terminal_size() window_width, window_height = get_terminal_size()
self.nextrely -= 1 self.nextrely -= 1
@ -143,37 +122,35 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
self.tabs.on_changed = self._toggle_tables self.tabs.on_changed = self._toggle_tables
top_of_table = self.nextrely top_of_table = self.nextrely
self.starter_diffusers_models = self.add_starter_diffusers() self.starter_pipelines = self.add_starter_pipelines()
bottom_of_table = self.nextrely bottom_of_table = self.nextrely
self.nextrely = top_of_table self.nextrely = top_of_table
self.diffusers_models = self.add_diffusers_widgets( self.pipeline_models = self.add_pipeline_widgets(
predefined_models=self.installed_diffusers_models, model_type=ModelType.Main,
model_type='Diffusers',
window_width=window_width, window_width=window_width,
exclude = self.starter_models
) )
# self.pipeline_models['autoload_pending'] = True
bottom_of_table = max(bottom_of_table,self.nextrely) bottom_of_table = max(bottom_of_table,self.nextrely)
self.nextrely = top_of_table self.nextrely = top_of_table
self.controlnet_models = self.add_model_widgets( self.controlnet_models = self.add_model_widgets(
predefined_models=self.installed_cn_models, model_type=ModelType.ControlNet,
model_type='ControlNet',
window_width=window_width, window_width=window_width,
) )
bottom_of_table = max(bottom_of_table,self.nextrely) bottom_of_table = max(bottom_of_table,self.nextrely)
self.nextrely = top_of_table self.nextrely = top_of_table
self.lora_models = self.add_model_widgets( self.lora_models = self.add_model_widgets(
predefined_models=self.installed_lora_models, model_type=ModelType.Lora,
model_type="LoRA/LyCORIS",
window_width=window_width, window_width=window_width,
) )
bottom_of_table = max(bottom_of_table,self.nextrely) bottom_of_table = max(bottom_of_table,self.nextrely)
self.nextrely = top_of_table self.nextrely = top_of_table
self.ti_models = self.add_model_widgets( self.ti_models = self.add_model_widgets(
predefined_models=self.installed_ti_models, model_type=ModelType.TextualInversion,
model_type="Textual Inversion Embeddings",
window_width=window_width, window_width=window_width,
) )
bottom_of_table = max(bottom_of_table,self.nextrely) bottom_of_table = max(bottom_of_table,self.nextrely)
@ -184,7 +161,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
BufferBox, BufferBox,
name='Log Messages', name='Log Messages',
editable=False, editable=False,
max_height = 16, max_height = 10,
) )
self.nextrely += 1 self.nextrely += 1
@ -197,6 +174,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
rely=-3, rely=-3,
when_pressed_function=self.on_back, when_pressed_function=self.on_back,
) )
else:
self.ok_button = self.add_widget_intelligent( self.ok_button = self.add_widget_intelligent(
npyscreen.ButtonPress, npyscreen.ButtonPress,
name=done_label, name=done_label,
@ -220,18 +198,15 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
self._toggle_tables([self.current_tab]) self._toggle_tables([self.current_tab])
############# diffusers tab ########## ############# diffusers tab ##########
def add_starter_diffusers(self)->dict[str, npyscreen.widget]: def add_starter_pipelines(self)->dict[str, npyscreen.widget]:
'''Add widgets responsible for selecting diffusers models''' '''Add widgets responsible for selecting diffusers models'''
widgets = dict() widgets = dict()
models = self.all_models
starters = self.starter_models
starter_model_labels = self.model_labels
starter_model_labels = self._get_starter_model_labels()
recommended_models = [
x
for x in self.starter_model_list
if self.starter_models[x].get("recommended", False)
]
self.installed_models = sorted( self.installed_models = sorted(
[x for x in list(self.starter_models.keys()) if x in self.existing_models] [x for x in starters if models[x].installed]
) )
widgets.update( widgets.update(
@ -246,55 +221,46 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
self.nextrely -= 1 self.nextrely -= 1
# if user has already installed some initial models, then don't patronize them # if user has already installed some initial models, then don't patronize them
# by showing more recommendations # by showing more recommendations
show_recommended = not self.existing_models show_recommended = len(self.installed_models)==0
keys = [x for x in models.keys() if x in starters]
widgets.update( widgets.update(
models_selected = self.add_widget_intelligent( models_selected = self.add_widget_intelligent(
MultiSelectColumns, MultiSelectColumns,
columns=1, columns=1,
name="Install Starter Models", name="Install Starter Models",
values=starter_model_labels, values=[starter_model_labels[x] for x in keys],
value=[ value=[
self.starter_model_list.index(x) keys.index(x)
for x in self.starter_model_list for x in keys
if (show_recommended and x in recommended_models)\ if (show_recommended and models[x].recommended) \
or (x in self.existing_models) or (x in self.installed_models)
], ],
max_height=len(starter_model_labels) + 1, max_height=len(starters) + 1,
relx=4, relx=4,
scroll_exit=True, scroll_exit=True,
),
models = keys,
) )
)
widgets.update(
purge_deleted = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Purge unchecked diffusers models from disk",
value=False,
scroll_exit=True,
relx=4,
)
)
widgets['purge_deleted'].when_value_edited = lambda: self.sync_purge_buttons(widgets['purge_deleted'])
self.nextrely += 1 self.nextrely += 1
return widgets return widgets
############# Add a set of model install widgets ######## ############# Add a set of model install widgets ########
def add_model_widgets(self, def add_model_widgets(self,
predefined_models: dict[str,bool], model_type: ModelType,
model_type: str,
window_width: int=120, window_width: int=120,
install_prompt: str=None, install_prompt: str=None,
add_purge_deleted: bool=False, exclude: set=set(),
)->dict[str,npyscreen.widget]: )->dict[str,npyscreen.widget]:
'''Generic code to create model selection widgets''' '''Generic code to create model selection widgets'''
widgets = dict() widgets = dict()
model_list = sorted(predefined_models.keys()) model_list = [x for x in self.all_models if self.all_models[x].model_type==model_type and not x in exclude]
model_labels = [self.model_labels[x] for x in model_list]
if len(model_list) > 0: if len(model_list) > 0:
max_width = max([len(x) for x in model_list]) max_width = max([len(x) for x in model_labels])
columns = window_width // (max_width+8) # 8 characters for "[x] " and padding columns = window_width // (max_width+8) # 8 characters for "[x] " and padding
columns = min(len(model_list),columns) or 1 columns = min(len(model_list),columns) or 1
prompt = install_prompt or f"Select the desired {model_type} models to install. Unchecked models will be purged from disk." prompt = install_prompt or f"Select the desired {model_type.value.title()} models to install. Unchecked models will be purged from disk."
widgets.update( widgets.update(
label1 = self.add_widget_intelligent( label1 = self.add_widget_intelligent(
@ -310,30 +276,18 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
MultiSelectColumns, MultiSelectColumns,
columns=columns, columns=columns,
name=f"Install {model_type} Models", name=f"Install {model_type} Models",
values=model_list, values=model_labels,
value=[ value=[
model_list.index(x) model_list.index(x)
for x in model_list for x in model_list
if predefined_models[x] if self.all_models[x].installed
], ],
max_height=len(model_list)//columns + 1, max_height=len(model_list)//columns + 1,
relx=4, relx=4,
scroll_exit=True, scroll_exit=True,
),
models = model_list,
) )
)
if add_purge_deleted:
self.nextrely += 1
widgets.update(
purge_deleted = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Purge unchecked diffusers models from disk",
value=False,
scroll_exit=True,
relx=4,
)
)
widgets['purge_deleted'].when_value_edited = lambda: self.sync_purge_buttons(widgets['purge_deleted'])
self.nextrely += 1 self.nextrely += 1
widgets.update( widgets.update(
@ -348,63 +302,33 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
return widgets return widgets
### Tab for arbitrary diffusers widgets ### ### Tab for arbitrary diffusers widgets ###
def add_diffusers_widgets(self, def add_pipeline_widgets(self,
predefined_models: dict[str,bool], model_type: ModelType=ModelType.Main,
model_type: str='Diffusers',
window_width: int=120, window_width: int=120,
**kwargs,
)->dict[str,npyscreen.widget]: )->dict[str,npyscreen.widget]:
'''Similar to add_model_widgets() but adds some additional widgets at the bottom '''Similar to add_model_widgets() but adds some additional widgets at the bottom
to support the autoload directory''' to support the autoload directory'''
widgets = self.add_model_widgets( widgets = self.add_model_widgets(
predefined_models, model_type = model_type,
'Diffusers', window_width = window_width,
window_width, install_prompt=f"Additional {model_type.value.title()} models already installed.",
install_prompt="Additional diffusers models already installed.", **kwargs,
add_purge_deleted=True
) )
label = "Directory to scan for models to automatically import (<tab> autocompletes):"
self.nextrely += 1
widgets.update(
autoload_directory = self.add_widget_intelligent(
FileBox,
max_height=3,
name=label,
value=str(config.autoconvert_dir) if config.autoconvert_dir else None,
select_dir=True,
must_exist=True,
use_two_lines=False,
labelColor="DANGER",
begin_entry_at=len(label)+1,
scroll_exit=True,
)
)
widgets.update(
autoscan_on_startup = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Scan and import from this directory each time InvokeAI starts",
value=config.autoconvert_dir is not None,
relx=4,
scroll_exit=True,
)
)
return widgets return widgets
def sync_purge_buttons(self,checkbox):
value = checkbox.value
self.starter_diffusers_models['purge_deleted'].value = value
self.diffusers_models['purge_deleted'].value = value
def resize(self): def resize(self):
super().resize() super().resize()
if (s := self.starter_diffusers_models.get("models_selected")): if (s := self.starter_pipelines.get("models_selected")):
s.values = self._get_starter_model_labels() keys = [x for x in self.all_models.keys() if x in self.starter_models]
s.values = [self.model_labels[x] for x in keys]
def _toggle_tables(self, value=None): def _toggle_tables(self, value=None):
selected_tab = value[0] selected_tab = value[0]
widgets = [ widgets = [
self.starter_diffusers_models, self.starter_pipelines,
self.diffusers_models, self.pipeline_models,
self.controlnet_models, self.controlnet_models,
self.lora_models, self.lora_models,
self.ti_models, self.ti_models,
@ -412,34 +336,38 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
for group in widgets: for group in widgets:
for k,v in group.items(): for k,v in group.items():
try:
v.hidden = True v.hidden = True
v.editable = False v.editable = False
except:
pass
for k,v in widgets[selected_tab].items(): for k,v in widgets[selected_tab].items():
try:
v.hidden = False v.hidden = False
if not isinstance(v,(npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)): if not isinstance(v,(npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)):
v.editable = True v.editable = True
except:
pass
self.__class__.current_tab = selected_tab # for persistence self.__class__.current_tab = selected_tab # for persistence
self.display() self.display()
def _get_starter_model_labels(self) -> List[str]: def _get_model_labels(self) -> dict[str,str]:
window_width, window_height = get_terminal_size() window_width, window_height = get_terminal_size()
label_width = 25
checkbox_width = 4 checkbox_width = 4
spacing_width = 2 spacing_width = 2
description_width = window_width - label_width - checkbox_width - spacing_width
im = self.starter_models
names = self.starter_model_list
descriptions = [
im[x].description[0 : description_width - 3] + "..."
if len(im[x].description) > description_width
else im[x].description
for x in names
]
return [
f"%-{label_width}s %s" % (names[x], descriptions[x])
for x in range(0, len(names))
]
models = self.all_models
label_width = max([len(models[x].name) for x in models])
description_width = window_width - label_width - checkbox_width - spacing_width
result = dict()
for x in models.keys():
description = models[x].description
description = description[0 : description_width - 3] + "..." \
if description and len(description) > description_width \
else description if description else ''
result[x] = f"%-{label_width}s %s" % (models[x].name, description)
return result
def _get_columns(self) -> int: def _get_columns(self) -> int:
window_width, window_height = get_terminal_size() window_width, window_height = get_terminal_size()
@ -467,7 +395,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
target = process_and_execute, target = process_and_execute,
kwargs=dict( kwargs=dict(
opt = app.program_opts, opt = app.program_opts,
selections = app.user_selections, selections = app.install_selections,
conn_out = child_conn, conn_out = child_conn,
) )
) )
@ -475,8 +403,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
child_conn.close() child_conn.close()
self.subprocess_connection = parent_conn self.subprocess_connection = parent_conn
self.subprocess = p self.subprocess = p
app.user_selections = UserSelections() app.install_selections = InstallSelections()
# process_and_execute(app.opt, app.user_selections) # process_and_execute(app.opt, app.install_selections)
def on_back(self): def on_back(self):
self.parentApp.switchFormPrevious() self.parentApp.switchFormPrevious()
@ -548,8 +476,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
# rebuild the form, saving and restoring some of the fields that need to be preserved. # rebuild the form, saving and restoring some of the fields that need to be preserved.
saved_messages = self.monitor.entry_widget.values saved_messages = self.monitor.entry_widget.values
autoload_dir = self.diffusers_models['autoload_directory'].value # autoload_dir = str(config.root_path / self.pipeline_models['autoload_directory'].value)
autoscan = self.diffusers_models['autoscan_on_startup'].value # autoscan = self.pipeline_models['autoscan_on_startup'].value
app.main_form = app.addForm( app.main_form = app.addForm(
"MAIN", addModelsForm, name="Install Stable Diffusion Models", multipage=self.multipage, "MAIN", addModelsForm, name="Install Stable Diffusion Models", multipage=self.multipage,
@ -558,23 +486,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
app.main_form.monitor.entry_widget.values = saved_messages app.main_form.monitor.entry_widget.values = saved_messages
app.main_form.monitor.entry_widget.buffer([''],scroll_end=True) app.main_form.monitor.entry_widget.buffer([''],scroll_end=True)
app.main_form.diffusers_models['autoload_directory'].value = autoload_dir # app.main_form.pipeline_models['autoload_directory'].value = autoload_dir
app.main_form.diffusers_models['autoscan_on_startup'].value = autoscan # app.main_form.pipeline_models['autoscan_on_startup'].value = autoscan
###############################################################
def list_additional_diffusers_models(self,
manager: ModelManager,
starters:dict
)->dict[str,bool]:
'''Return a dict of all the currently installed models that are not on the starter list'''
model_info = manager.list_models()
additional_models = {
x:True for x in model_info \
if model_info[x]['format']=='diffusers' \
and x not in starters
}
return additional_models
def marshall_arguments(self): def marshall_arguments(self):
""" """
@ -586,89 +499,40 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
.autoscan_on_startup: True if invokeai should scan and import at startup time .autoscan_on_startup: True if invokeai should scan and import at startup time
.import_model_paths: list of URLs, repo_ids and file paths to import .import_model_paths: list of URLs, repo_ids and file paths to import
""" """
# we're using a global here rather than storing the result in the parentapp selections = self.parentApp.install_selections
# due to some bug in npyscreen that is causing attributes to be lost all_models = self.all_models
selections = self.parentApp.user_selections
# Starter models to install/remove # Defined models (in INITIAL_CONFIG.yaml or models.yaml) to add/remove
starter_models = dict( ui_sections = [self.starter_pipelines, self.pipeline_models,
map( self.controlnet_models, self.lora_models, self.ti_models]
lambda x: (self.starter_model_list[x], True), for section in ui_sections:
self.starter_diffusers_models['models_selected'].value, if not 'models_selected' in section:
) continue
) selected = set([section['models'][x] for x in section['models_selected'].value])
selections.purge_deleted_models = self.starter_diffusers_models['purge_deleted'].value or \ models_to_install = [x for x in selected if not self.all_models[x].installed]
self.diffusers_models['purge_deleted'].value models_to_remove = [x for x in section['models'] if x not in selected and self.all_models[x].installed]
selections.remove_models.extend(models_to_remove)
selections.install_models.extend(all_models[x].path or all_models[x].repo_id \
for x in models_to_install if all_models[x].path or all_models[x].repo_id)
selections.install_models = [x for x in starter_models if x not in self.existing_models] # models located in the 'download_ids" section
selections.remove_models = [x for x in self.starter_model_list if x in self.existing_models and x not in starter_models] for section in ui_sections:
if downloads := section.get('download_ids'):
# "More" models selections.install_models.extend(downloads.value.split())
selections.import_model_paths = self.diffusers_models['download_ids'].value.split()
if diffusers_selected := self.diffusers_models.get('models_selected'):
selections.remove_models.extend([x
for x in diffusers_selected.values
if self.installed_diffusers_models[x]
and diffusers_selected.values.index(x) not in diffusers_selected.value
]
)
# TODO: REFACTOR THIS REPETITIVE CODE
if cn_models_selected := self.controlnet_models.get('models_selected'):
selections.install_cn_models = [cn_models_selected.values[x]
for x in cn_models_selected.value
if not self.installed_cn_models[cn_models_selected.values[x]]
]
selections.remove_cn_models = [x
for x in cn_models_selected.values
if self.installed_cn_models[x]
and cn_models_selected.values.index(x) not in cn_models_selected.value
]
if (additional_cns := self.controlnet_models['download_ids'].value.split()):
valid_cns = [x for x in additional_cns if '/' in x]
selections.install_cn_models.extend(valid_cns)
# same thing, for LoRAs
if loras_selected := self.lora_models.get('models_selected'):
selections.install_lora_models = [loras_selected.values[x]
for x in loras_selected.value
if not self.installed_lora_models[loras_selected.values[x]]
]
selections.remove_lora_models = [x
for x in loras_selected.values
if self.installed_lora_models[x]
and loras_selected.values.index(x) not in loras_selected.value
]
if (additional_loras := self.lora_models['download_ids'].value.split()):
selections.install_lora_models.extend(additional_loras)
# same thing, for TIs
# TODO: refactor
if tis_selected := self.ti_models.get('models_selected'):
selections.install_ti_models = [tis_selected.values[x]
for x in tis_selected.value
if not self.installed_ti_models[tis_selected.values[x]]
]
selections.remove_ti_models = [x
for x in tis_selected.values
if self.installed_ti_models[x]
and tis_selected.values.index(x) not in tis_selected.value
]
if (additional_tis := self.ti_models['download_ids'].value.split()):
selections.install_ti_models.extend(additional_tis)
# load directory and whether to scan on startup # load directory and whether to scan on startup
selections.scan_directory = self.diffusers_models['autoload_directory'].value # if self.parentApp.autoload_pending:
selections.autoscan_on_startup = self.diffusers_models['autoscan_on_startup'].value # selections.scan_directory = str(config.root_path / self.pipeline_models['autoload_directory'].value)
# self.parentApp.autoload_pending = False
# selections.autoscan_on_startup = self.pipeline_models['autoscan_on_startup'].value
class AddModelApplication(npyscreen.NPSAppManaged): class AddModelApplication(npyscreen.NPSAppManaged):
def __init__(self,opt): def __init__(self,opt):
super().__init__() super().__init__()
self.program_opts = opt self.program_opts = opt
self.user_cancelled = False self.user_cancelled = False
self.user_selections = UserSelections() # self.autoload_pending = True
self.install_selections = InstallSelections()
def onStart(self): def onStart(self):
npyscreen.setTheme(npyscreen.Themes.DefaultTheme) npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
@ -687,26 +551,22 @@ class StderrToMessage():
pass pass
# -------------------------------------------------------- # --------------------------------------------------------
def ask_user_for_config_file(model_path: Path, def ask_user_for_prediction_type(model_path: Path,
tui_conn: Connection=None tui_conn: Connection=None
)->Path: )->SchedulerPredictionType:
if tui_conn: if tui_conn:
logger.debug('Waiting for user response...') logger.debug('Waiting for user response...')
return _ask_user_for_cf_tui(model_path, tui_conn) return _ask_user_for_pt_tui(model_path, tui_conn)
else: else:
return _ask_user_for_cf_cmdline(model_path) return _ask_user_for_pt_cmdline(model_path)
def _ask_user_for_cf_cmdline(model_path): def _ask_user_for_pt_cmdline(model_path: Path)->SchedulerPredictionType:
choices = [ choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None]
config.legacy_conf_path / x
for x in ['v2-inference.yaml','v2-inference-v.yaml']
]
choices.extend([None])
print( print(
f""" f"""
Please select the type of the V2 checkpoint named {model_path.name}: Please select the type of the V2 checkpoint named {model_path.name}:
[1] A Stable Diffusion v2.x base model (512 pixels; there should be no 'parameterization:' line in its yaml file) [1] A model based on Stable Diffusion v2 trained on 512 pixel images (SD-2-base)
[2] A Stable Diffusion v2.x v-predictive model (768 pixels; look for a 'parameterization: "v"' line in its yaml file) [2] A model based on Stable Diffusion v2 trained on 768 pixel images (SD-2-768)
[3] Skip this model and come back later. [3] Skip this model and come back later.
""" """
) )
@ -723,7 +583,7 @@ Please select the type of the V2 checkpoint named {model_path.name}:
return return
return choice return choice
def _ask_user_for_cf_tui(model_path: Path, tui_conn: Connection)->Path: def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection)->SchedulerPredictionType:
try: try:
tui_conn.send_bytes(f'*need v2 config for:{model_path}'.encode('utf-8')) tui_conn.send_bytes(f'*need v2 config for:{model_path}'.encode('utf-8'))
# note that we don't do any status checking here # note that we don't do any status checking here
@ -731,20 +591,20 @@ def _ask_user_for_cf_tui(model_path: Path, tui_conn: Connection)->Path:
if response is None: if response is None:
return None return None
elif response == 'epsilon': elif response == 'epsilon':
return config.legacy_conf_path / 'v2-inference.yaml' return SchedulerPredictionType.epsilon
elif response == 'v': elif response == 'v':
return config.legacy_conf_path / 'v2-inference-v.yaml' return SchedulerPredictionType.VPrediction
elif response == 'abort': elif response == 'abort':
logger.info('Conversion aborted') logger.info('Conversion aborted')
return None return None
else: else:
return Path(response) return response
except: except:
return None return None
# -------------------------------------------------------- # --------------------------------------------------------
def process_and_execute(opt: Namespace, def process_and_execute(opt: Namespace,
selections: UserSelections, selections: InstallSelections,
conn_out: Connection=None, conn_out: Connection=None,
): ):
# set up so that stderr is sent to conn_out # set up so that stderr is sent to conn_out
@ -756,33 +616,13 @@ def process_and_execute(opt: Namespace,
logger.handlers.clear() logger.handlers.clear()
logger.addHandler(logging.StreamHandler(translator)) logger.addHandler(logging.StreamHandler(translator))
models_to_install = selections.install_models installer = ModelInstall(config, prediction_type_helper=lambda x: ask_user_for_prediction_type(x,conn_out))
models_to_remove = selections.remove_models installer.install(selections)
directory_to_scan = selections.scan_directory
scan_at_startup = selections.autoscan_on_startup
potential_models_to_install = selections.import_model_paths
install_requested_models(
diffusers = ModelInstallList(models_to_install, models_to_remove),
controlnet = ModelInstallList(selections.install_cn_models, selections.remove_cn_models),
lora = ModelInstallList(selections.install_lora_models, selections.remove_lora_models),
ti = ModelInstallList(selections.install_ti_models, selections.remove_ti_models),
scan_directory=Path(directory_to_scan) if directory_to_scan else None,
external_models=potential_models_to_install,
scan_at_startup=scan_at_startup,
precision="float32"
if opt.full_precision
else choose_precision(torch.device(choose_torch_device())),
purge_deleted=selections.purge_deleted_models,
config_file_path=Path(opt.config_file) if opt.config_file else config.model_conf_path,
model_config_file_callback = lambda x: ask_user_for_config_file(x,conn_out)
)
if conn_out: if conn_out:
conn_out.send_bytes('*done*'.encode('utf-8')) conn_out.send_bytes('*done*'.encode('utf-8'))
conn_out.close() conn_out.close()
def do_listings(opt)->bool: def do_listings(opt)->bool:
"""List installed models of various sorts, and return """List installed models of various sorts, and return
True if any were requested.""" True if any were requested."""
@ -813,39 +653,34 @@ def select_and_download_models(opt: Namespace):
if opt.full_precision if opt.full_precision
else choose_precision(torch.device(choose_torch_device())) else choose_precision(torch.device(choose_torch_device()))
) )
config.precision = precision
helper = lambda x: ask_user_for_prediction_type(x)
# if do_listings(opt):
# pass
if do_listings(opt): installer = ModelInstall(config, prediction_type_helper=helper)
pass if opt.add or opt.delete:
# this processes command line additions/removals selections = InstallSelections(
elif opt.diffusers or opt.controlnets or opt.textual_inversions or opt.loras: install_models = opt.add or [],
action = 'remove_models' if opt.delete else 'install_models' remove_models = opt.delete or []
diffusers_args = {'diffusers':ModelInstallList(remove_models=opt.diffusers or [])} \
if opt.delete \
else {'external_models':opt.diffusers or []}
install_requested_models(
**diffusers_args,
controlnet=ModelInstallList(**{action:opt.controlnets or []}),
ti=ModelInstallList(**{action:opt.textual_inversions or []}),
lora=ModelInstallList(**{action:opt.loras or []}),
precision=precision,
purge_deleted=True,
model_config_file_callback=lambda x: ask_user_for_config_file(x),
) )
installer.install(selections)
elif opt.default_only: elif opt.default_only:
install_requested_models( selections = InstallSelections(
diffusers=ModelInstallList(install_models=default_dataset()), install_models = installer.default_model()
precision=precision,
) )
installer.install(selections)
elif opt.yes_to_all: elif opt.yes_to_all:
install_requested_models( selections = InstallSelections(
diffusers=ModelInstallList(install_models=recommended_datasets()), install_models = installer.recommended_models()
precision=precision,
) )
installer.install(selections)
# this is where the TUI is called # this is where the TUI is called
else: else:
# needed because the torch library is loaded, even though we don't use it # needed because the torch library is loaded, even though we don't use it
torch.multiprocessing.set_start_method("spawn") # currently commented out because it has started generating errors (?)
# torch.multiprocessing.set_start_method("spawn")
# the third argument is needed in the Windows 11 environment in # the third argument is needed in the Windows 11 environment in
# order to launch and resize a console window running this program # order to launch and resize a console window running this program
@ -861,35 +696,20 @@ def select_and_download_models(opt: Namespace):
installApp.main_form.subprocess.terminate() installApp.main_form.subprocess.terminate()
installApp.main_form.subprocess = None installApp.main_form.subprocess = None
raise e raise e
process_and_execute(opt, installApp.user_selections) process_and_execute(opt, installApp.install_selections)
# ------------------------------------- # -------------------------------------
def main(): def main():
parser = argparse.ArgumentParser(description="InvokeAI model downloader") parser = argparse.ArgumentParser(description="InvokeAI model downloader")
parser.add_argument( parser.add_argument(
"--diffusers", "--add",
nargs="*", nargs="*",
help="List of URLs or repo_ids of diffusers to install/delete", help="List of URLs, local paths or repo_ids of models to install",
)
parser.add_argument(
"--loras",
nargs="*",
help="List of URLs or repo_ids of LoRA/LyCORIS models to install/delete",
)
parser.add_argument(
"--controlnets",
nargs="*",
help="List of URLs or repo_ids of controlnet models to install/delete",
)
parser.add_argument(
"--textual-inversions",
nargs="*",
help="List of URLs or repo_ids of textual inversion embeddings to install/delete",
) )
parser.add_argument( parser.add_argument(
"--delete", "--delete",
action="store_true", nargs="*",
help="Delete models listed on command line rather than installing them", help="List of names of models to idelete",
) )
parser.add_argument( parser.add_argument(
"--full-precision", "--full-precision",
@ -909,7 +729,7 @@ def main():
parser.add_argument( parser.add_argument(
"--default_only", "--default_only",
action="store_true", action="store_true",
help="only install the default model", help="Only install the default model",
) )
parser.add_argument( parser.add_argument(
"--list-models", "--list-models",

View File

@ -17,8 +17,8 @@ from shutil import get_terminal_size
from curses import BUTTON2_CLICKED,BUTTON3_CLICKED from curses import BUTTON2_CLICKED,BUTTON3_CLICKED
# minimum size for UIs # minimum size for UIs
MIN_COLS = 120 MIN_COLS = 130
MIN_LINES = 50 MIN_LINES = 40
# ------------------------------------- # -------------------------------------
def set_terminal_size(columns: int, lines: int, launch_command: str=None): def set_terminal_size(columns: int, lines: int, launch_command: str=None):
@ -73,6 +73,12 @@ def _set_terminal_size_unix(width: int, height: int):
import fcntl import fcntl
import termios import termios
# These terminals accept the size command and report that the
# size changed, but they lie!!!
for bad_terminal in ['TERMINATOR_UUID', 'ALACRITTY_WINDOW_ID']:
if os.environ.get(bad_terminal):
return
winsize = struct.pack("HHHH", height, width, 0, 0) winsize = struct.pack("HHHH", height, width, 0, 0)
fcntl.ioctl(sys.stdout.fileno(), termios.TIOCSWINSZ, winsize) fcntl.ioctl(sys.stdout.fileno(), termios.TIOCSWINSZ, winsize)
sys.stdout.write("\x1b[8;{height};{width}t".format(height=height, width=width)) sys.stdout.write("\x1b[8;{height};{width}t".format(height=height, width=width))
@ -87,6 +93,12 @@ def set_min_terminal_size(min_cols: int, min_lines: int, launch_command: str=Non
lines = max(term_lines, min_lines) lines = max(term_lines, min_lines)
set_terminal_size(cols, lines, launch_command) set_terminal_size(cols, lines, launch_command)
# did it work?
term_cols, term_lines = get_terminal_size()
if term_cols < cols or term_lines < lines:
print(f'This window is too small for optimal display. For best results please enlarge it.')
input('After resizing, press any key to continue...')
class IntSlider(npyscreen.Slider): class IntSlider(npyscreen.Slider):
def translate_value(self): def translate_value(self):
stri = "%2d / %2d" % (self.value, self.out_of) stri = "%2d / %2d" % (self.value, self.out_of)
@ -390,13 +402,12 @@ def select_stable_diffusion_config_file(
wrap:bool =True, wrap:bool =True,
model_name:str='Unknown', model_name:str='Unknown',
): ):
message = "Please select the correct base model for the V2 checkpoint named {model_name}. Press <CANCEL> to skip installation." message = f"Please select the correct base model for the V2 checkpoint named '{model_name}'. Press <CANCEL> to skip installation."
title = "CONFIG FILE SELECTION" title = "CONFIG FILE SELECTION"
options=[ options=[
"An SD v2.x base model (512 pixels; no 'parameterization:' line in its yaml file)", "An SD v2.x base model (512 pixels; no 'parameterization:' line in its yaml file)",
"An SD v2.x v-predictive model (768 pixels; 'parameterization: \"v\"' line in its yaml file)", "An SD v2.x v-predictive model (768 pixels; 'parameterization: \"v\"' line in its yaml file)",
"Skip installation for now and come back later", "Skip installation for now and come back later",
"Enter config file path manually",
] ]
F = ConfirmCancelPopup( F = ConfirmCancelPopup(
@ -418,35 +429,17 @@ def select_stable_diffusion_config_file(
mlw.values = message mlw.values = message
choice = F.add( choice = F.add(
SingleSelectWithChanged, npyscreen.SelectOne,
values = options, values = options,
value = [0], value = [0],
max_height = len(options)+1, max_height = len(options)+1,
scroll_exit=True, scroll_exit=True,
) )
file = F.add(
FileBox,
name='Path to config file',
max_height=3,
hidden=True,
must_exist=True,
scroll_exit=True
)
def toggle_visible(value):
value = value[0]
if value==3:
file.hidden=False
else:
file.hidden=True
F.display()
choice.on_changed = toggle_visible
F.editw = 1 F.editw = 1
F.edit() F.edit()
if not F.value: if not F.value:
return None return None
assert choice.value[0] in range(0,4),'invalid choice' assert choice.value[0] in range(0,3),'invalid choice'
choices = ['epsilon','v','abort',file.value] choices = ['epsilon','v','abort']
return choices[choice.value[0]] return choices[choice.value[0]]

View File

@ -48,7 +48,7 @@ const App = ({
const isApplicationReady = useIsApplicationReady(); const isApplicationReady = useIsApplicationReady();
const { data: pipelineModels } = useListModelsQuery({ const { data: pipelineModels } = useListModelsQuery({
model_type: 'pipeline', model_type: 'main',
}); });
const { data: controlnetModels } = useListModelsQuery({ const { data: controlnetModels } = useListModelsQuery({
model_type: 'controlnet', model_type: 'controlnet',

View File

@ -35,8 +35,8 @@ const ParamDynamicPromptsCollapse = () => {
withSwitch withSwitch
> >
<Flex sx={{ gap: 2, flexDir: 'column' }}> <Flex sx={{ gap: 2, flexDir: 'column' }}>
<ParamDynamicPromptsMaxPrompts />
<ParamDynamicPromptsCombinatorial /> <ParamDynamicPromptsCombinatorial />
<ParamDynamicPromptsMaxPrompts />
</Flex> </Flex>
</IAICollapse> </IAICollapse>
); );

View File

@ -9,17 +9,18 @@ import { stateSelector } from 'app/store/store';
const selector = createSelector( const selector = createSelector(
stateSelector, stateSelector,
(state) => { (state) => {
const { maxPrompts } = state.dynamicPrompts; const { maxPrompts, combinatorial } = state.dynamicPrompts;
const { min, sliderMax, inputMax } = const { min, sliderMax, inputMax } =
state.config.sd.dynamicPrompts.maxPrompts; state.config.sd.dynamicPrompts.maxPrompts;
return { maxPrompts, min, sliderMax, inputMax }; return { maxPrompts, min, sliderMax, inputMax, combinatorial };
}, },
defaultSelectorOptions defaultSelectorOptions
); );
const ParamDynamicPromptsMaxPrompts = () => { const ParamDynamicPromptsMaxPrompts = () => {
const { maxPrompts, min, sliderMax, inputMax } = useAppSelector(selector); const { maxPrompts, min, sliderMax, inputMax, combinatorial } =
useAppSelector(selector);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const handleChange = useCallback( const handleChange = useCallback(
@ -36,6 +37,7 @@ const ParamDynamicPromptsMaxPrompts = () => {
return ( return (
<IAISlider <IAISlider
label="Max Prompts" label="Max Prompts"
isDisabled={!combinatorial}
min={min} min={min}
max={sliderMax} max={sliderMax}
value={maxPrompts} value={maxPrompts}

View File

@ -23,7 +23,7 @@ const ModelInputFieldComponent = (
const { t } = useTranslation(); const { t } = useTranslation();
const { data: pipelineModels } = useListModelsQuery({ const { data: pipelineModels } = useListModelsQuery({
model_type: 'pipeline', model_type: 'main',
}); });
const data = useMemo(() => { const data = useMemo(() => {

View File

@ -37,7 +37,7 @@ export const addDynamicPromptsToGraph = (
const dynamicPromptNode: DynamicPromptInvocation = { const dynamicPromptNode: DynamicPromptInvocation = {
id: DYNAMIC_PROMPT, id: DYNAMIC_PROMPT,
type: 'dynamic_prompt', type: 'dynamic_prompt',
max_prompts: maxPrompts, max_prompts: combinatorial ? maxPrompts : iterations,
combinatorial, combinatorial,
prompt: positivePrompt, prompt: positivePrompt,
}; };

View File

@ -16,7 +16,8 @@ const selector = createSelector([stateSelector], (state) => {
state.config.sd.iterations; state.config.sd.iterations;
const { iterations } = state.generation; const { iterations } = state.generation;
const { shouldUseSliders } = state.ui; const { shouldUseSliders } = state.ui;
const isDisabled = state.dynamicPrompts.isEnabled; const isDisabled =
state.dynamicPrompts.isEnabled && state.dynamicPrompts.combinatorial;
const step = state.hotkeys.shift ? fineStep : coarseStep; const step = state.hotkeys.shift ? fineStep : coarseStep;

View File

@ -24,7 +24,7 @@ const ModelSelect = () => {
); );
const { data: pipelineModels } = useListModelsQuery({ const { data: pipelineModels } = useListModelsQuery({
model_type: 'pipeline', model_type: 'main',
}); });
const data = useMemo(() => { const data = useMemo(() => {

View File

@ -1030,7 +1030,7 @@ export type components = {
* @description The nodes in this graph * @description The nodes in this graph
*/ */
nodes?: { nodes?: {
[key: string]: (components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined; [key: string]: (components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined;
}; };
/** /**
* Edges * Edges
@ -1073,7 +1073,7 @@ export type components = {
* @description The results of node executions * @description The results of node executions
*/ */
results: { results: {
[key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined; [key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined;
}; };
/** /**
* Errors * Errors
@ -2917,7 +2917,7 @@ export type components = {
/** ModelsList */ /** ModelsList */
ModelsList: { ModelsList: {
/** Models */ /** Models */
models: (components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[]; models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"])[];
}; };
/** /**
* MultiplyInvocation * MultiplyInvocation
@ -2993,6 +2993,18 @@ export type components = {
* @default 512 * @default 512
*/ */
height?: number; height?: number;
/**
* Perlin
* @description The amount of perlin noise to add to the noise
* @default 0
*/
perlin?: number;
/**
* Use Cpu
* @description Use CPU for noise generation (for reproducible results across platforms)
* @default true
*/
use_cpu?: boolean;
}; };
/** /**
* NoiseOutput * NoiseOutput
@ -4177,18 +4189,18 @@ export type components = {
*/ */
image?: components["schemas"]["ImageField"]; image?: components["schemas"]["ImageField"];
}; };
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/** /**
* StableDiffusion2ModelFormat * StableDiffusion2ModelFormat
* @description An enumeration. * @description An enumeration.
* @enum {string} * @enum {string}
*/ */
StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
}; };
responses: never; responses: never;
parameters: never; parameters: never;
@ -4299,7 +4311,7 @@ export type operations = {
}; };
requestBody: { requestBody: {
content: { content: {
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]; "application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
}; };
}; };
responses: { responses: {
@ -4336,7 +4348,7 @@ export type operations = {
}; };
requestBody: { requestBody: {
content: { content: {
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]; "application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
}; };
}; };
responses: { responses: {

View File

@ -4,91 +4,89 @@ import { components } from './schema';
type schemas = components['schemas']; type schemas = components['schemas'];
/** /**
* Helper type to extract the invocation type from the schema. * Extracts the schema type from the schema.
* Also flags the `type` property as required.
*/ */
type Invocation<T extends keyof schemas> = O.Required<schemas[T], 'type'>; type S<T extends keyof components['schemas']> = components['schemas'][T];
/** /**
* Types from the API, re-exported from the types generated by `openapi-typescript`. * Extracts the node type from the schema.
* Also flags the `type` property as required.
*/ */
type N<T extends keyof components['schemas']> = O.Required<
components['schemas'][T],
'type'
>;
// Images // Images
export type ImageDTO = schemas['ImageDTO']; export type ImageDTO = S<'ImageDTO'>;
export type BoardDTO = schemas['BoardDTO']; export type BoardDTO = S<'BoardDTO'>;
export type BoardChanges = schemas['BoardChanges']; export type BoardChanges = S<'BoardChanges'>;
export type ImageChanges = schemas['ImageRecordChanges']; export type ImageChanges = S<'ImageRecordChanges'>;
export type ImageCategory = schemas['ImageCategory']; export type ImageCategory = S<'ImageCategory'>;
export type ResourceOrigin = schemas['ResourceOrigin']; export type ResourceOrigin = S<'ResourceOrigin'>;
export type ImageField = schemas['ImageField']; export type ImageField = S<'ImageField'>;
export type OffsetPaginatedResults_BoardDTO_ = export type OffsetPaginatedResults_BoardDTO_ =
schemas['OffsetPaginatedResults_BoardDTO_']; S<'OffsetPaginatedResults_BoardDTO_'>;
export type OffsetPaginatedResults_ImageDTO_ = export type OffsetPaginatedResults_ImageDTO_ =
schemas['OffsetPaginatedResults_ImageDTO_']; S<'OffsetPaginatedResults_ImageDTO_'>;
// Models // Models
export type ModelType = schemas['ModelType']; export type ModelType = S<'ModelType'>;
export type BaseModelType = schemas['BaseModelType']; export type BaseModelType = S<'BaseModelType'>;
export type PipelineModelField = schemas['PipelineModelField']; export type PipelineModelField = S<'PipelineModelField'>;
export type ModelsList = schemas['ModelsList']; export type ModelsList = S<'ModelsList'>;
// Graphs // Graphs
export type Graph = schemas['Graph']; export type Graph = S<'Graph'>;
export type Edge = schemas['Edge']; export type Edge = S<'Edge'>;
export type GraphExecutionState = schemas['GraphExecutionState']; export type GraphExecutionState = S<'GraphExecutionState'>;
// General nodes // General nodes
export type CollectInvocation = Invocation<'CollectInvocation'>; export type CollectInvocation = N<'CollectInvocation'>;
export type IterateInvocation = Invocation<'IterateInvocation'>; export type IterateInvocation = N<'IterateInvocation'>;
export type RangeInvocation = Invocation<'RangeInvocation'>; export type RangeInvocation = N<'RangeInvocation'>;
export type RandomRangeInvocation = Invocation<'RandomRangeInvocation'>; export type RandomRangeInvocation = N<'RandomRangeInvocation'>;
export type RangeOfSizeInvocation = Invocation<'RangeOfSizeInvocation'>; export type RangeOfSizeInvocation = N<'RangeOfSizeInvocation'>;
export type InpaintInvocation = Invocation<'InpaintInvocation'>; export type InpaintInvocation = N<'InpaintInvocation'>;
export type ImageResizeInvocation = Invocation<'ImageResizeInvocation'>; export type ImageResizeInvocation = N<'ImageResizeInvocation'>;
export type RandomIntInvocation = Invocation<'RandomIntInvocation'>; export type RandomIntInvocation = N<'RandomIntInvocation'>;
export type CompelInvocation = Invocation<'CompelInvocation'>; export type CompelInvocation = N<'CompelInvocation'>;
export type DynamicPromptInvocation = Invocation<'DynamicPromptInvocation'>; export type DynamicPromptInvocation = N<'DynamicPromptInvocation'>;
export type NoiseInvocation = Invocation<'NoiseInvocation'>; export type NoiseInvocation = N<'NoiseInvocation'>;
export type TextToLatentsInvocation = Invocation<'TextToLatentsInvocation'>; export type TextToLatentsInvocation = N<'TextToLatentsInvocation'>;
export type LatentsToLatentsInvocation = export type LatentsToLatentsInvocation = N<'LatentsToLatentsInvocation'>;
Invocation<'LatentsToLatentsInvocation'>; export type ImageToLatentsInvocation = N<'ImageToLatentsInvocation'>;
export type ImageToLatentsInvocation = Invocation<'ImageToLatentsInvocation'>; export type LatentsToImageInvocation = N<'LatentsToImageInvocation'>;
export type LatentsToImageInvocation = Invocation<'LatentsToImageInvocation'>; export type PipelineModelLoaderInvocation = N<'PipelineModelLoaderInvocation'>;
export type PipelineModelLoaderInvocation =
Invocation<'PipelineModelLoaderInvocation'>;
// ControlNet Nodes // ControlNet Nodes
export type ControlNetInvocation = Invocation<'ControlNetInvocation'>; export type ControlNetInvocation = N<'ControlNetInvocation'>;
export type CannyImageProcessorInvocation = export type CannyImageProcessorInvocation = N<'CannyImageProcessorInvocation'>;
Invocation<'CannyImageProcessorInvocation'>;
export type ContentShuffleImageProcessorInvocation = export type ContentShuffleImageProcessorInvocation =
Invocation<'ContentShuffleImageProcessorInvocation'>; N<'ContentShuffleImageProcessorInvocation'>;
export type HedImageProcessorInvocation = export type HedImageProcessorInvocation = N<'HedImageProcessorInvocation'>;
Invocation<'HedImageProcessorInvocation'>;
export type LineartAnimeImageProcessorInvocation = export type LineartAnimeImageProcessorInvocation =
Invocation<'LineartAnimeImageProcessorInvocation'>; N<'LineartAnimeImageProcessorInvocation'>;
export type LineartImageProcessorInvocation = export type LineartImageProcessorInvocation =
Invocation<'LineartImageProcessorInvocation'>; N<'LineartImageProcessorInvocation'>;
export type MediapipeFaceProcessorInvocation = export type MediapipeFaceProcessorInvocation =
Invocation<'MediapipeFaceProcessorInvocation'>; N<'MediapipeFaceProcessorInvocation'>;
export type MidasDepthImageProcessorInvocation = export type MidasDepthImageProcessorInvocation =
Invocation<'MidasDepthImageProcessorInvocation'>; N<'MidasDepthImageProcessorInvocation'>;
export type MlsdImageProcessorInvocation = export type MlsdImageProcessorInvocation = N<'MlsdImageProcessorInvocation'>;
Invocation<'MlsdImageProcessorInvocation'>;
export type NormalbaeImageProcessorInvocation = export type NormalbaeImageProcessorInvocation =
Invocation<'NormalbaeImageProcessorInvocation'>; N<'NormalbaeImageProcessorInvocation'>;
export type OpenposeImageProcessorInvocation = export type OpenposeImageProcessorInvocation =
Invocation<'OpenposeImageProcessorInvocation'>; N<'OpenposeImageProcessorInvocation'>;
export type PidiImageProcessorInvocation = export type PidiImageProcessorInvocation = N<'PidiImageProcessorInvocation'>;
Invocation<'PidiImageProcessorInvocation'>;
export type ZoeDepthImageProcessorInvocation = export type ZoeDepthImageProcessorInvocation =
Invocation<'ZoeDepthImageProcessorInvocation'>; N<'ZoeDepthImageProcessorInvocation'>;
// Node Outputs // Node Outputs
export type ImageOutput = schemas['ImageOutput']; export type ImageOutput = S<'ImageOutput'>;
export type MaskOutput = schemas['MaskOutput']; export type MaskOutput = S<'MaskOutput'>;
export type PromptOutput = schemas['PromptOutput']; export type PromptOutput = S<'PromptOutput'>;
export type IterateInvocationOutput = schemas['IterateInvocationOutput']; export type IterateInvocationOutput = S<'IterateInvocationOutput'>;
export type CollectInvocationOutput = schemas['CollectInvocationOutput']; export type CollectInvocationOutput = S<'CollectInvocationOutput'>;
export type LatentsOutput = schemas['LatentsOutput']; export type LatentsOutput = S<'LatentsOutput'>;
export type GraphInvocationOutput = schemas['GraphInvocationOutput']; export type GraphInvocationOutput = S<'GraphInvocationOutput'>;

View File

@ -39,7 +39,7 @@ dependencies = [
"click", "click",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel>=1.2.1", "compel>=1.2.1",
"controlnet-aux>=0.0.4", "controlnet-aux>=0.0.6",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26 "timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"datasets", "datasets",
"diffusers[torch]~=0.17.1", "diffusers[torch]~=0.17.1",
@ -120,6 +120,7 @@ dependencies = [
"invokeai-merge" = "invokeai.frontend.merge:invokeai_merge_diffusers" "invokeai-merge" = "invokeai.frontend.merge:invokeai_merge_diffusers"
"invokeai-ti" = "invokeai.frontend.training:invokeai_textual_inversion" "invokeai-ti" = "invokeai.frontend.training:invokeai_textual_inversion"
"invokeai-model-install" = "invokeai.frontend.install:invokeai_model_install" "invokeai-model-install" = "invokeai.frontend.install:invokeai_model_install"
"invokeai-migrate3" = "invokeai.backend.install.migrate_to_3:main"
"invokeai-update" = "invokeai.frontend.install:invokeai_update" "invokeai-update" = "invokeai.frontend.install:invokeai_update"
"invokeai-metadata" = "invokeai.frontend.CLI.sd_metadata:print_metadata" "invokeai-metadata" = "invokeai.frontend.CLI.sd_metadata:print_metadata"
"invokeai-node-cli" = "invokeai.app.cli_app:invoke_cli" "invokeai-node-cli" = "invokeai.app.cli_app:invoke_cli"

View File

@ -0,0 +1,4 @@
from invokeai.backend.install.migrate_to_3 import main
if __name__=='__main__':
main()

View File

@ -0,0 +1,3 @@
from invokeai.frontend.install.model_install import main
main()

View File

@ -1,278 +0,0 @@
'''
Migrate the models directory and models.yaml file from an existing
InvokeAI 2.3 installation to 3.0.0.
'''
import io
import os
import argparse
import shutil
import yaml
import transformers
import diffusers
import warnings
from pathlib import Path
from omegaconf import OmegaConf
from diffusers import StableDiffusionPipeline, AutoencoderKL
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import (
CLIPTextModel,
CLIPTokenizer,
AutoFeatureExtractor,
BertTokenizerFast,
)
import invokeai.backend.util.logging as logger
from invokeai.backend.model_management.model_probe import (
ModelProbe, ModelType, BaseModelType
)
warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error()
diffusers.logging.set_verbosity_error()
def create_directory_structure(dest: Path):
for model_base in [BaseModelType.StableDiffusion1,BaseModelType.StableDiffusion2]:
for model_type in [ModelType.Pipeline, ModelType.Vae, ModelType.Lora,
ModelType.ControlNet,ModelType.TextualInversion]:
path = dest / model_base.value / model_type.value
path.mkdir(parents=True, exist_ok=True)
path = dest / 'core'
path.mkdir(parents=True, exist_ok=True)
def copy_file(src:Path,dest:Path):
logger.info(f'Copying {str(src)} to {str(dest)}')
try:
shutil.copy(src, dest)
except Exception as e:
logger.error(f'COPY FAILED: {str(e)}')
def copy_dir(src:Path,dest:Path):
logger.info(f'Copying {str(src)} to {str(dest)}')
try:
shutil.copytree(src, dest)
except Exception as e:
logger.error(f'COPY FAILED: {str(e)}')
def migrate_models(src_dir: Path, dest_dir: Path):
for root, dirs, files in os.walk(src_dir):
for f in files:
# hack - don't copy raw learned_embeds.bin, let them
# be copied as part of a tree copy operation
if f == 'learned_embeds.bin':
continue
try:
model = Path(root,f)
info = ModelProbe().heuristic_probe(model)
if not info:
continue
dest = Path(dest_dir, info.base_type.value, info.model_type.value, f)
copy_file(model, dest)
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
for d in dirs:
try:
model = Path(root,d)
info = ModelProbe().heuristic_probe(model)
if not info:
continue
dest = Path(dest_dir, info.base_type.value, info.model_type.value, model.name)
copy_dir(model, dest)
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
def migrate_support_models(dest_directory: Path):
if Path('./models/clipseg').exists():
copy_dir(Path('./models/clipseg'),dest_directory / 'core/misc/clipseg')
if Path('./models/realesrgan').exists():
copy_dir(Path('./models/realesrgan'),dest_directory / 'core/upscaling/realesrgan')
for d in ['codeformer','gfpgan']:
path = Path('./models',d)
if path.exists():
copy_dir(path,dest_directory / f'core/face_restoration/{d}')
def migrate_conversion_models(dest_directory: Path):
# These are needed for the conversion script
kwargs = dict(
cache_dir = Path('./models/hub'),
#local_files_only = True
)
try:
logger.info('Migrating core tokenizers and text encoders')
target_dir = dest_directory / 'core' / 'convert'
# bert
bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs)
bert.save_pretrained(target_dir / 'bert-base-uncased', safe_serialization=True)
# sd-1
repo_id = 'openai/clip-vit-large-patch14'
pipeline = CLIPTokenizer.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir / 'clip-vit-large-patch14', safe_serialization=True)
pipeline = CLIPTextModel.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir / 'clip-vit-large-patch14', safe_serialization=True)
# sd-2
repo_id = "stabilityai/stable-diffusion-2"
pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs)
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'tokenizer', safe_serialization=True)
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'text_encoder', safe_serialization=True)
# VAE
logger.info('Migrating stable diffusion VAE')
vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse', **kwargs)
vae.save_pretrained(target_dir / 'sd-vae-ft-mse', safe_serialization=True)
# safety checking
logger.info('Migrating safety checker')
repo_id = "CompVis/stable-diffusion-safety-checker"
pipeline = AutoFeatureExtractor.from_pretrained(repo_id,**kwargs)
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id,**kwargs)
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
def migrate_tuning_models(dest: Path):
for subdir in ['embeddings','loras','controlnets']:
src = Path('.',subdir)
if not src.is_dir():
logger.info(f'{subdir} directory not found; skipping')
continue
logger.info(f'Scanning {subdir}')
migrate_models(src, dest)
def migrate_pipelines(dest_dir: Path, dest_yaml: io.TextIOBase):
cache = Path('./models/hub')
kwargs = dict(
cache_dir = cache,
local_files_only = True,
safety_checker = None,
)
for model in cache.glob('models--*'):
if len(list(model.glob('snapshots/**/model_index.json')))==0:
continue
_,owner,repo_name=model.name.split('--')
repo_id = f'{owner}/{repo_name}'
revisions = [x.name for x in model.glob('refs/*')]
for revision in revisions:
logger.info(f'Migrating {repo_id}, revision {revision}')
try:
pipeline = StableDiffusionPipeline.from_pretrained(
repo_id,
revision=revision,
**kwargs)
info = ModelProbe().heuristic_probe(pipeline)
if not info:
continue
dest = Path(dest_dir, info.base_type.value, info.model_type.value, f'{repo_name}-{revision}')
pipeline.save_pretrained(dest, safe_serialization=True)
rel_path = Path('models',dest.relative_to(dest_dir))
stanza = {
f'{info.base_type.value}/{info.model_type.value}/{repo_name}-{revision}':
{
'name': repo_name,
'path': str(rel_path),
'description': f'diffusers model {repo_id}',
'format': 'diffusers',
'image_size': info.image_size,
'base': info.base_type.value,
'variant': info.variant_type.value,
'prediction_type': info.prediction_type.value,
}
}
print(yaml.dump(stanza),file=dest_yaml,end="")
dest_yaml.flush()
except KeyboardInterrupt:
raise
except Exception as e:
logger.warning(f'Could not load the "{revision}" version of {repo_id}. Skipping.')
def migrate_checkpoints(dest_dir: Path, dest_yaml: io.TextIOBase):
# find any checkpoints referred to in old models.yaml
conf = OmegaConf.load('./configs/models.yaml')
orig_models_dir = Path.cwd() / 'models'
for model_name, stanza in conf.items():
if stanza.get('format') and stanza['format'] == 'ckpt':
try:
logger.info(f'Migrating checkpoint model {model_name}')
weights = orig_models_dir.parent / stanza['weights']
config = stanza['config']
info = ModelProbe().heuristic_probe(weights)
if not info:
continue
# uh oh, weights is in the old models directory - move it into the new one
if Path(weights).is_relative_to(orig_models_dir):
dest = Path(dest_dir, info.base_type.value, info.model_type.value,weights.name)
copy_file(weights,dest)
weights = Path('models', info.base_type.value, info.model_type.value,weights.name)
stanza = {
f'{info.base_type.value}/{info.model_type.value}/{model_name}':
{
'name': model_name,
'path': str(weights),
'description': f'checkpoint model {model_name}',
'format': 'checkpoint',
'image_size': info.image_size,
'base': info.base_type.value,
'variant': info.variant_type.value,
'config': config
}
}
print(yaml.dump(stanza),file=dest_yaml,end="")
dest_yaml.flush()
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
def main():
parser = argparse.ArgumentParser(description="Model directory migrator")
parser.add_argument('root_directory',
help='Root directory (containing "models", "embeddings", "controlnets" and "loras")'
)
parser.add_argument('--dest-directory',
default='./models-3.0',
help='Destination for new models directory',
)
parser.add_argument('--dest-yaml',
default='./models.yaml-3.0',
help='Destination for new models.yaml file',
)
args = parser.parse_args()
root_directory = Path(args.root_directory)
assert root_directory.is_dir(), f"{root_directory} is not a valid directory"
assert (root_directory / 'models').is_dir(), f"{root_directory} does not contain a 'models' subdirectory"
dest_directory = Path(args.dest_directory).resolve()
dest_yaml = Path(args.dest_yaml).resolve()
os.chdir(root_directory)
with open(dest_yaml,'w') as yaml_file:
print(yaml.dump({'__metadata__':
{'version':'3.0.0'}
}
),file=yaml_file,end=""
)
create_directory_structure(dest_directory)
migrate_support_models(dest_directory)
migrate_conversion_models(dest_directory)
migrate_tuning_models(dest_directory)
migrate_pipelines(dest_directory,yaml_file)
migrate_checkpoints(dest_directory,yaml_file)
if __name__ == '__main__':
main()