mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
model_cache: add ability to load a diffusers model pipeline
and update associated things in Generate & Generator to not instantly fail when that happens
This commit is contained in:
parent
9b274bd57c
commit
4c3858e079
@ -18,6 +18,8 @@ import gc
|
|||||||
import hashlib
|
import hashlib
|
||||||
import cv2
|
import cv2
|
||||||
import skimage
|
import skimage
|
||||||
|
from diffusers import DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, \
|
||||||
|
EulerAncestralDiscreteScheduler
|
||||||
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from ldm.invoke.generator.base import downsampling
|
from ldm.invoke.generator.base import downsampling
|
||||||
@ -402,7 +404,10 @@ class Generate:
|
|||||||
width = width or self.width
|
width = width or self.width
|
||||||
height = height or self.height
|
height = height or self.height
|
||||||
|
|
||||||
configure_model_padding(model, seamless, seamless_axes)
|
if isinstance(model, DiffusionPipeline):
|
||||||
|
configure_model_padding(model.unet, seamless, seamless_axes)
|
||||||
|
else:
|
||||||
|
configure_model_padding(model, seamless, seamless_axes)
|
||||||
|
|
||||||
assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
|
assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
|
||||||
assert threshold >= 0.0, '--threshold must be >=0.0'
|
assert threshold >= 0.0, '--threshold must be >=0.0'
|
||||||
@ -962,9 +967,15 @@ class Generate:
|
|||||||
def sample_to_lowres_estimated_image(self, samples):
|
def sample_to_lowres_estimated_image(self, samples):
|
||||||
return self._make_base().sample_to_lowres_estimated_image(samples)
|
return self._make_base().sample_to_lowres_estimated_image(samples)
|
||||||
|
|
||||||
|
def _set_sampler(self):
|
||||||
|
if isinstance(self.model, DiffusionPipeline):
|
||||||
|
return self._set_scheduler()
|
||||||
|
else:
|
||||||
|
return self._set_sampler_legacy()
|
||||||
|
|
||||||
# very repetitive code - can this be simplified? The KSampler names are
|
# very repetitive code - can this be simplified? The KSampler names are
|
||||||
# consistent, at least
|
# consistent, at least
|
||||||
def _set_sampler(self):
|
def _set_sampler_legacy(self):
|
||||||
msg = f'>> Setting Sampler to {self.sampler_name}'
|
msg = f'>> Setting Sampler to {self.sampler_name}'
|
||||||
if self.sampler_name == 'plms':
|
if self.sampler_name == 'plms':
|
||||||
self.sampler = PLMSSampler(self.model, device=self.device)
|
self.sampler = PLMSSampler(self.model, device=self.device)
|
||||||
@ -992,6 +1003,44 @@ class Generate:
|
|||||||
|
|
||||||
print(msg)
|
print(msg)
|
||||||
|
|
||||||
|
def _set_scheduler(self):
|
||||||
|
msg = f'>> Setting Sampler to {self.sampler_name}'
|
||||||
|
default = self.model.scheduler
|
||||||
|
# TODO: Test me! Not all schedulers take the same args.
|
||||||
|
scheduler_args = dict(
|
||||||
|
num_train_timesteps=default.num_train_timesteps,
|
||||||
|
beta_start=default.beta_start,
|
||||||
|
beta_end=default.beta_end,
|
||||||
|
beta_schedule=default.beta_schedule,
|
||||||
|
)
|
||||||
|
trained_betas = getattr(self.model.scheduler, 'trained_betas')
|
||||||
|
if trained_betas is not None:
|
||||||
|
scheduler_args.update(trained_betas=trained_betas)
|
||||||
|
if self.sampler_name == 'plms':
|
||||||
|
raise NotImplementedError("What's the diffusers implementation of PLMS?")
|
||||||
|
elif self.sampler_name == 'ddim':
|
||||||
|
self.sampler = DDIMScheduler(**scheduler_args)
|
||||||
|
elif self.sampler_name == 'k_dpm_2_a':
|
||||||
|
raise NotImplementedError("no diffusers implementation of dpm_2 samplers")
|
||||||
|
elif self.sampler_name == 'k_dpm_2':
|
||||||
|
raise NotImplementedError("no diffusers implementation of dpm_2 samplers")
|
||||||
|
elif self.sampler_name == 'k_euler_a':
|
||||||
|
self.sampler = EulerAncestralDiscreteScheduler(**scheduler_args)
|
||||||
|
elif self.sampler_name == 'k_euler':
|
||||||
|
self.sampler = EulerDiscreteScheduler(**scheduler_args)
|
||||||
|
elif self.sampler_name == 'k_heun':
|
||||||
|
raise NotImplementedError("no diffusers implementation of Heun's sampler")
|
||||||
|
elif self.sampler_name == 'k_lms':
|
||||||
|
self.sampler = LMSDiscreteScheduler(**scheduler_args)
|
||||||
|
else:
|
||||||
|
msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to {default}'
|
||||||
|
|
||||||
|
print(msg)
|
||||||
|
|
||||||
|
if not hasattr(self.sampler, 'uses_inpainting_model'):
|
||||||
|
# FIXME: terrible kludge!
|
||||||
|
self.sampler.uses_inpainting_model = lambda: False
|
||||||
|
|
||||||
def _load_img(self, img)->Image:
|
def _load_img(self, img)->Image:
|
||||||
if isinstance(img, Image.Image):
|
if isinstance(img, Image.Image):
|
||||||
image = img
|
image = img
|
||||||
|
@ -9,8 +9,8 @@ import traceback
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image, ImageFilter, ImageChops
|
from PIL import Image, ImageFilter
|
||||||
import cv2 as cv
|
from diffusers import DiffusionPipeline
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
@ -26,9 +26,9 @@ class Generator:
|
|||||||
downsampling_factor: int
|
downsampling_factor: int
|
||||||
latent_channels: int
|
latent_channels: int
|
||||||
precision: str
|
precision: str
|
||||||
model: DiffusionWrapper
|
model: DiffusionWrapper | DiffusionPipeline
|
||||||
|
|
||||||
def __init__(self, model: DiffusionWrapper, precision: str):
|
def __init__(self, model: DiffusionWrapper | DiffusionPipeline, precision: str):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
self.seed = None
|
self.seed = None
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import secrets
|
import secrets
|
||||||
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Union, Callable
|
from typing import List, Optional, Union, Callable
|
||||||
|
|
||||||
@ -309,6 +310,28 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
|||||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||||
return text_embeddings
|
return text_embeddings
|
||||||
|
|
||||||
|
def get_learned_conditioning(self, c: List[List[str]], return_tokens=True,
|
||||||
|
fragment_weights=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion.
|
||||||
|
"""
|
||||||
|
assert return_tokens == True
|
||||||
|
if fragment_weights:
|
||||||
|
weights = fragment_weights[0]
|
||||||
|
if any(weight != 1.0 for weight in weights):
|
||||||
|
warnings.warn(f"fragment weights not implemented yet {fragment_weights}", stacklevel=2)
|
||||||
|
|
||||||
|
if kwargs:
|
||||||
|
warnings.warn(f"unsupported args {kwargs}", stacklevel=2)
|
||||||
|
|
||||||
|
text_fragments = c[0]
|
||||||
|
text_input = self._tokenize(text_fragments)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
token_ids = text_input.input_ids.to(self.text_encoder.device)
|
||||||
|
text_embeddings = self.text_encoder(token_ids)[0]
|
||||||
|
return text_embeddings, text_input.input_ids
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def _tokenize(self, prompt: Union[str, List[str]]):
|
def _tokenize(self, prompt: Union[str, List[str]]):
|
||||||
return self.tokenizer(
|
return self.tokenizer(
|
||||||
@ -319,6 +342,11 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
|||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def channels(self) -> int:
|
||||||
|
"""Compatible with DiffusionWrapper"""
|
||||||
|
return self.unet.in_channels
|
||||||
|
|
||||||
def prepare_latents(self, latents, batch_size, height, width, generator, dtype):
|
def prepare_latents(self, latents, batch_size, height, width, generator, dtype):
|
||||||
# get the initial random noise unless the user supplied it
|
# get the initial random noise unless the user supplied it
|
||||||
# Unlike in other pipelines, latents need to be generated in the target device
|
# Unlike in other pipelines, latents need to be generated in the target device
|
||||||
|
@ -24,17 +24,8 @@ class Txt2Img(Generator):
|
|||||||
self.perlin = perlin
|
self.perlin = perlin
|
||||||
uc, c, extra_conditioning_info = conditioning
|
uc, c, extra_conditioning_info = conditioning
|
||||||
|
|
||||||
# FIXME: this should probably be either passed in to __init__ instead of model & precision,
|
pipeline = self.model
|
||||||
# or be constructed in __init__ from those inputs.
|
# TODO: customize a new pipeline for the given sampler (Scheduler)
|
||||||
pipeline = StableDiffusionGeneratorPipeline.from_pretrained(
|
|
||||||
"runwayml/stable-diffusion-v1-5",
|
|
||||||
revision="fp16", torch_dtype=torch.float16,
|
|
||||||
safety_checker=None, # TODO
|
|
||||||
# scheduler=sampler + ddim_eta, # TODO
|
|
||||||
# TODO: local_files_only=True
|
|
||||||
)
|
|
||||||
pipeline.unet.to("cuda")
|
|
||||||
pipeline.vae.to("cuda")
|
|
||||||
|
|
||||||
def make_image(x_T) -> PIL.Image.Image:
|
def make_image(x_T) -> PIL.Image.Image:
|
||||||
# FIXME: restore free_gpu_mem functionality
|
# FIXME: restore free_gpu_mem functionality
|
||||||
|
@ -4,6 +4,7 @@ They are moved between GPU and CPU as necessary. If CPU memory falls
|
|||||||
below a preset minimum, the least recently used model will be
|
below a preset minimum, the least recently used model will be
|
||||||
cleared and loaded from disk when next needed.
|
cleared and loaded from disk when next needed.
|
||||||
'''
|
'''
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
@ -20,6 +21,8 @@ import contextlib
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.errors import ConfigAttributeError
|
from omegaconf.errors import ConfigAttributeError
|
||||||
|
|
||||||
|
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||||
from ldm.util import instantiate_from_config, ask_user
|
from ldm.util import instantiate_from_config, ask_user
|
||||||
from ldm.invoke.globals import Globals
|
from ldm.invoke.globals import Globals
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
@ -91,7 +94,7 @@ class ModelCache(object):
|
|||||||
assert self.current_model,'** FATAL: no current model to restore to'
|
assert self.current_model,'** FATAL: no current model to restore to'
|
||||||
print(f'** restoring {self.current_model}')
|
print(f'** restoring {self.current_model}')
|
||||||
self.get_model(self.current_model)
|
self.get_model(self.current_model)
|
||||||
return
|
return None
|
||||||
|
|
||||||
self.current_model = model_name
|
self.current_model = model_name
|
||||||
self._push_newest_model(model_name)
|
self._push_newest_model(model_name)
|
||||||
@ -277,7 +280,43 @@ class ModelCache(object):
|
|||||||
return model, width, height, model_hash
|
return model, width, height, model_hash
|
||||||
|
|
||||||
def _load_diffusers_model(self, mconfig):
|
def _load_diffusers_model(self, mconfig):
|
||||||
raise NotImplementedError() # return pipeline, width, height, model_hash
|
pipeline_args = {}
|
||||||
|
|
||||||
|
if 'repo_name' in mconfig:
|
||||||
|
name_or_path = mconfig['repo_name']
|
||||||
|
model_hash = "FIXME"
|
||||||
|
# model_hash = huggingface_hub.get_hf_file_metadata(url).commit_hash
|
||||||
|
elif 'path' in mconfig:
|
||||||
|
name_or_path = Path(mconfig['path'])
|
||||||
|
# FIXME: What should the model_hash be? A hash of the unet weights? Of all files of all
|
||||||
|
# the submodels hashed together? The commit ID from the repo?
|
||||||
|
model_hash = "FIXME TOO"
|
||||||
|
else:
|
||||||
|
raise ValueError("Model config must specify either repo_name or path.")
|
||||||
|
|
||||||
|
print(f'>> Loading diffusers model from {name_or_path}')
|
||||||
|
|
||||||
|
if self.precision == 'float16':
|
||||||
|
print(' | Using faster float16 precision')
|
||||||
|
pipeline_args.update(revision="fp16", torch_dtype=torch.float16)
|
||||||
|
else:
|
||||||
|
# TODO: more accurately, "using the model's default precision."
|
||||||
|
# How do we find out what that is?
|
||||||
|
print(' | Using more accurate float32 precision')
|
||||||
|
|
||||||
|
pipeline = StableDiffusionGeneratorPipeline.from_pretrained(
|
||||||
|
name_or_path,
|
||||||
|
safety_checker=None, # TODO
|
||||||
|
# TODO: alternate VAE
|
||||||
|
# TODO: local_files_only=True
|
||||||
|
**pipeline_args
|
||||||
|
)
|
||||||
|
pipeline.to(self.device)
|
||||||
|
|
||||||
|
width = pipeline.vae.sample_size
|
||||||
|
height = pipeline.vae.sample_size
|
||||||
|
|
||||||
|
return pipeline, width, height, model_hash
|
||||||
|
|
||||||
def offload_model(self, model_name:str):
|
def offload_model(self, model_name:str):
|
||||||
'''
|
'''
|
||||||
|
Loading…
x
Reference in New Issue
Block a user