mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add support for custom config files
This commit is contained in:
parent
6e52ca3307
commit
310501cd8a
@ -43,7 +43,8 @@ if not os.path.isabs(args.outdir):
|
||||
|
||||
# normalize the config directory relative to root
|
||||
if not os.path.isabs(opt.conf):
|
||||
opt.conf = os.path.normpath(os.path.join(Globals.root,opt.conf))
|
||||
opt.conf = os.path.normpath(os.path.join(Globals.root, opt.conf))
|
||||
|
||||
|
||||
class InvokeAIWebServer:
|
||||
def __init__(self, generate: Generate, gfpgan, codeformer, esrgan) -> None:
|
||||
@ -189,7 +190,8 @@ class InvokeAIWebServer:
|
||||
(width, height) = pil_image.size
|
||||
|
||||
thumbnail_path = save_thumbnail(
|
||||
pil_image, os.path.basename(file_path), self.thumbnail_image_path
|
||||
pil_image, os.path.basename(
|
||||
file_path), self.thumbnail_image_path
|
||||
)
|
||||
|
||||
response = {
|
||||
@ -264,14 +266,16 @@ class InvokeAIWebServer:
|
||||
# location for "finished" images
|
||||
self.result_path = args.outdir
|
||||
# temporary path for intermediates
|
||||
self.intermediate_path = os.path.join(self.result_path, "intermediates/")
|
||||
self.intermediate_path = os.path.join(
|
||||
self.result_path, "intermediates/")
|
||||
# path for user-uploaded init images and masks
|
||||
self.init_image_path = os.path.join(self.result_path, "init-images/")
|
||||
self.mask_image_path = os.path.join(self.result_path, "mask-images/")
|
||||
# path for temp images e.g. gallery generations which are not committed
|
||||
self.temp_image_path = os.path.join(self.result_path, "temp-images/")
|
||||
# path for thumbnail images
|
||||
self.thumbnail_image_path = os.path.join(self.result_path, "thumbnails/")
|
||||
self.thumbnail_image_path = os.path.join(
|
||||
self.result_path, "thumbnails/")
|
||||
# txt log
|
||||
self.log_path = os.path.join(self.result_path, "invoke_log.txt")
|
||||
# make all output paths
|
||||
@ -301,14 +305,16 @@ class InvokeAIWebServer:
|
||||
try:
|
||||
if not search_folder:
|
||||
socketio.emit(
|
||||
"foundModels",
|
||||
{'search_folder': None, 'found_models': None},
|
||||
)
|
||||
"foundModels",
|
||||
{'search_folder': None, 'found_models': None},
|
||||
)
|
||||
else:
|
||||
search_folder, found_models = self.generate.model_manager.search_models(search_folder)
|
||||
search_folder, found_models = self.generate.model_manager.search_models(
|
||||
search_folder)
|
||||
socketio.emit(
|
||||
"foundModels",
|
||||
{'search_folder': search_folder, 'found_models': found_models},
|
||||
{'search_folder': search_folder,
|
||||
'found_models': found_models},
|
||||
)
|
||||
except Exception as e:
|
||||
self.socketio.emit("error", {"message": (str(e))})
|
||||
@ -396,7 +402,6 @@ class InvokeAIWebServer:
|
||||
@socketio.on('convertToDiffusers')
|
||||
def convert_to_diffusers(model_to_convert: dict):
|
||||
try:
|
||||
|
||||
if (model_info := self.generate.model_manager.model_info(model_name=model_to_convert['name'])):
|
||||
if 'weights' in model_info:
|
||||
ckpt_path = Path(model_info['weights'])
|
||||
@ -404,15 +409,18 @@ class InvokeAIWebServer:
|
||||
model_name = model_to_convert["name"]
|
||||
model_description = model_info['description']
|
||||
else:
|
||||
self.socketio.emit("error", {"message": "Model is not a valid checkpoint file"})
|
||||
self.socketio.emit(
|
||||
"error", {"message": "Model is not a valid checkpoint file"})
|
||||
else:
|
||||
self.socketio.emit("error", {"message": "Could not retrieve model info."})
|
||||
|
||||
self.socketio.emit(
|
||||
"error", {"message": "Could not retrieve model info."})
|
||||
|
||||
if not ckpt_path.is_absolute():
|
||||
ckpt_path = Path(Globals.root,ckpt_path)
|
||||
|
||||
ckpt_path = Path(Globals.root, ckpt_path)
|
||||
|
||||
if original_config_file and not original_config_file.is_absolute():
|
||||
original_config_file = Path(Globals.root, original_config_file)
|
||||
original_config_file = Path(
|
||||
Globals.root, original_config_file)
|
||||
|
||||
if model_to_convert['is_inpainting']:
|
||||
original_config_file = Path(
|
||||
@ -420,19 +428,24 @@ class InvokeAIWebServer:
|
||||
'stable-diffusion',
|
||||
'v1-inpainting-inference.yaml' if model_to_convert['is_inpainting'] else 'v1-inference.yaml'
|
||||
)
|
||||
|
||||
diffusers_path = Path(f'{ckpt_path.parent.absolute()}\\{model_name}_diffusers')
|
||||
|
||||
|
||||
if model_to_convert['custom_config'] is not None:
|
||||
original_config_file = Path(
|
||||
model_to_convert['custom_config'])
|
||||
|
||||
diffusers_path = Path(
|
||||
f'{ckpt_path.parent.absolute()}\\{model_name}_diffusers')
|
||||
|
||||
if diffusers_path.exists():
|
||||
shutil.rmtree(diffusers_path)
|
||||
|
||||
|
||||
self.generate.model_manager.convert_and_import(
|
||||
ckpt_path,
|
||||
diffusers_path,
|
||||
model_name=model_name,
|
||||
model_description=model_description,
|
||||
vae = None,
|
||||
original_config_file = original_config_file,
|
||||
vae=None,
|
||||
original_config_file=original_config_file,
|
||||
commit_to_conf=opt.conf,
|
||||
)
|
||||
|
||||
@ -440,7 +453,7 @@ class InvokeAIWebServer:
|
||||
socketio.emit(
|
||||
"newModelAdded",
|
||||
{"new_model_name": model_name,
|
||||
"model_list": new_model_list, 'update': True},
|
||||
"model_list": new_model_list, 'update': True},
|
||||
)
|
||||
print(f">> Model Converted: {model_name}")
|
||||
except Exception as e:
|
||||
@ -448,7 +461,7 @@ class InvokeAIWebServer:
|
||||
print("\n")
|
||||
|
||||
traceback.print_exc()
|
||||
print("\n")
|
||||
print("\n")
|
||||
|
||||
@socketio.on("requestEmptyTempFolder")
|
||||
def empty_temp_folder():
|
||||
@ -463,7 +476,8 @@ class InvokeAIWebServer:
|
||||
)
|
||||
os.remove(thumbnail_path)
|
||||
except Exception as e:
|
||||
socketio.emit("error", {"message": f"Unable to delete {f}: {str(e)}"})
|
||||
socketio.emit(
|
||||
"error", {"message": f"Unable to delete {f}: {str(e)}"})
|
||||
pass
|
||||
|
||||
socketio.emit("tempFolderEmptied")
|
||||
@ -478,7 +492,8 @@ class InvokeAIWebServer:
|
||||
def save_temp_image_to_gallery(url):
|
||||
try:
|
||||
image_path = self.get_image_path_from_url(url)
|
||||
new_path = os.path.join(self.result_path, os.path.basename(image_path))
|
||||
new_path = os.path.join(
|
||||
self.result_path, os.path.basename(image_path))
|
||||
shutil.copy2(image_path, new_path)
|
||||
|
||||
if os.path.splitext(new_path)[1] == ".png":
|
||||
@ -491,7 +506,8 @@ class InvokeAIWebServer:
|
||||
(width, height) = pil_image.size
|
||||
|
||||
thumbnail_path = save_thumbnail(
|
||||
pil_image, os.path.basename(new_path), self.thumbnail_image_path
|
||||
pil_image, os.path.basename(
|
||||
new_path), self.thumbnail_image_path
|
||||
)
|
||||
|
||||
image_array = [
|
||||
@ -554,7 +570,8 @@ class InvokeAIWebServer:
|
||||
(width, height) = pil_image.size
|
||||
|
||||
thumbnail_path = save_thumbnail(
|
||||
pil_image, os.path.basename(path), self.thumbnail_image_path
|
||||
pil_image, os.path.basename(
|
||||
path), self.thumbnail_image_path
|
||||
)
|
||||
|
||||
image_array.append(
|
||||
@ -572,7 +589,8 @@ class InvokeAIWebServer:
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
socketio.emit("error", {"message": f"Unable to load {path}: {str(e)}"})
|
||||
socketio.emit(
|
||||
"error", {"message": f"Unable to load {path}: {str(e)}"})
|
||||
pass
|
||||
|
||||
socketio.emit(
|
||||
@ -626,7 +644,8 @@ class InvokeAIWebServer:
|
||||
(width, height) = pil_image.size
|
||||
|
||||
thumbnail_path = save_thumbnail(
|
||||
pil_image, os.path.basename(path), self.thumbnail_image_path
|
||||
pil_image, os.path.basename(
|
||||
path), self.thumbnail_image_path
|
||||
)
|
||||
|
||||
image_array.append(
|
||||
@ -645,7 +664,8 @@ class InvokeAIWebServer:
|
||||
)
|
||||
except Exception as e:
|
||||
print(f">> Unable to load {path}")
|
||||
socketio.emit("error", {"message": f"Unable to load {path}: {str(e)}"})
|
||||
socketio.emit(
|
||||
"error", {"message": f"Unable to load {path}: {str(e)}"})
|
||||
pass
|
||||
|
||||
socketio.emit(
|
||||
@ -683,7 +703,8 @@ class InvokeAIWebServer:
|
||||
printable_parameters["init_mask"][:64] + "..."
|
||||
)
|
||||
|
||||
print(f'\n>> Image Generation Parameters:\n\n{printable_parameters}\n')
|
||||
print(
|
||||
f'\n>> Image Generation Parameters:\n\n{printable_parameters}\n')
|
||||
print(f'>> ESRGAN Parameters: {esrgan_parameters}')
|
||||
print(f'>> Facetool Parameters: {facetool_parameters}')
|
||||
|
||||
@ -726,9 +747,11 @@ class InvokeAIWebServer:
|
||||
if postprocessing_parameters["type"] == "esrgan":
|
||||
progress.set_current_status("common:statusUpscalingESRGAN")
|
||||
elif postprocessing_parameters["type"] == "gfpgan":
|
||||
progress.set_current_status("common:statusRestoringFacesGFPGAN")
|
||||
progress.set_current_status(
|
||||
"common:statusRestoringFacesGFPGAN")
|
||||
elif postprocessing_parameters["type"] == "codeformer":
|
||||
progress.set_current_status("common:statusRestoringFacesCodeFormer")
|
||||
progress.set_current_status(
|
||||
"common:statusRestoringFacesCodeFormer")
|
||||
|
||||
socketio.emit("progressUpdate", progress.to_formatted_dict())
|
||||
eventlet.sleep(0)
|
||||
@ -904,7 +927,8 @@ class InvokeAIWebServer:
|
||||
|
||||
init_img_url = generation_parameters["init_img"]
|
||||
|
||||
original_bounding_box = generation_parameters["bounding_box"].copy()
|
||||
original_bounding_box = generation_parameters["bounding_box"].copy(
|
||||
)
|
||||
|
||||
initial_image = dataURL_to_image(
|
||||
generation_parameters["init_img"]
|
||||
@ -981,7 +1005,8 @@ class InvokeAIWebServer:
|
||||
elif generation_parameters["generation_mode"] == "img2img":
|
||||
init_img_url = generation_parameters["init_img"]
|
||||
init_img_path = self.get_image_path_from_url(init_img_url)
|
||||
generation_parameters["init_img"] = Image.open(init_img_path).convert('RGB')
|
||||
generation_parameters["init_img"] = Image.open(
|
||||
init_img_path).convert('RGB')
|
||||
|
||||
def image_progress(sample, step):
|
||||
if self.canceled.is_set():
|
||||
@ -1040,9 +1065,9 @@ class InvokeAIWebServer:
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if generation_parameters["progress_latents"]:
|
||||
image = self.generate.sample_to_lowres_estimated_image(sample)
|
||||
image = self.generate.sample_to_lowres_estimated_image(
|
||||
sample)
|
||||
(width, height) = image.size
|
||||
width *= 8
|
||||
height *= 8
|
||||
@ -1061,7 +1086,8 @@ class InvokeAIWebServer:
|
||||
},
|
||||
)
|
||||
|
||||
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
||||
self.socketio.emit(
|
||||
"progressUpdate", progress.to_formatted_dict())
|
||||
eventlet.sleep(0)
|
||||
|
||||
def image_done(image, seed, first_seed, attention_maps_image=None):
|
||||
@ -1089,7 +1115,8 @@ class InvokeAIWebServer:
|
||||
|
||||
progress.set_current_status("common:statusGenerationComplete")
|
||||
|
||||
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
||||
self.socketio.emit(
|
||||
"progressUpdate", progress.to_formatted_dict())
|
||||
eventlet.sleep(0)
|
||||
|
||||
all_parameters = generation_parameters
|
||||
@ -1100,7 +1127,8 @@ class InvokeAIWebServer:
|
||||
and all_parameters["variation_amount"] > 0
|
||||
):
|
||||
first_seed = first_seed or seed
|
||||
this_variation = [[seed, all_parameters["variation_amount"]]]
|
||||
this_variation = [
|
||||
[seed, all_parameters["variation_amount"]]]
|
||||
all_parameters["with_variations"] = (
|
||||
prior_variations + this_variation
|
||||
)
|
||||
@ -1116,7 +1144,8 @@ class InvokeAIWebServer:
|
||||
if esrgan_parameters:
|
||||
progress.set_current_status("common:statusUpscaling")
|
||||
progress.set_current_status_has_steps(False)
|
||||
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
||||
self.socketio.emit(
|
||||
"progressUpdate", progress.to_formatted_dict())
|
||||
eventlet.sleep(0)
|
||||
|
||||
image = self.esrgan.process(
|
||||
@ -1139,12 +1168,15 @@ class InvokeAIWebServer:
|
||||
|
||||
if facetool_parameters:
|
||||
if facetool_parameters["type"] == "gfpgan":
|
||||
progress.set_current_status("common:statusRestoringFacesGFPGAN")
|
||||
progress.set_current_status(
|
||||
"common:statusRestoringFacesGFPGAN")
|
||||
elif facetool_parameters["type"] == "codeformer":
|
||||
progress.set_current_status("common:statusRestoringFacesCodeFormer")
|
||||
progress.set_current_status(
|
||||
"common:statusRestoringFacesCodeFormer")
|
||||
|
||||
progress.set_current_status_has_steps(False)
|
||||
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
||||
self.socketio.emit(
|
||||
"progressUpdate", progress.to_formatted_dict())
|
||||
eventlet.sleep(0)
|
||||
|
||||
if facetool_parameters["type"] == "gfpgan":
|
||||
@ -1174,7 +1206,8 @@ class InvokeAIWebServer:
|
||||
all_parameters["facetool_type"] = facetool_parameters["type"]
|
||||
|
||||
progress.set_current_status("common:statusSavingImage")
|
||||
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
||||
self.socketio.emit(
|
||||
"progressUpdate", progress.to_formatted_dict())
|
||||
eventlet.sleep(0)
|
||||
|
||||
# restore the stashed URLS and discard the paths, we are about to send the result to client
|
||||
@ -1185,12 +1218,14 @@ class InvokeAIWebServer:
|
||||
)
|
||||
|
||||
if "init_mask" in all_parameters:
|
||||
all_parameters["init_mask"] = "" # TODO: store the mask in metadata
|
||||
# TODO: store the mask in metadata
|
||||
all_parameters["init_mask"] = ""
|
||||
|
||||
if generation_parameters["generation_mode"] == "unifiedCanvas":
|
||||
all_parameters["bounding_box"] = original_bounding_box
|
||||
|
||||
metadata = self.parameters_to_generated_image_metadata(all_parameters)
|
||||
metadata = self.parameters_to_generated_image_metadata(
|
||||
all_parameters)
|
||||
|
||||
command = parameters_to_command(all_parameters)
|
||||
|
||||
@ -1220,15 +1255,18 @@ class InvokeAIWebServer:
|
||||
|
||||
if progress.total_iterations > progress.current_iteration:
|
||||
progress.set_current_step(1)
|
||||
progress.set_current_status("common:statusIterationComplete")
|
||||
progress.set_current_status(
|
||||
"common:statusIterationComplete")
|
||||
progress.set_current_status_has_steps(False)
|
||||
else:
|
||||
progress.mark_complete()
|
||||
|
||||
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
||||
self.socketio.emit(
|
||||
"progressUpdate", progress.to_formatted_dict())
|
||||
eventlet.sleep(0)
|
||||
|
||||
parsed_prompt, _ = get_prompt_structure(generation_parameters["prompt"])
|
||||
parsed_prompt, _ = get_prompt_structure(
|
||||
generation_parameters["prompt"])
|
||||
tokens = None if type(parsed_prompt) is Blend else \
|
||||
get_tokens_for_prompt(self.generate.model, parsed_prompt)
|
||||
attention_maps_image_base64_url = None if attention_maps_image is None \
|
||||
@ -1402,7 +1440,8 @@ class InvokeAIWebServer:
|
||||
self, parameters, original_image_path
|
||||
):
|
||||
try:
|
||||
current_metadata = retrieve_metadata(original_image_path)["sd-metadata"]
|
||||
current_metadata = retrieve_metadata(
|
||||
original_image_path)["sd-metadata"]
|
||||
postprocessing_metadata = {}
|
||||
|
||||
"""
|
||||
@ -1442,7 +1481,8 @@ class InvokeAIWebServer:
|
||||
postprocessing_metadata
|
||||
)
|
||||
else:
|
||||
current_metadata["image"]["postprocessing"] = [postprocessing_metadata]
|
||||
current_metadata["image"]["postprocessing"] = [
|
||||
postprocessing_metadata]
|
||||
|
||||
return current_metadata
|
||||
|
||||
@ -1554,7 +1594,8 @@ class InvokeAIWebServer:
|
||||
)
|
||||
elif "thumbnails" in url:
|
||||
return os.path.abspath(
|
||||
os.path.join(self.thumbnail_image_path, os.path.basename(url))
|
||||
os.path.join(self.thumbnail_image_path,
|
||||
os.path.basename(url))
|
||||
)
|
||||
else:
|
||||
return os.path.abspath(
|
||||
@ -1723,10 +1764,12 @@ def dataURL_to_image(dataURL: str) -> ImageType:
|
||||
)
|
||||
return image
|
||||
|
||||
|
||||
"""
|
||||
Converts an image into a base64 image dataURL.
|
||||
"""
|
||||
|
||||
|
||||
def image_to_dataURL(image: ImageType) -> str:
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
@ -1736,7 +1779,6 @@ def image_to_dataURL(image: ImageType) -> str:
|
||||
return image_base64
|
||||
|
||||
|
||||
|
||||
"""
|
||||
Converts a base64 image dataURL into bytes.
|
||||
The dataURL is split on the first commma.
|
||||
|
@ -74,5 +74,6 @@
|
||||
"convertToDiffusersHelpText6": "Do you wish to convert this model?",
|
||||
"inpaintingModel": "Inpainting Model",
|
||||
"customConfig": "Custom Config",
|
||||
"pathToCustomConfig": "Path To Custom Config"
|
||||
"pathToCustomConfig": "Path To Custom Config",
|
||||
"statusConverting": "Converting"
|
||||
}
|
||||
|
1
invokeai/frontend/src/app/invokeai.d.ts
vendored
1
invokeai/frontend/src/app/invokeai.d.ts
vendored
@ -222,6 +222,7 @@ export declare type InvokeDiffusersModelConfigProps = {
|
||||
export declare type InvokeModelConversionProps = {
|
||||
name: string;
|
||||
is_inpainting: boolean;
|
||||
custom_config: string | null;
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -178,8 +178,10 @@ const makeSocketIOEmitters = (
|
||||
emitDeleteModel: (modelName: string) => {
|
||||
socketio.emit('deleteModel', modelName);
|
||||
},
|
||||
emitConvertToDiffusers: (modelName: string) => {
|
||||
socketio.emit('convertToDiffusers', modelName);
|
||||
emitConvertToDiffusers: (
|
||||
modelToConvert: InvokeAI.InvokeModelConversionProps
|
||||
) => {
|
||||
socketio.emit('convertToDiffusers', modelToConvert);
|
||||
},
|
||||
emitRequestModelChange: (modelName: string) => {
|
||||
dispatch(modelChangeRequested());
|
||||
|
@ -25,6 +25,7 @@ export default function ModelConvert(props: ModelConvertProps) {
|
||||
|
||||
const [isInpainting, setIsInpainting] = useState<boolean>(false);
|
||||
const [customConfig, setIsCustomConfig] = useState<boolean>(false);
|
||||
const [pathToConfig, setPathToConfig] = useState<string>('');
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
@ -37,20 +38,40 @@ export default function ModelConvert(props: ModelConvertProps) {
|
||||
(state: RootState) => state.system.isConnected
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
// Need to manually handle local state reset because the component does not re-render.
|
||||
const stateReset = () => {
|
||||
setIsInpainting(false);
|
||||
setIsCustomConfig(false);
|
||||
setPathToConfig('');
|
||||
};
|
||||
|
||||
// Reset local state when model changes
|
||||
useEffect(() => {
|
||||
stateReset();
|
||||
}, [model]);
|
||||
|
||||
// Handle local state reset when user cancels input
|
||||
const modelConvertCancelHandler = () => {
|
||||
stateReset();
|
||||
};
|
||||
|
||||
const modelConvertHandler = () => {
|
||||
const modelConvertData = {
|
||||
name: model,
|
||||
is_inpainting: isInpainting,
|
||||
custom_config: customConfig && pathToConfig !== '' ? pathToConfig : null,
|
||||
};
|
||||
|
||||
dispatch(setIsProcessing(true));
|
||||
dispatch(convertToDiffusers({ name: model, is_inpainting: isInpainting }));
|
||||
dispatch(convertToDiffusers(modelConvertData));
|
||||
stateReset(); // Edge case: Cancel local state when model convert fails
|
||||
};
|
||||
|
||||
return (
|
||||
<IAIAlertDialog
|
||||
title={`${t('modelmanager:convert')} ${model}`}
|
||||
acceptCallback={modelConvertHandler}
|
||||
cancelCallback={modelConvertCancelHandler}
|
||||
acceptButtonText={`${t('modelmanager:convert')}`}
|
||||
triggerComponent={
|
||||
<IAIButton
|
||||
@ -105,7 +126,13 @@ export default function ModelConvert(props: ModelConvertProps) {
|
||||
>
|
||||
{t('modelmanager:pathToCustomConfig')}
|
||||
</Text>
|
||||
<IAIInput width="25rem" />
|
||||
<IAIInput
|
||||
value={pathToConfig}
|
||||
onChange={(e) => {
|
||||
if (e.target.value !== '') setPathToConfig(e.target.value);
|
||||
}}
|
||||
width="25rem"
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
|
Loading…
Reference in New Issue
Block a user