mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
listing, downloading and deleting LoRAs working; TI support pending
This commit is contained in:
@ -95,6 +95,9 @@ def install_requested_models(
|
||||
model_manager.install_controlnet_models(controlnet.install_models, access_token=access_token)
|
||||
model_manager.delete_controlnet_models(controlnet.remove_models)
|
||||
|
||||
model_manager.install_lora_models(lora.install_models)
|
||||
model_manager.delete_lora_models(lora.remove_models)
|
||||
|
||||
# TODO: Replace next three paragraphs with calls into new model manager
|
||||
if diffusers.remove_models and len(diffusers.remove_models) > 0:
|
||||
logger.info("DELETING UNCHECKED STARTER MODELS")
|
||||
|
@ -20,7 +20,7 @@ import warnings
|
||||
from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
from shutil import move, rmtree
|
||||
from typing import Any, Optional, Union, Callable, types
|
||||
from typing import Any, Optional, Union, Callable, Dict, List, types
|
||||
|
||||
import safetensors
|
||||
import safetensors.torch
|
||||
@ -1322,15 +1322,69 @@ class ModelManager(object):
|
||||
os.getenv("HF_HOME") is not None or os.getenv("XDG_CACHE_HOME") is not None
|
||||
)
|
||||
|
||||
def list_lora_models(self)->Dict[str,bool]:
|
||||
'''Return a dict of installed lora models; key is either the shortname
|
||||
defined in INITIAL_MODELS, or the basename of the file in the LoRA
|
||||
directory. Value is True if installed'''
|
||||
|
||||
models = OmegaConf.load(Dataset_path).get('lora') or {}
|
||||
installed_models = {x: False for x in models.keys()}
|
||||
|
||||
dir = self.globals.lora_path
|
||||
installed_models = dict()
|
||||
for root, dirs, files in os.walk(dir):
|
||||
for name in files:
|
||||
if Path(name).suffix in ['.safetensors','.ckpt','.pt']:
|
||||
installed_models.update({name: True})
|
||||
return installed_models
|
||||
|
||||
def install_lora_models(self, model_names: list[str]):
|
||||
'''Download list of LoRA/LyCORIS models'''
|
||||
short_names = OmegaConf.load(Dataset_path).get('lora') or {}
|
||||
for name in model_names:
|
||||
url = short_names.get(name) or name
|
||||
download_with_resume(url, self.globals.lora_path)
|
||||
|
||||
def delete_lora_models(self, model_names: List[str]):
|
||||
'''Remove the list of lora models'''
|
||||
for name in model_names:
|
||||
path = self.globals.lora_path / name
|
||||
if path.exists():
|
||||
self.logger.info(f'Purging lora model {name}')
|
||||
path.unlink()
|
||||
|
||||
def list_ti_models(self)->Dict[str,bool]:
|
||||
'''Return a dict of installed textual models; key is either the shortname
|
||||
defined in INITIAL_MODELS, or the basename of the file in the LoRA
|
||||
directory. Value is True if installed'''
|
||||
|
||||
models = OmegaConf.load(Dataset_path).get('textual_inversion') or {}
|
||||
installed_models = {x: False for x in models.keys()}
|
||||
|
||||
dir = self.globals.embedding_path
|
||||
installed_models = dict()
|
||||
for root, dirs, files in os.walk(dir):
|
||||
for name in files:
|
||||
if name == 'learned_embeds.bin':
|
||||
name = str(Path(root,name).parent)
|
||||
installed_models.update({name: True})
|
||||
return installed_models
|
||||
|
||||
def install_ti_models(self, model_names: list[str]):
|
||||
'''Download list of textual inversion embeddings'''
|
||||
short_names = OmegaConf.load(Dataset_path).get('textual_inversion') or {}
|
||||
for name in model_names:
|
||||
url = short_names.get(name) or name
|
||||
download_with_resume(url, self.globals.embedding_path)
|
||||
|
||||
def list_controlnet_models(self)->Dict[str,bool]:
|
||||
'''Return a dict of installed controlnet models; key is repo_id or short name
|
||||
of model (defined in INITIAL_MODELS), and valule is True if installed'''
|
||||
of model (defined in INITIAL_MODELS), and value is True if installed'''
|
||||
|
||||
cn_models = OmegaConf.load(Dataset_path).get('controlnet') or {}
|
||||
installed_models = {x: False for x in cn_models.keys()}
|
||||
|
||||
cn_dir = self.globals.controlnet_path
|
||||
installed_cn_models = dict()
|
||||
for root, dirs, files in os.walk(cn_dir):
|
||||
for name in dirs:
|
||||
if Path(root, name, '.download_complete').exists():
|
||||
|
Reference in New Issue
Block a user