address some of ebr issues

This commit is contained in:
Lincoln Stein 2023-06-20 11:08:27 -04:00
parent 678bb4fe10
commit ac6403f877
5 changed files with 36 additions and 141 deletions

View File

@ -95,8 +95,6 @@ class ModelInstall(object):
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
access_token:str = None):
self.config = config
with open('log.txt','w') as file:
print(config.model_conf_path,file=file)
self.mgr = ModelManager(config.model_conf_path)
self.datasets = OmegaConf.load(Dataset_path)
self.prediction_helper = prediction_type_helper
@ -271,27 +269,36 @@ class ModelInstall(object):
# we try to figure out how to download this most economically
# list all the files in the repo
files = [x.rfilename for x in hinfo.siblings]
location = None
with TemporaryDirectory(dir=self.config.models_path) as staging:
staging = Path(staging)
if 'model_index.json' in files:
location = self._download_hf_pipeline(repo_id, staging) # pipeline
elif 'pytorch_lora_weights.bin' in files:
location = self._download_hf_model(repo_id, ['pytorch_lora_weights.bin'], staging) # LoRA
elif self.config.precision=='float16' and 'diffusion_pytorch_model.fp16.safetensors' in files: # vae, controlnet or some other standalone
files = ['config.json', 'diffusion_pytorch_model.fp16.safetensors']
location = self._download_hf_model(repo_id, files, staging)
elif 'diffusion_pytorch_model.safetensors' in files:
files = ['config.json', 'diffusion_pytorch_model.safetensors']
location = self._download_hf_model(repo_id, files, staging)
elif 'learned_embeds.bin' in files:
location = self._download_hf_model(repo_id, ['learned_embeds.bin'], staging)
else:
for suffix in ['safetensors','bin']:
if f'pytorch_lora_weights.{suffix}' in files:
location = self._download_hf_model(repo_id, ['pytorch_lora_weights.bin'], staging) # LoRA
break
elif self.config.precision=='float16' and f'diffusion_pytorch_model.fp16.{suffix}' in files: # vae, controlnet or some other standalone
files = ['config.json', f'diffusion_pytorch_model.fp16.{suffix}']
location = self._download_hf_model(repo_id, files, staging)
break
elif f'diffusion_pytorch_model.{suffix}' in files:
files = ['config.json', f'diffusion_pytorch_model.{suffix}']
location = self._download_hf_model(repo_id, files, staging)
break
elif f'learned_embeds.{suffix}' in files:
location = self._download_hf_model(repo_id, [f'learned_embeds.suffix'], staging)
break
if not location:
logger.warning(f'Could not determine type of repo {repo_id}. Skipping install.')
return
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
if not info:
logger.warning(f'Could not probe {location}. Skipping install.')
return
dest = self.config.models_path / info.base_type.value / info.model_type.value / self._get_model_name(repo_id,location)
if dest.exists():
shutil.rmtree(dest)

View File

