mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Model Manager Backend Implementation
This commit is contained in:
parent
e66b1a685c
commit
3521557541
@ -9,6 +9,7 @@ import io
|
||||
import base64
|
||||
import os
|
||||
import json
|
||||
import tkinter as tk
|
||||
|
||||
from werkzeug.utils import secure_filename
|
||||
from flask import Flask, redirect, send_from_directory, request, make_response
|
||||
@ -17,6 +18,7 @@ from PIL import Image, ImageOps
|
||||
from PIL.Image import Image as ImageType
|
||||
from uuid import uuid4
|
||||
from threading import Event
|
||||
from tkinter import filedialog
|
||||
|
||||
from ldm.generate import Generate
|
||||
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
||||
@ -297,6 +299,87 @@ class InvokeAIWebServer:
|
||||
config["infill_methods"] = infill_methods()
|
||||
socketio.emit("systemConfig", config)
|
||||
|
||||
@socketio.on('searchForModels')
|
||||
def handle_search_models():
|
||||
try:
|
||||
# Using tkinter to get the filepath because JS doesn't allow
|
||||
root = tk.Tk()
|
||||
root.iconify() # for macos
|
||||
root.withdraw()
|
||||
root.wm_attributes('-topmost', 1)
|
||||
root.focus_force()
|
||||
search_folder = filedialog.askdirectory(parent=root, title='Select Checkpoint Folder')
|
||||
root.destroy()
|
||||
|
||||
if not search_folder:
|
||||
socketio.emit(
|
||||
"foundModels",
|
||||
{'search_folder': None, 'found_models': None},
|
||||
)
|
||||
else:
|
||||
search_folder, found_models = self.generate.model_cache.search_models(search_folder)
|
||||
socketio.emit(
|
||||
"foundModels",
|
||||
{'search_folder': search_folder, 'found_models': found_models},
|
||||
)
|
||||
except Exception as e:
|
||||
self.socketio.emit("error", {"message": (str(e))})
|
||||
print("\n")
|
||||
|
||||
traceback.print_exc()
|
||||
print("\n")
|
||||
|
||||
@socketio.on("addNewModel")
|
||||
def handle_add_model(new_model_config: dict):
|
||||
try:
|
||||
model_name = new_model_config['name']
|
||||
del new_model_config['name']
|
||||
model_attributes = new_model_config
|
||||
update = False
|
||||
current_model_list = self.generate.model_cache.list_models()
|
||||
if model_name in current_model_list:
|
||||
update = True
|
||||
|
||||
print(f">> Adding New Model: {model_name}")
|
||||
|
||||
self.generate.model_cache.add_model(
|
||||
model_name=model_name, model_attributes=model_attributes, clobber=True)
|
||||
self.generate.model_cache.commit(opt.conf)
|
||||
|
||||
new_model_list = self.generate.model_cache.list_models()
|
||||
socketio.emit(
|
||||
"newModelAdded",
|
||||
{"new_model_name": model_name,
|
||||
"model_list": new_model_list, 'update': update},
|
||||
)
|
||||
print(f">> New Model Added: {model_name}")
|
||||
except Exception as e:
|
||||
self.socketio.emit("error", {"message": (str(e))})
|
||||
print("\n")
|
||||
|
||||
traceback.print_exc()
|
||||
print("\n")
|
||||
|
||||
@socketio.on("deleteModel")
|
||||
def handle_delete_model(model_name: str):
|
||||
try:
|
||||
print(f">> Deleting Model: {model_name}")
|
||||
self.generate.model_cache.del_model(model_name)
|
||||
self.generate.model_cache.commit(opt.conf)
|
||||
updated_model_list = self.generate.model_cache.list_models()
|
||||
socketio.emit(
|
||||
"modelDeleted",
|
||||
{"deleted_model_name": model_name,
|
||||
"model_list": updated_model_list},
|
||||
)
|
||||
print(f">> Model Deleted: {model_name}")
|
||||
except Exception as e:
|
||||
self.socketio.emit("error", {"message": (str(e))})
|
||||
print("\n")
|
||||
|
||||
traceback.print_exc()
|
||||
print("\n")
|
||||
|
||||
@socketio.on("requestModelChange")
|
||||
def handle_set_model(model_name: str):
|
||||
try:
|
||||
|
@ -23,6 +23,7 @@ from omegaconf.errors import ConfigAttributeError
|
||||
from ldm.util import instantiate_from_config, ask_user
|
||||
from ldm.invoke.globals import Globals
|
||||
from picklescan.scanner import scan_file_path
|
||||
from pathlib import Path
|
||||
|
||||
DEFAULT_MAX_MODELS=2
|
||||
|
||||
@ -135,8 +136,10 @@ class ModelCache(object):
|
||||
for name in self.config:
|
||||
try:
|
||||
description = self.config[name].description
|
||||
weights = self.config[name].weights
|
||||
except ConfigAttributeError:
|
||||
description = '<no description>'
|
||||
weights = '<not found>'
|
||||
|
||||
if self.current_model == name:
|
||||
status = 'active'
|
||||
@ -147,7 +150,8 @@ class ModelCache(object):
|
||||
|
||||
models[name]={
|
||||
'status' : status,
|
||||
'description' : description
|
||||
'description' : description,
|
||||
'weights': weights
|
||||
}
|
||||
return models
|
||||
|
||||
@ -186,6 +190,8 @@ class ModelCache(object):
|
||||
|
||||
config = omega[model_name] if model_name in omega else {}
|
||||
for field in model_attributes:
|
||||
if field == 'weights':
|
||||
field.replace('\\', '/')
|
||||
config[field] = model_attributes[field]
|
||||
|
||||
omega[model_name] = config
|
||||
@ -311,6 +317,22 @@ class ModelCache(object):
|
||||
sys.exit()
|
||||
else:
|
||||
print('>> Model Scanned. OK!!')
|
||||
|
||||
def search_models(self, search_folder):
|
||||
|
||||
print(f'>> Finding Models In: {search_folder}')
|
||||
models_folder = Path(search_folder).glob('**/*.ckpt')
|
||||
|
||||
files = [x for x in models_folder if x.is_file()]
|
||||
|
||||
found_models = []
|
||||
for file in files:
|
||||
found_models.append({
|
||||
'name': file.stem,
|
||||
'location': str(file.resolve()).replace('\\', '/')
|
||||
})
|
||||
|
||||
return search_folder, found_models
|
||||
|
||||
def _make_cache_room(self) -> None:
|
||||
num_loaded_models = len(self.models)
|
||||
|
Loading…
Reference in New Issue
Block a user