mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Model Manager Backend Implementation
This commit is contained in:
parent
e66b1a685c
commit
3521557541
@ -9,6 +9,7 @@ import io
|
|||||||
import base64
|
import base64
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import tkinter as tk
|
||||||
|
|
||||||
from werkzeug.utils import secure_filename
|
from werkzeug.utils import secure_filename
|
||||||
from flask import Flask, redirect, send_from_directory, request, make_response
|
from flask import Flask, redirect, send_from_directory, request, make_response
|
||||||
@ -17,6 +18,7 @@ from PIL import Image, ImageOps
|
|||||||
from PIL.Image import Image as ImageType
|
from PIL.Image import Image as ImageType
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from threading import Event
|
from threading import Event
|
||||||
|
from tkinter import filedialog
|
||||||
|
|
||||||
from ldm.generate import Generate
|
from ldm.generate import Generate
|
||||||
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
||||||
@ -297,6 +299,87 @@ class InvokeAIWebServer:
|
|||||||
config["infill_methods"] = infill_methods()
|
config["infill_methods"] = infill_methods()
|
||||||
socketio.emit("systemConfig", config)
|
socketio.emit("systemConfig", config)
|
||||||
|
|
||||||
|
@socketio.on('searchForModels')
|
||||||
|
def handle_search_models():
|
||||||
|
try:
|
||||||
|
# Using tkinter to get the filepath because JS doesn't allow
|
||||||
|
root = tk.Tk()
|
||||||
|
root.iconify() # for macos
|
||||||
|
root.withdraw()
|
||||||
|
root.wm_attributes('-topmost', 1)
|
||||||
|
root.focus_force()
|
||||||
|
search_folder = filedialog.askdirectory(parent=root, title='Select Checkpoint Folder')
|
||||||
|
root.destroy()
|
||||||
|
|
||||||
|
if not search_folder:
|
||||||
|
socketio.emit(
|
||||||
|
"foundModels",
|
||||||
|
{'search_folder': None, 'found_models': None},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
search_folder, found_models = self.generate.model_cache.search_models(search_folder)
|
||||||
|
socketio.emit(
|
||||||
|
"foundModels",
|
||||||
|
{'search_folder': search_folder, 'found_models': found_models},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.socketio.emit("error", {"message": (str(e))})
|
||||||
|
print("\n")
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
print("\n")
|
||||||
|
|
||||||
|
@socketio.on("addNewModel")
|
||||||
|
def handle_add_model(new_model_config: dict):
|
||||||
|
try:
|
||||||
|
model_name = new_model_config['name']
|
||||||
|
del new_model_config['name']
|
||||||
|
model_attributes = new_model_config
|
||||||
|
update = False
|
||||||
|
current_model_list = self.generate.model_cache.list_models()
|
||||||
|
if model_name in current_model_list:
|
||||||
|
update = True
|
||||||
|
|
||||||
|
print(f">> Adding New Model: {model_name}")
|
||||||
|
|
||||||
|
self.generate.model_cache.add_model(
|
||||||
|
model_name=model_name, model_attributes=model_attributes, clobber=True)
|
||||||
|
self.generate.model_cache.commit(opt.conf)
|
||||||
|
|
||||||
|
new_model_list = self.generate.model_cache.list_models()
|
||||||
|
socketio.emit(
|
||||||
|
"newModelAdded",
|
||||||
|
{"new_model_name": model_name,
|
||||||
|
"model_list": new_model_list, 'update': update},
|
||||||
|
)
|
||||||
|
print(f">> New Model Added: {model_name}")
|
||||||
|
except Exception as e:
|
||||||
|
self.socketio.emit("error", {"message": (str(e))})
|
||||||
|
print("\n")
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
print("\n")
|
||||||
|
|
||||||
|
@socketio.on("deleteModel")
|
||||||
|
def handle_delete_model(model_name: str):
|
||||||
|
try:
|
||||||
|
print(f">> Deleting Model: {model_name}")
|
||||||
|
self.generate.model_cache.del_model(model_name)
|
||||||
|
self.generate.model_cache.commit(opt.conf)
|
||||||
|
updated_model_list = self.generate.model_cache.list_models()
|
||||||
|
socketio.emit(
|
||||||
|
"modelDeleted",
|
||||||
|
{"deleted_model_name": model_name,
|
||||||
|
"model_list": updated_model_list},
|
||||||
|
)
|
||||||
|
print(f">> Model Deleted: {model_name}")
|
||||||
|
except Exception as e:
|
||||||
|
self.socketio.emit("error", {"message": (str(e))})
|
||||||
|
print("\n")
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
print("\n")
|
||||||
|
|
||||||
@socketio.on("requestModelChange")
|
@socketio.on("requestModelChange")
|
||||||
def handle_set_model(model_name: str):
|
def handle_set_model(model_name: str):
|
||||||
try:
|
try:
|
||||||
|
@ -23,6 +23,7 @@ from omegaconf.errors import ConfigAttributeError
|
|||||||
from ldm.util import instantiate_from_config, ask_user
|
from ldm.util import instantiate_from_config, ask_user
|
||||||
from ldm.invoke.globals import Globals
|
from ldm.invoke.globals import Globals
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
DEFAULT_MAX_MODELS=2
|
DEFAULT_MAX_MODELS=2
|
||||||
|
|
||||||
@ -135,8 +136,10 @@ class ModelCache(object):
|
|||||||
for name in self.config:
|
for name in self.config:
|
||||||
try:
|
try:
|
||||||
description = self.config[name].description
|
description = self.config[name].description
|
||||||
|
weights = self.config[name].weights
|
||||||
except ConfigAttributeError:
|
except ConfigAttributeError:
|
||||||
description = '<no description>'
|
description = '<no description>'
|
||||||
|
weights = '<not found>'
|
||||||
|
|
||||||
if self.current_model == name:
|
if self.current_model == name:
|
||||||
status = 'active'
|
status = 'active'
|
||||||
@ -147,7 +150,8 @@ class ModelCache(object):
|
|||||||
|
|
||||||
models[name]={
|
models[name]={
|
||||||
'status' : status,
|
'status' : status,
|
||||||
'description' : description
|
'description' : description,
|
||||||
|
'weights': weights
|
||||||
}
|
}
|
||||||
return models
|
return models
|
||||||
|
|
||||||
@ -186,6 +190,8 @@ class ModelCache(object):
|
|||||||
|
|
||||||
config = omega[model_name] if model_name in omega else {}
|
config = omega[model_name] if model_name in omega else {}
|
||||||
for field in model_attributes:
|
for field in model_attributes:
|
||||||
|
if field == 'weights':
|
||||||
|
field.replace('\\', '/')
|
||||||
config[field] = model_attributes[field]
|
config[field] = model_attributes[field]
|
||||||
|
|
||||||
omega[model_name] = config
|
omega[model_name] = config
|
||||||
@ -311,6 +317,22 @@ class ModelCache(object):
|
|||||||
sys.exit()
|
sys.exit()
|
||||||
else:
|
else:
|
||||||
print('>> Model Scanned. OK!!')
|
print('>> Model Scanned. OK!!')
|
||||||
|
|
||||||
|
def search_models(self, search_folder):
|
||||||
|
|
||||||
|
print(f'>> Finding Models In: {search_folder}')
|
||||||
|
models_folder = Path(search_folder).glob('**/*.ckpt')
|
||||||
|
|
||||||
|
files = [x for x in models_folder if x.is_file()]
|
||||||
|
|
||||||
|
found_models = []
|
||||||
|
for file in files:
|
||||||
|
found_models.append({
|
||||||
|
'name': file.stem,
|
||||||
|
'location': str(file.resolve()).replace('\\', '/')
|
||||||
|
})
|
||||||
|
|
||||||
|
return search_folder, found_models
|
||||||
|
|
||||||
def _make_cache_room(self) -> None:
|
def _make_cache_room(self) -> None:
|
||||||
num_loaded_models = len(self.models)
|
num_loaded_models = len(self.models)
|
||||||
|
Loading…
Reference in New Issue
Block a user