@ -1,118 +0,0 @@
"""
Routines for downloading and installing models.
"""
import json
import safetensors
import safetensors.torch
import shutil
import tempfile
import torch
import traceback
from dataclasses import dataclass
from diffusers import ModelMixin
from enum import Enum
from typing import Callable
from pathlib import Path
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from . import ModelManager
from .models import BaseModelType, ModelType, VariantType
from .model_probe import ModelProbe, ModelVariantInfo
from .model_cache import SilenceWarnings
class ModelInstall(object):
'''
This class is able to download and install several different kinds of
InvokeAI models. The helper function, if provided, is called on to distinguish
between v2-base and v2-768 stable diffusion pipelines. This usually involves
asking the user to select the proper type, as there is no way of distinguishing
the two type of v2 file programmatically (as far as I know).
'''
def __init__(self,
config: InvokeAIAppConfig,
model_base_helper: Callable[[Path],BaseModelType]=None,
clobber:bool = False
):
'''
:param config: InvokeAI configuration object
:param model_base_helper: A function call that accepts the Path to a checkpoint model and returns a ModelType enum
:param clobber: If true, models with colliding names will be overwritten
'''
self.config = config
self.clogger = clobber
self.helper = model_base_helper
self.prober = ModelProbe()
def install_checkpoint_file(self, checkpoint: Path)->dict:
'''
Install the checkpoint file at path and return a
configuration entry that can be added to `models.yaml`.
Model checkpoints and VAEs will be converted into
diffusers before installation. Note that the model manager
does not hold entries for anything but diffusers pipelines,
and the configuration file stanzas returned from such models
can be safely ignored.
'''
model_info = self.prober.probe(checkpoint, self.helper)
if not model_info:
raise ValueError(f"Unable to determine type of checkpoint file {checkpoint}")
key = ModelManager.create_key(
model_name = checkpoint.stem,
base_model = model_info.base_type,
model_type = model_info.model_type,
)
destination_path = self._dest_path(model_info) / checkpoint
destination_path.parent.mkdir(parents=True, exist_ok=True)
self._check_for_collision(destination_path)
stanza = {
key: dict(
name = checkpoint.stem,
description = f'{model_info.model_type} model {checkpoint.stem}',
base = model_info.base_model.value,
type = model_info.model_type.value,
variant = model_info.variant_type.value,
path = str(destination_path),
)
}
# non-pipeline; no conversion needed, just copy into right place
if model_info.model_type != ModelType.Pipeline:
shutil.copyfile(checkpoint, destination_path)
stanza[key].update({'format': 'checkpoint'})
# pipeline - conversion needed here
else:
destination_path = self._dest_path(model_info) / checkpoint.stem
config_file = self._pipeline_type_to_config_file(model_info.model_type)
from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
with SilenceWarnings:
convert_ckpt_to_diffusers(
checkpoint,
destination_path,
extract_ema=True,
original_config_file=config_file,
scan_needed=False,
)
stanza[key].update({'format': 'folder',
'path': destination_path, # no suffix on this
})
return stanza
def _check_for_collision(self, path: Path):
if not path.exists():
return
if self.clobber:
shutil.rmtree(path)
else:
raise ValueError(f"Destination {path} already exists. Won't overwrite unless clobber=True.")
def _staging_directory(self)->tempfile.TemporaryDirectory:
return tempfile.TemporaryDirectory(dir=self.config.root_path)

View File

@ -703,7 +703,7 @@ class ModelManager(object):
model_path = os.path.join(models_dir, entry_name)
if model_path not in loaded_files: # TODO: check
model_path = Path(model_path)
model_name = model_path.name if model_path.is_dir else model_path.stem
model_name = model_path.name if model_path.is_dir() else model_path.stem
model_key = self.create_key(model_name, base_model, model_type)
if model_key in self.models:

View File

@ -401,8 +401,16 @@ class ControlNetFolderProbe(FolderProbeBase):
else BaseModelType.StableDiffusion2
class LoRAFolderProbe(FolderProbeBase):
# I've never seen one of these in the wild, so this is a noop
pass
def get_base_type(self)->BaseModelType:
model_file = None
for suffix in ['safetensors','bin']:
base_file = self.folder_path / f'pytorch_lora_weights.{suffix}'
if base_file.exists():
model_file = base_file
break
if not model_file:
raise Exception('Unknown LoRA format encountered')
return LoRACheckpointProbe(model_file,None).get_base_type()
############## register probe classes ######
ModelProbe.register_probe('diffusers', ModelType.Pipeline, PipelineFolderProbe)

View File

@ -87,7 +87,5 @@ sd-1/embedding/ahx-beta-453407d:
repo_id: sd-concepts-library/ahx-beta-453407d
sd-1/lora/LowRA:
path: https://civitai.com/api/download/models/63006
sd-1/lora/Ink Scenery:
sd-1/lora/Ink scenery:
path: https://civitai.com/api/download/models/83390
sd-1/lora/sd-model-finetuned-lora-t4:
repo_id: sayakpaul/sd-model-finetuned-lora-t4