mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
model_cache: add ability to load a diffusers model pipeline
and update associated things in Generate & Generator to not instantly fail when that happens
This commit is contained in:
parent
9b274bd57c
commit
4c3858e079
@ -18,6 +18,8 @@ import gc
|
||||
import hashlib
|
||||
import cv2
|
||||
import skimage
|
||||
from diffusers import DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, \
|
||||
EulerAncestralDiscreteScheduler
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from ldm.invoke.generator.base import downsampling
|
||||
@ -402,7 +404,10 @@ class Generate:
|
||||
width = width or self.width
|
||||
height = height or self.height
|
||||
|
||||
configure_model_padding(model, seamless, seamless_axes)
|
||||
if isinstance(model, DiffusionPipeline):
|
||||
configure_model_padding(model.unet, seamless, seamless_axes)
|
||||
else:
|
||||
configure_model_padding(model, seamless, seamless_axes)
|
||||
|
||||
assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
|
||||
assert threshold >= 0.0, '--threshold must be >=0.0'
|
||||
@ -962,9 +967,15 @@ class Generate:
|
||||
def sample_to_lowres_estimated_image(self, samples):
|
||||
return self._make_base().sample_to_lowres_estimated_image(samples)
|
||||
|
||||
def _set_sampler(self):
|
||||
if isinstance(self.model, DiffusionPipeline):
|
||||
return self._set_scheduler()
|
||||
else:
|
||||
return self._set_sampler_legacy()
|
||||
|
||||
# very repetitive code - can this be simplified? The KSampler names are
|
||||
# consistent, at least
|
||||
def _set_sampler(self):
|
||||
def _set_sampler_legacy(self):
|
||||
msg = f'>> Setting Sampler to {self.sampler_name}'
|
||||
if self.sampler_name == 'plms':
|
||||
self.sampler = PLMSSampler(self.model, device=self.device)
|
||||
@ -992,6 +1003,44 @@ class Generate:
|
||||
|
||||
print(msg)
|
||||
|
||||
def _set_scheduler(self):
|
||||
msg = f'>> Setting Sampler to {self.sampler_name}'
|
||||
default = self.model.scheduler
|
||||
# TODO: Test me! Not all schedulers take the same args.
|
||||
scheduler_args = dict(
|
||||
num_train_timesteps=default.num_train_timesteps,
|
||||
beta_start=default.beta_start,
|
||||
beta_end=default.beta_end,
|
||||
beta_schedule=default.beta_schedule,
|
||||
)
|
||||
trained_betas = getattr(self.model.scheduler, 'trained_betas')
|
||||
if trained_betas is not None:
|
||||
scheduler_args.update(trained_betas=trained_betas)
|
||||
if self.sampler_name == 'plms':
|
||||
raise NotImplementedError("What's the diffusers implementation of PLMS?")
|
||||
elif self.sampler_name == 'ddim':
|
||||
self.sampler = DDIMScheduler(**scheduler_args)
|
||||
elif self.sampler_name == 'k_dpm_2_a':
|
||||
raise NotImplementedError("no diffusers implementation of dpm_2 samplers")
|
||||
elif self.sampler_name == 'k_dpm_2':
|
||||
raise NotImplementedError("no diffusers implementation of dpm_2 samplers")
|
||||
elif self.sampler_name == 'k_euler_a':
|
||||
self.sampler = EulerAncestralDiscreteScheduler(**scheduler_args)
|
||||
elif self.sampler_name == 'k_euler':
|
||||
self.sampler = EulerDiscreteScheduler(**scheduler_args)
|
||||
elif self.sampler_name == 'k_heun':
|
||||
raise NotImplementedError("no diffusers implementation of Heun's sampler")
|
||||
elif self.sampler_name == 'k_lms':
|
||||
self.sampler = LMSDiscreteScheduler(**scheduler_args)
|
||||
else:
|
||||
msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to {default}'
|
||||
|
||||
print(msg)
|
||||
|
||||
if not hasattr(self.sampler, 'uses_inpainting_model'):
|
||||
# FIXME: terrible kludge!
|
||||
self.sampler.uses_inpainting_model = lambda: False
|
||||
|
||||
def _load_img(self, img)->Image:
|
||||
if isinstance(img, Image.Image):
|
||||
image = img
|
||||
|
@ -9,8 +9,8 @@ import traceback
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageFilter, ImageChops
|
||||
import cv2 as cv
|
||||
from PIL import Image, ImageFilter
|
||||
from diffusers import DiffusionPipeline
|
||||
from einops import rearrange
|
||||
from pytorch_lightning import seed_everything
|
||||
from tqdm import trange
|
||||
@ -26,9 +26,9 @@ class Generator:
|
||||
downsampling_factor: int
|
||||
latent_channels: int
|
||||
precision: str
|
||||
model: DiffusionWrapper
|
||||
model: DiffusionWrapper | DiffusionPipeline
|
||||
|
||||
def __init__(self, model: DiffusionWrapper, precision: str):
|
||||
def __init__(self, model: DiffusionWrapper | DiffusionPipeline, precision: str):
|
||||
self.model = model
|
||||
self.precision = precision
|
||||
self.seed = None
|
||||
|
@ -1,4 +1,5 @@
|
||||
import secrets
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union, Callable
|
||||
|
||||
@ -309,6 +310,28 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
return text_embeddings
|
||||
|
||||
def get_learned_conditioning(self, c: List[List[str]], return_tokens=True,
|
||||
fragment_weights=None, **kwargs):
|
||||
"""
|
||||
Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion.
|
||||
"""
|
||||
assert return_tokens == True
|
||||
if fragment_weights:
|
||||
weights = fragment_weights[0]
|
||||
if any(weight != 1.0 for weight in weights):
|
||||
warnings.warn(f"fragment weights not implemented yet {fragment_weights}", stacklevel=2)
|
||||
|
||||
if kwargs:
|
||||
warnings.warn(f"unsupported args {kwargs}", stacklevel=2)
|
||||
|
||||
text_fragments = c[0]
|
||||
text_input = self._tokenize(text_fragments)
|
||||
|
||||
with torch.inference_mode():
|
||||
token_ids = text_input.input_ids.to(self.text_encoder.device)
|
||||
text_embeddings = self.text_encoder(token_ids)[0]
|
||||
return text_embeddings, text_input.input_ids
|
||||
|
||||
@torch.inference_mode()
|
||||
def _tokenize(self, prompt: Union[str, List[str]]):
|
||||
return self.tokenizer(
|
||||
@ -319,6 +342,11 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
@property
|
||||
def channels(self) -> int:
|
||||
"""Compatible with DiffusionWrapper"""
|
||||
return self.unet.in_channels
|
||||
|
||||
def prepare_latents(self, latents, batch_size, height, width, generator, dtype):
|
||||
# get the initial random noise unless the user supplied it
|
||||
# Unlike in other pipelines, latents need to be generated in the target device
|
||||
|
@ -24,17 +24,8 @@ class Txt2Img(Generator):
|
||||
self.perlin = perlin
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
|
||||
# FIXME: this should probably be either passed in to __init__ instead of model & precision,
|
||||
# or be constructed in __init__ from those inputs.
|
||||
pipeline = StableDiffusionGeneratorPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
revision="fp16", torch_dtype=torch.float16,
|
||||
safety_checker=None, # TODO
|
||||
# scheduler=sampler + ddim_eta, # TODO
|
||||
# TODO: local_files_only=True
|
||||
)
|
||||
pipeline.unet.to("cuda")
|
||||
pipeline.vae.to("cuda")
|
||||
pipeline = self.model
|
||||
# TODO: customize a new pipeline for the given sampler (Scheduler)
|
||||
|
||||
def make_image(x_T) -> PIL.Image.Image:
|
||||
# FIXME: restore free_gpu_mem functionality
|
||||
|
@ -4,6 +4,7 @@ They are moved between GPU and CPU as necessary. If CPU memory falls
|
||||
below a preset minimum, the least recently used model will be
|
||||
cleared and loaded from disk when next needed.
|
||||
'''
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import os
|
||||
@ -20,6 +21,8 @@ import contextlib
|
||||
from typing import Union
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.errors import ConfigAttributeError
|
||||
|
||||
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ldm.util import instantiate_from_config, ask_user
|
||||
from ldm.invoke.globals import Globals
|
||||
from picklescan.scanner import scan_file_path
|
||||
@ -91,7 +94,7 @@ class ModelCache(object):
|
||||
assert self.current_model,'** FATAL: no current model to restore to'
|
||||
print(f'** restoring {self.current_model}')
|
||||
self.get_model(self.current_model)
|
||||
return
|
||||
return None
|
||||
|
||||
self.current_model = model_name
|
||||
self._push_newest_model(model_name)
|
||||
@ -277,7 +280,43 @@ class ModelCache(object):
|
||||
return model, width, height, model_hash
|
||||
|
||||
def _load_diffusers_model(self, mconfig):
|
||||
raise NotImplementedError() # return pipeline, width, height, model_hash
|
||||
pipeline_args = {}
|
||||
|
||||
if 'repo_name' in mconfig:
|
||||
name_or_path = mconfig['repo_name']
|
||||
model_hash = "FIXME"
|
||||
# model_hash = huggingface_hub.get_hf_file_metadata(url).commit_hash
|
||||
elif 'path' in mconfig:
|
||||
name_or_path = Path(mconfig['path'])
|
||||
# FIXME: What should the model_hash be? A hash of the unet weights? Of all files of all
|
||||
# the submodels hashed together? The commit ID from the repo?
|
||||
model_hash = "FIXME TOO"
|
||||
else:
|
||||
raise ValueError("Model config must specify either repo_name or path.")
|
||||
|
||||
print(f'>> Loading diffusers model from {name_or_path}')
|
||||
|
||||
if self.precision == 'float16':
|
||||
print(' | Using faster float16 precision')
|
||||
pipeline_args.update(revision="fp16", torch_dtype=torch.float16)
|
||||
else:
|
||||
# TODO: more accurately, "using the model's default precision."
|
||||
# How do we find out what that is?
|
||||
print(' | Using more accurate float32 precision')
|
||||
|
||||
pipeline = StableDiffusionGeneratorPipeline.from_pretrained(
|
||||
name_or_path,
|
||||
safety_checker=None, # TODO
|
||||
# TODO: alternate VAE
|
||||
# TODO: local_files_only=True
|
||||
**pipeline_args
|
||||
)
|
||||
pipeline.to(self.device)
|
||||
|
||||
width = pipeline.vae.sample_size
|
||||
height = pipeline.vae.sample_size
|
||||
|
||||
return pipeline, width, height, model_hash
|
||||
|
||||
def offload_model(self, model_name:str):
|
||||
'''
|
||||
|
Loading…
Reference in New Issue
Block a user