Merge branch 'main' into set-timestep-mps-fix

This commit is contained in:
Millun Atluri 2023-07-28 16:12:07 +10:00 committed by GitHub
commit 7f81a95b20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 518 additions and 210 deletions

View File

@ -123,7 +123,7 @@ and go to http://localhost:9090.
### Command-Line Installation (for developers and users familiar with Terminals) ### Command-Line Installation (for developers and users familiar with Terminals)
You must have Python 3.9 or 3.10 installed on your machine. Earlier or You must have Python 3.9 through 3.11 installed on your machine. Earlier or
later versions are not supported. later versions are not supported.
Node.js also needs to be installed along with yarn (can be installed with Node.js also needs to be installed along with yarn (can be installed with
the command `npm install -g yarn` if needed) the command `npm install -g yarn` if needed)

View File

@ -40,10 +40,8 @@ experimental versions later.
this, open up a command-line window ("Terminal" on Linux and this, open up a command-line window ("Terminal" on Linux and
Macintosh, "Command" or "Powershell" on Windows) and type `python Macintosh, "Command" or "Powershell" on Windows) and type `python
--version`. If Python is installed, it will print out the version --version`. If Python is installed, it will print out the version
number. If it is version `3.9.*` or `3.10.*`, you meet number. If it is version `3.9.*`, `3.10.*` or `3.11.*` you meet
requirements. We do not recommend using Python 3.11 or higher, requirements.
as not all the libraries that InvokeAI depends on work properly
with this version.
!!! warning "What to do if you have an unsupported version" !!! warning "What to do if you have an unsupported version"

View File

@ -32,7 +32,7 @@ gaming):
* **Python** * **Python**
version 3.9 or 3.10 (3.11 is not recommended). version 3.9 through 3.11
* **CUDA Tools** * **CUDA Tools**
@ -65,7 +65,7 @@ gaming):
To install InvokeAI with virtual environments and the PIP package To install InvokeAI with virtual environments and the PIP package
manager, please follow these steps: manager, please follow these steps:
1. Please make sure you are using Python 3.9 or 3.10. The rest of the install 1. Please make sure you are using Python 3.9 through 3.11. The rest of the install
procedure depends on this and will not work with other versions: procedure depends on this and will not work with other versions:
```bash ```bash

View File

@ -9,16 +9,20 @@ cd $scriptdir
function version { echo "$@" | awk -F. '{ printf("%d%03d%03d%03d\n", $1,$2,$3,$4); }'; } function version { echo "$@" | awk -F. '{ printf("%d%03d%03d%03d\n", $1,$2,$3,$4); }'; }
MINIMUM_PYTHON_VERSION=3.9.0 MINIMUM_PYTHON_VERSION=3.9.0
MAXIMUM_PYTHON_VERSION=3.11.0 MAXIMUM_PYTHON_VERSION=3.11.100
PYTHON="" PYTHON=""
for candidate in python3.10 python3.9 python3 python ; do for candidate in python3.11 python3.10 python3.9 python3 python ; do
if ppath=`which $candidate`; then if ppath=`which $candidate`; then
# when using `pyenv`, the executable for an inactive Python version will exist but will not be operational
# we check that this found executable can actually run
if [ $($candidate --version &>/dev/null; echo ${PIPESTATUS}) -gt 0 ]; then continue; fi
python_version=$($ppath -V | awk '{ print $2 }') python_version=$($ppath -V | awk '{ print $2 }')
if [ $(version $python_version) -ge $(version "$MINIMUM_PYTHON_VERSION") ]; then if [ $(version $python_version) -ge $(version "$MINIMUM_PYTHON_VERSION") ]; then
if [ $(version $python_version) -lt $(version "$MAXIMUM_PYTHON_VERSION") ]; then if [ $(version $python_version) -le $(version "$MAXIMUM_PYTHON_VERSION") ]; then
PYTHON=$ppath PYTHON=$ppath
break break
fi fi
fi fi
fi fi
done done

View File

@ -90,7 +90,7 @@ async def update_model(
new_name=info.model_name, new_name=info.model_name,
new_base=info.base_model, new_base=info.base_model,
) )
logger.info(f"Successfully renamed {base_model}/{model_name}=>{info.base_model}/{info.model_name}") logger.info(f"Successfully renamed {base_model.value}/{model_name}=>{info.base_model}/{info.model_name}")
# update information to support an update of attributes # update information to support an update of attributes
model_name = info.model_name model_name = info.model_name
base_model = info.base_model base_model = info.base_model

View File

@ -3,6 +3,7 @@ import asyncio
import sys import sys
from inspect import signature from inspect import signature
import logging
import uvicorn import uvicorn
import socket import socket
@ -210,11 +211,25 @@ def invoke_api():
port = find_port(app_config.port) port = find_port(app_config.port)
if port != app_config.port: if port != app_config.port:
logger.warn(f"Port {app_config.port} in use, using port {port}") logger.warn(f"Port {app_config.port} in use, using port {port}")
# Start our own event loop for eventing usage # Start our own event loop for eventing usage
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
config = uvicorn.Config(app=app, host=app_config.host, port=port, loop=loop) config = uvicorn.Config(
# Use access_log to turn off logging app=app,
host=app_config.host,
port=port,
loop=loop,
log_level=app_config.log_level,
)
server = uvicorn.Server(config) server = uvicorn.Server(config)
# replace uvicorn's loggers with InvokeAI's for consistent appearance
for logname in ["uvicorn.access", "uvicorn"]:
l = logging.getLogger(logname)
l.handlers.clear()
for ch in logger.handlers:
l.addHandler(ch)
loop.run_until_complete(server.serve()) loop.run_until_complete(server.serve())

View File

@ -12,7 +12,7 @@ from pydantic import BaseModel, Field, validator
from invokeai.app.invocations.metadata import CoreMetadata from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models.base import ModelType from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
@ -311,70 +311,71 @@ class TextToLatentsInvocation(BaseInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name) with SilenceWarnings():
noise = context.services.latents.get(self.noise.latents_name)
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state) self.dispatch_progress(context, source_node_id, state)
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}), **lora.dict(exclude={"weight"}),
context=context,
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict(),
context=context,
)
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
unet_info.context.model, _lora_loader()
), unet_info as unet:
noise = noise.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context, context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
) )
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model( pipeline = self.create_pipeline(unet, scheduler)
**self.unet.unet.dict(), conditioning_data = self.get_conditioning_data(context, scheduler, unet)
context=context,
)
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
unet_info.context.model, _lora_loader()
), unet_info as unet:
noise = noise.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler( control_data = self.prep_control_data(
context=context, model=pipeline,
scheduler_info=self.unet.scheduler, context=context,
scheduler_name=self.scheduler, control_input=self.control,
) latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
pipeline = self.create_pipeline(unet, scheduler) # TODO: Verify the noise is the right size
conditioning_data = self.get_conditioning_data(context, scheduler, unet) result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback,
)
control_data = self.prep_control_data( # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
model=pipeline, result_latents = result_latents.to("cpu")
context=context, torch.cuda.empty_cache()
control_input=self.control,
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
# TODO: Verify the noise is the right size name = f"{context.graph_execution_state_id}__{self.id}"
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( context.services.latents.save(name, result_latents)
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)), return build_latents_output(latents_name=name, latents=result_latents)
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents)
class LatentsToLatentsInvocation(TextToLatentsInvocation): class LatentsToLatentsInvocation(TextToLatentsInvocation):
@ -402,82 +403,83 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name) with SilenceWarnings(): # this quenches NSFW nag from diffusers
latent = context.services.latents.get(self.latents.latents_name) noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name)
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state) self.dispatch_progress(context, source_node_id, state)
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}), **lora.dict(exclude={"weight"}),
context=context,
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict(),
context=context,
)
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
unet_info.context.model, _lora_loader()
), unet_info as unet:
noise = noise.to(device=unet.device, dtype=unet.dtype)
latent = latent.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context, context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
) )
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model( pipeline = self.create_pipeline(unet, scheduler)
**self.unet.unet.dict(), conditioning_data = self.get_conditioning_data(context, scheduler, unet)
context=context,
)
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
unet_info.context.model, _lora_loader()
), unet_info as unet:
noise = noise.to(device=unet.device, dtype=unet.dtype)
latent = latent.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler( control_data = self.prep_control_data(
context=context, model=pipeline,
scheduler_info=self.unet.scheduler, context=context,
scheduler_name=self.scheduler, control_input=self.control,
) latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
pipeline = self.create_pipeline(unet, scheduler) # TODO: Verify the noise is the right size
conditioning_data = self.get_conditioning_data(context, scheduler, unet) initial_latents = (
latent if self.strength < 1.0 else torch.zeros_like(latent, device=unet.device, dtype=latent.dtype)
)
control_data = self.prep_control_data( timesteps, _ = pipeline.get_img2img_timesteps(
model=pipeline, self.steps,
context=context, self.strength,
control_input=self.control, device=unet.device,
latents_shape=noise.shape, )
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
# TODO: Verify the noise is the right size result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
initial_latents = ( latents=initial_latents,
latent if self.strength < 1.0 else torch.zeros_like(latent, device=unet.device, dtype=latent.dtype) timesteps=timesteps,
) noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback,
)
timesteps, _ = pipeline.get_img2img_timesteps( # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
self.steps, result_latents = result_latents.to("cpu")
self.strength, torch.cuda.empty_cache()
device=unet.device,
)
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( name = f"{context.graph_execution_state_id}__{self.id}"
latents=initial_latents, context.services.latents.save(name, result_latents)
timesteps=timesteps,
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents) return build_latents_output(latents_name=name, latents=result_latents)
@ -490,7 +492,7 @@ class LatentsToImageInvocation(BaseInvocation):
# Inputs # Inputs
latents: Optional[LatentsField] = Field(description="The latents to generate an image from") latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
vae: VaeField = Field(default=None, description="Vae submodel") vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field(default=False, description="Decode latents by overlapping tiles(less memory consumption)") tiled: bool = Field(default=False, description="Decode latents by overlaping tiles (less memory consumption)")
fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision") fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
metadata: Optional[CoreMetadata] = Field( metadata: Optional[CoreMetadata] = Field(
default=None, description="Optional core metadata to be written to the image" default=None, description="Optional core metadata to be written to the image"

View File

@ -401,7 +401,11 @@ class ModelManager(object):
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
) -> str: ) -> str:
return f"{base_model}/{model_type}/{model_name}" # In 3.11, the behavior of (str,enum) when interpolated into a
# string has changed. The next two lines are defensive.
base_model = BaseModelType(base_model)
model_type = ModelType(model_type)
return f"{base_model.value}/{model_type.value}/{model_name}"
@classmethod @classmethod
def parse_key(cls, model_key: str) -> Tuple[str, BaseModelType, ModelType]: def parse_key(cls, model_key: str) -> Tuple[str, BaseModelType, ModelType]:

View File

@ -57,7 +57,7 @@ class LoRAModel(ModelBase):
@classproperty @classproperty
def save_to_config(cls) -> bool: def save_to_config(cls) -> bool:
return False return True
@classmethod @classmethod
def detect_format(cls, path: str): def detect_format(cls, path: str):

View File

@ -1,4 +1,4 @@
import{A as g,fS as Xe,z as x,a4 as Ba,fT as Ea,af as ca,aj as c,fU as b,al as Da,fV as t,fW as Ra,fX as h,fY as ba,fZ as ja,f_ as Ha,aI as Wa,f$ as Va,ad as La,g0 as qa}from"./index-89941396.js";import{n,o as Sr,p as Oa,T as Na,q as Ga,s as Ua,t as Ya,v as Xa,w as Ka,x as Za,y as Ja,z as Qa,A as et,B as rt,D as at,E as tt,F as ot,G as nt,e as it,M as lt}from"./MantineProvider-8184f020.js";var va=String.raw,ua=va` import{A as g,fS as Xe,z as x,a4 as Ba,fT as Ea,af as ca,aj as c,fU as b,al as Da,fV as t,fW as Ra,fX as h,fY as ba,fZ as ja,f_ as Ha,aI as Wa,f$ as Va,ad as La,g0 as qa}from"./index-5a784cdd.js";import{n,o as Sr,p as Oa,T as Na,q as Ga,s as Ua,t as Ya,v as Xa,w as Ka,x as Za,y as Ja,z as Qa,A as et,B as rt,D as at,E as tt,F as ot,G as nt,e as it,M as lt}from"./MantineProvider-ea42d3d1.js";var va=String.raw,ua=va`
:root, :root,
:host { :host {
--chakra-vh: 100vh; --chakra-vh: 100vh;

View File

@ -12,7 +12,7 @@
margin: 0; margin: 0;
} }
</style> </style>
<script type="module" crossorigin src="./assets/index-89941396.js"></script> <script type="module" crossorigin src="./assets/index-5a784cdd.js"></script>
</head> </head>
<body dir="ltr"> <body dir="ltr">

View File

@ -340,6 +340,7 @@
"allModels": "All Models", "allModels": "All Models",
"checkpointModels": "Checkpoints", "checkpointModels": "Checkpoints",
"diffusersModels": "Diffusers", "diffusersModels": "Diffusers",
"loraModels": "LoRAs",
"safetensorModels": "SafeTensors", "safetensorModels": "SafeTensors",
"modelAdded": "Model Added", "modelAdded": "Model Added",
"modelUpdated": "Model Updated", "modelUpdated": "Model Updated",

View File

@ -1,3 +1,5 @@
import { components } from 'services/api/schema';
export const MODEL_TYPE_MAP = { export const MODEL_TYPE_MAP = {
'sd-1': 'Stable Diffusion 1.x', 'sd-1': 'Stable Diffusion 1.x',
'sd-2': 'Stable Diffusion 2.x', 'sd-2': 'Stable Diffusion 2.x',
@ -5,6 +7,13 @@ export const MODEL_TYPE_MAP = {
'sdxl-refiner': 'Stable Diffusion XL Refiner', 'sdxl-refiner': 'Stable Diffusion XL Refiner',
}; };
export const MODEL_TYPE_SHORT_MAP = {
'sd-1': 'SD1',
'sd-2': 'SD2',
sdxl: 'SDXL',
'sdxl-refiner': 'SDXLR',
};
export const clipSkipMap = { export const clipSkipMap = {
'sd-1': { 'sd-1': {
maxClip: 12, maxClip: 12,
@ -23,3 +32,12 @@ export const clipSkipMap = {
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24], markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
}, },
}; };
type LoRAModelFormatMap = {
[key in components['schemas']['LoRAModelFormat']]: string;
};
export const LORA_MODEL_FORMAT_MAP: LoRAModelFormatMap = {
lycoris: 'LyCORIS',
diffusers: 'Diffusers',
};

View File

@ -3,20 +3,31 @@ import { Flex, Text } from '@chakra-ui/react';
import { useState } from 'react'; import { useState } from 'react';
import { import {
MainModelConfigEntity, MainModelConfigEntity,
DiffusersModelConfigEntity,
LoRAModelConfigEntity,
useGetMainModelsQuery, useGetMainModelsQuery,
useGetLoRAModelsQuery,
} from 'services/api/endpoints/models'; } from 'services/api/endpoints/models';
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit'; import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit'; import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
import LoRAModelEdit from './ModelManagerPanel/LoRAModelEdit';
import ModelList from './ModelManagerPanel/ModelList'; import ModelList from './ModelManagerPanel/ModelList';
import { ALL_BASE_MODELS } from 'services/api/constants'; import { ALL_BASE_MODELS } from 'services/api/constants';
export default function ModelManagerPanel() { export default function ModelManagerPanel() {
const [selectedModelId, setSelectedModelId] = useState<string>(); const [selectedModelId, setSelectedModelId] = useState<string>();
const { model } = useGetMainModelsQuery(ALL_BASE_MODELS, { const { mainModel } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({ selectFromResult: ({ data }) => ({
model: selectedModelId ? data?.entities[selectedModelId] : undefined, mainModel: selectedModelId ? data?.entities[selectedModelId] : undefined,
}), }),
}); });
const { loraModel } = useGetLoRAModelsQuery(undefined, {
selectFromResult: ({ data }) => ({
loraModel: selectedModelId ? data?.entities[selectedModelId] : undefined,
}),
});
const model = mainModel ? mainModel : loraModel;
return ( return (
<Flex sx={{ gap: 8, w: 'full', h: 'full' }}> <Flex sx={{ gap: 8, w: 'full', h: 'full' }}>
@ -30,7 +41,7 @@ export default function ModelManagerPanel() {
} }
type ModelEditProps = { type ModelEditProps = {
model: MainModelConfigEntity | undefined; model: MainModelConfigEntity | LoRAModelConfigEntity | undefined;
}; };
const ModelEdit = (props: ModelEditProps) => { const ModelEdit = (props: ModelEditProps) => {
@ -41,7 +52,16 @@ const ModelEdit = (props: ModelEditProps) => {
} }
if (model?.model_format === 'diffusers') { if (model?.model_format === 'diffusers') {
return <DiffusersModelEdit key={model.id} model={model} />; return (
<DiffusersModelEdit
key={model.id}
model={model as DiffusersModelConfigEntity}
/>
);
}
if (model?.model_type === 'lora') {
return <LoRAModelEdit key={model.id} model={model} />;
} }
return ( return (

View File

@ -0,0 +1,137 @@
import { Divider, Flex, Text } from '@chakra-ui/react';
import { useForm } from '@mantine/form';
import { makeToast } from 'features/system/util/makeToast';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import {
LORA_MODEL_FORMAT_MAP,
MODEL_TYPE_MAP,
} from 'features/parameters/types/constants';
import {
LoRAModelConfigEntity,
useUpdateLoRAModelsMutation,
} from 'services/api/endpoints/models';
import { LoRAModelConfig } from 'services/api/types';
import BaseModelSelect from '../shared/BaseModelSelect';
type LoRAModelEditProps = {
model: LoRAModelConfigEntity;
};
export default function LoRAModelEdit(props: LoRAModelEditProps) {
const isBusy = useAppSelector(selectIsBusy);
const { model } = props;
const [updateLoRAModel, { isLoading }] = useUpdateLoRAModelsMutation();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const loraEditForm = useForm<LoRAModelConfig>({
initialValues: {
model_name: model.model_name ? model.model_name : '',
base_model: model.base_model,
model_type: 'lora',
path: model.path ? model.path : '',
description: model.description ? model.description : '',
model_format: model.model_format,
},
validate: {
path: (value) =>
value.trim().length === 0 ? 'Must provide a path' : null,
},
});
const editModelFormSubmitHandler = useCallback(
(values: LoRAModelConfig) => {
const responseBody = {
base_model: model.base_model,
model_name: model.model_name,
body: values,
};
updateLoRAModel(responseBody)
.unwrap()
.then((payload) => {
loraEditForm.setValues(payload as LoRAModelConfig);
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
})
.catch((_) => {
loraEditForm.reset();
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'error',
})
)
);
});
},
[
dispatch,
loraEditForm,
model.base_model,
model.model_name,
t,
updateLoRAModel,
]
);
return (
<Flex flexDirection="column" rowGap={4} width="100%">
<Flex flexDirection="column">
<Text fontSize="lg" fontWeight="bold">
{model.model_name}
</Text>
<Text fontSize="sm" color="base.400">
{MODEL_TYPE_MAP[model.base_model]} Model {' '}
{LORA_MODEL_FORMAT_MAP[model.model_format]} format
</Text>
</Flex>
<Divider />
<form
onSubmit={loraEditForm.onSubmit((values) =>
editModelFormSubmitHandler(values)
)}
>
<Flex flexDirection="column" overflowY="scroll" gap={4}>
<IAIMantineTextInput
label={t('modelManager.name')}
{...loraEditForm.getInputProps('model_name')}
/>
<IAIMantineTextInput
label={t('modelManager.description')}
{...loraEditForm.getInputProps('description')}
/>
<BaseModelSelect {...loraEditForm.getInputProps('base_model')} />
<IAIMantineTextInput
label={t('modelManager.modelLocation')}
{...loraEditForm.getInputProps('path')}
/>
<IAIButton
type="submit"
isDisabled={isBusy || isLoading}
isLoading={isLoading}
>
{t('modelManager.updateModel')}
</IAIButton>
</Flex>
</form>
</Flex>
);
}

View File

@ -9,6 +9,8 @@ import { useTranslation } from 'react-i18next';
import { import {
MainModelConfigEntity, MainModelConfigEntity,
useGetMainModelsQuery, useGetMainModelsQuery,
useGetLoRAModelsQuery,
LoRAModelConfigEntity,
} from 'services/api/endpoints/models'; } from 'services/api/endpoints/models';
import ModelListItem from './ModelListItem'; import ModelListItem from './ModelListItem';
import { ALL_BASE_MODELS } from 'services/api/constants'; import { ALL_BASE_MODELS } from 'services/api/constants';
@ -20,22 +22,42 @@ type ModelListProps = {
type ModelFormat = 'images' | 'checkpoint' | 'diffusers'; type ModelFormat = 'images' | 'checkpoint' | 'diffusers';
type ModelType = 'main' | 'lora';
type CombinedModelFormat = ModelFormat | 'lora';
const ModelList = (props: ModelListProps) => { const ModelList = (props: ModelListProps) => {
const { selectedModelId, setSelectedModelId } = props; const { selectedModelId, setSelectedModelId } = props;
const { t } = useTranslation(); const { t } = useTranslation();
const [nameFilter, setNameFilter] = useState<string>(''); const [nameFilter, setNameFilter] = useState<string>('');
const [modelFormatFilter, setModelFormatFilter] = const [modelFormatFilter, setModelFormatFilter] =
useState<ModelFormat>('images'); useState<CombinedModelFormat>('images');
const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, { const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({ selectFromResult: ({ data }) => ({
filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter), filteredDiffusersModels: modelsFilter(
data,
'main',
'diffusers',
nameFilter
),
}), }),
}); });
const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, { const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({ selectFromResult: ({ data }) => ({
filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter), filteredCheckpointModels: modelsFilter(
data,
'main',
'checkpoint',
nameFilter
),
}),
});
const { filteredLoraModels } = useGetLoRAModelsQuery(undefined, {
selectFromResult: ({ data }) => ({
filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter),
}), }),
}); });
@ -68,6 +90,13 @@ const ModelList = (props: ModelListProps) => {
> >
{t('modelManager.checkpointModels')} {t('modelManager.checkpointModels')}
</IAIButton> </IAIButton>
<IAIButton
size="sm"
onClick={() => setModelFormatFilter('lora')}
isChecked={modelFormatFilter === 'lora'}
>
{t('modelManager.loraModels')}
</IAIButton>
</ButtonGroup> </ButtonGroup>
<IAIInput <IAIInput
@ -118,6 +147,24 @@ const ModelList = (props: ModelListProps) => {
</Flex> </Flex>
</StyledModelContainer> </StyledModelContainer>
)} )}
{['images', 'lora'].includes(modelFormatFilter) &&
filteredLoraModels.length > 0 && (
<StyledModelContainer>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
LoRAs
</Text>
{filteredLoraModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
)}
</Flex> </Flex>
</Flex> </Flex>
</Flex> </Flex>
@ -126,12 +173,13 @@ const ModelList = (props: ModelListProps) => {
export default ModelList; export default ModelList;
const modelsFilter = ( const modelsFilter = <T extends MainModelConfigEntity | LoRAModelConfigEntity>(
data: EntityState<MainModelConfigEntity> | undefined, data: EntityState<T> | undefined,
model_format: ModelFormat, model_type: ModelType,
model_format: ModelFormat | undefined,
nameFilter: string nameFilter: string
) => { ) => {
const filteredModels: MainModelConfigEntity[] = []; const filteredModels: T[] = [];
forEach(data?.entities, (model) => { forEach(data?.entities, (model) => {
if (!model) { if (!model) {
return; return;
@ -141,9 +189,11 @@ const modelsFilter = (
.toLowerCase() .toLowerCase()
.includes(nameFilter.toLowerCase()); .includes(nameFilter.toLowerCase());
const matchesFormat = model.model_format === model_format; const matchesFormat =
model_format === undefined || model.model_format === model_format;
const matchesType = model.model_type === model_type;
if (matchesFilter && matchesFormat) { if (matchesFilter && matchesFormat && matchesType) {
filteredModels.push(model); filteredModels.push(model);
} }
}); });

View File

@ -9,29 +9,26 @@ import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
import { import {
MainModelConfigEntity, MainModelConfigEntity,
LoRAModelConfigEntity,
useDeleteMainModelsMutation, useDeleteMainModelsMutation,
useDeleteLoRAModelsMutation,
} from 'services/api/endpoints/models'; } from 'services/api/endpoints/models';
type ModelListItemProps = { type ModelListItemProps = {
model: MainModelConfigEntity; model: MainModelConfigEntity | LoRAModelConfigEntity;
isSelected: boolean; isSelected: boolean;
setSelectedModelId: (v: string | undefined) => void; setSelectedModelId: (v: string | undefined) => void;
}; };
const modelBaseTypeMap = {
'sd-1': 'SD1',
'sd-2': 'SD2',
sdxl: 'SDXL',
'sdxl-refiner': 'SDXLR',
};
export default function ModelListItem(props: ModelListItemProps) { export default function ModelListItem(props: ModelListItemProps) {
const isBusy = useAppSelector(selectIsBusy); const isBusy = useAppSelector(selectIsBusy);
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const [deleteMainModel] = useDeleteMainModelsMutation(); const [deleteMainModel] = useDeleteMainModelsMutation();
const [deleteLoRAModel] = useDeleteLoRAModelsMutation();
const { model, isSelected, setSelectedModelId } = props; const { model, isSelected, setSelectedModelId } = props;
@ -40,7 +37,10 @@ export default function ModelListItem(props: ModelListItemProps) {
}, [model.id, setSelectedModelId]); }, [model.id, setSelectedModelId]);
const handleModelDelete = useCallback(() => { const handleModelDelete = useCallback(() => {
deleteMainModel(model) const method = { main: deleteMainModel, lora: deleteLoRAModel }[
model.model_type
];
method(model)
.unwrap() .unwrap()
.then((_) => { .then((_) => {
dispatch( dispatch(
@ -60,14 +60,21 @@ export default function ModelListItem(props: ModelListItemProps) {
title: `${t('modelManager.modelDeleteFailed')}: ${ title: `${t('modelManager.modelDeleteFailed')}: ${
model.model_name model.model_name
}`, }`,
status: 'success', status: 'error',
}) })
) )
); );
} }
}); });
setSelectedModelId(undefined); setSelectedModelId(undefined);
}, [deleteMainModel, model, setSelectedModelId, dispatch, t]); }, [
deleteMainModel,
deleteLoRAModel,
model,
setSelectedModelId,
dispatch,
t,
]);
return ( return (
<Flex sx={{ gap: 2, alignItems: 'center', w: 'full' }}> <Flex sx={{ gap: 2, alignItems: 'center', w: 'full' }}>
@ -100,8 +107,8 @@ export default function ModelListItem(props: ModelListItemProps) {
<Flex gap={4} alignItems="center"> <Flex gap={4} alignItems="center">
<Badge minWidth={14} p={0.5} fontSize="sm" variant="solid"> <Badge minWidth={14} p={0.5} fontSize="sm" variant="solid">
{ {
modelBaseTypeMap[ MODEL_TYPE_SHORT_MAP[
model.base_model as keyof typeof modelBaseTypeMap model.base_model as keyof typeof MODEL_TYPE_SHORT_MAP
] ]
} }
</Badge> </Badge>

View File

@ -52,9 +52,17 @@ type UpdateMainModelArg = {
body: MainModelConfig; body: MainModelConfig;
}; };
type UpdateLoRAModelArg = {
base_model: BaseModelType;
model_name: string;
body: LoRAModelConfig;
};
type UpdateMainModelResponse = type UpdateMainModelResponse =
paths['/api/v1/models/{base_model}/{model_type}/{model_name}']['patch']['responses']['200']['content']['application/json']; paths['/api/v1/models/{base_model}/{model_type}/{model_name}']['patch']['responses']['200']['content']['application/json'];
type UpdateLoRAModelResponse = UpdateMainModelResponse;
type DeleteMainModelArg = { type DeleteMainModelArg = {
base_model: BaseModelType; base_model: BaseModelType;
model_name: string; model_name: string;
@ -62,6 +70,10 @@ type DeleteMainModelArg = {
type DeleteMainModelResponse = void; type DeleteMainModelResponse = void;
type DeleteLoRAModelArg = DeleteMainModelArg;
type DeleteLoRAModelResponse = void;
type ConvertMainModelArg = { type ConvertMainModelArg = {
base_model: BaseModelType; base_model: BaseModelType;
model_name: string; model_name: string;
@ -320,6 +332,31 @@ export const modelsApi = api.injectEndpoints({
); );
}, },
}), }),
updateLoRAModels: build.mutation<
UpdateLoRAModelResponse,
UpdateLoRAModelArg
>({
query: ({ base_model, model_name, body }) => {
return {
url: `models/${base_model}/lora/${model_name}`,
method: 'PATCH',
body: body,
};
},
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
}),
deleteLoRAModels: build.mutation<
DeleteLoRAModelResponse,
DeleteLoRAModelArg
>({
query: ({ base_model, model_name }) => {
return {
url: `models/${base_model}/lora/${model_name}`,
method: 'DELETE',
};
},
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
}),
getControlNetModels: build.query< getControlNetModels: build.query<
EntityState<ControlNetModelConfigEntity>, EntityState<ControlNetModelConfigEntity>,
void void
@ -467,6 +504,8 @@ export const {
useAddMainModelsMutation, useAddMainModelsMutation,
useConvertMainModelsMutation, useConvertMainModelsMutation,
useMergeMainModelsMutation, useMergeMainModelsMutation,
useDeleteLoRAModelsMutation,
useUpdateLoRAModelsMutation,
useSyncModelsMutation, useSyncModelsMutation,
useGetModelsInFolderQuery, useGetModelsInFolderQuery,
useGetCheckpointConfigsQuery, useGetCheckpointConfigsQuery,

View File

@ -5562,12 +5562,6 @@ export type components = {
* @enum {string} * @enum {string}
*/ */
StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/** /**
* StableDiffusionXLModelFormat * StableDiffusionXLModelFormat
* @description An enumeration. * @description An enumeration.
@ -5580,6 +5574,12 @@ export type components = {
* @enum {string} * @enum {string}
*/ */
ControlNetModelFormat: "checkpoint" | "diffusers"; ControlNetModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
}; };
responses: never; responses: never;
parameters: never; parameters: never;

View File

@ -42,8 +42,13 @@ export type ControlField = components['schemas']['ControlField'];
// Model Configs // Model Configs
export type LoRAModelConfig = components['schemas']['LoRAModelConfig']; export type LoRAModelConfig = components['schemas']['LoRAModelConfig'];
export type VaeModelConfig = components['schemas']['VaeModelConfig']; export type VaeModelConfig = components['schemas']['VaeModelConfig'];
export type ControlNetModelCheckpointConfig =
components['schemas']['ControlNetModelCheckpointConfig'];
export type ControlNetModelDiffusersConfig =
components['schemas']['ControlNetModelDiffusersConfig'];
export type ControlNetModelConfig = export type ControlNetModelConfig =
components['schemas']['ControlNetModelConfig']; | ControlNetModelCheckpointConfig
| ControlNetModelDiffusersConfig;
export type TextualInversionModelConfig = export type TextualInversionModelConfig =
components['schemas']['TextualInversionModelConfig']; components['schemas']['TextualInversionModelConfig'];
export type DiffusersModelConfig = export type DiffusersModelConfig =

View File

@ -1 +1 @@
__version__ = "3.0.1rc1" __version__ = "3.0.1rc2"

View File

@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "InvokeAI" name = "InvokeAI"
description = "An implementation of Stable Diffusion which provides various new features and options to aid the image generation process" description = "An implementation of Stable Diffusion which provides various new features and options to aid the image generation process"
requires-python = ">=3.9, <3.11" requires-python = ">=3.9, <3.12"
readme = { content-type = "text/markdown", file = "README.md" } readme = { content-type = "text/markdown", file = "README.md" }
keywords = ["stable-diffusion", "AI"] keywords = ["stable-diffusion", "AI"]
dynamic = ["version"] dynamic = ["version"]
@ -32,16 +32,16 @@ classifiers = [
'Topic :: Scientific/Engineering :: Image Processing', 'Topic :: Scientific/Engineering :: Image Processing',
] ]
dependencies = [ dependencies = [
"accelerate~=0.16", "accelerate~=0.21.0",
"albumentations", "albumentations",
"click", "click",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel==2.0.0", "compel~=2.0.0",
"controlnet-aux>=0.0.6", "controlnet-aux>=0.0.6",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26 "timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"datasets", "datasets",
"diffusers[torch]~=0.18.1", "diffusers[torch]~=0.19.0",
"dnspython==2.2.1", "dnspython~=2.4.0",
"dynamicprompts", "dynamicprompts",
"easing-functions", "easing-functions",
"einops", "einops",
@ -54,37 +54,37 @@ dependencies = [
"flask_cors==3.0.10", "flask_cors==3.0.10",
"flask_socketio==5.3.0", "flask_socketio==5.3.0",
"flaskwebgui==1.0.3", "flaskwebgui==1.0.3",
"gfpgan==1.3.8",
"huggingface-hub>=0.11.1", "huggingface-hub>=0.11.1",
"invisible-watermark>=0.2.0", # needed to install SDXL base and refiner using their repo_ids "invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids
"matplotlib", # needed for plotting of Penner easing functions "matplotlib", # needed for plotting of Penner easing functions
"mediapipe", # needed for "mediapipeface" controlnet model "mediapipe", # needed for "mediapipeface" controlnet model
"npyscreen", "npyscreen",
"numpy<1.24", "numpy==1.24.4",
"omegaconf", "omegaconf",
"opencv-python", "opencv-python",
"picklescan", "picklescan",
"pillow", "pillow",
"prompt-toolkit", "prompt-toolkit",
"pympler==1.0.1", "pydantic==1.10.10",
"pympler~=1.0.1",
"pypatchmatch", "pypatchmatch",
'pyperclip', 'pyperclip',
"pyreadline3", "pyreadline3",
"python-multipart==0.0.6", "python-multipart",
"pytorch-lightning==1.7.7", "pytorch-lightning",
"realesrgan", "realesrgan",
"requests==2.28.2", "requests~=2.28.2",
"rich~=13.3", "rich~=13.3",
"safetensors~=0.3.0", "safetensors~=0.3.0",
"scikit-image>=0.19", "scikit-image~=0.21.0",
"send2trash", "send2trash",
"test-tube>=0.7.5", "test-tube~=0.7.5",
"torch~=2.0.0", "torch~=2.0.1",
"torchvision>=0.14.1", "torchvision~=0.15.2",
"torchmetrics==0.11.4", "torchmetrics~=1.0.1",
"torchsde==0.2.5", "torchsde~=0.2.5",
"transformers~=4.31.0", "transformers~=4.31.0",
"uvicorn[standard]==0.21.1", "uvicorn[standard]~=0.21.1",
"windows-curses; sys_platform=='win32'", "windows-curses; sys_platform=='win32'",
] ]

View File

@ -1,8 +1,16 @@
#!/bin/env python #!/bin/env python
import argparse
import sys import sys
from pathlib import Path from pathlib import Path
from invokeai.backend.model_management.model_probe import ModelProbe from invokeai.backend.model_management.model_probe import ModelProbe
info = ModelProbe().probe(Path(sys.argv[1])) parser = argparse.ArgumentParser(description="Probe model type")
parser.add_argument(
"model_path",
type=Path,
)
args = parser.parse_args()
info = ModelProbe().probe(args.model_path)
print(info) print(info)