Merge branch 'main' into refactor/cleanup-root-detection

This commit is contained in:
Lincoln Stein 2023-08-02 09:46:46 -04:00 committed by GitHub
commit 0db1e97119
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 206 additions and 145 deletions

View File

@ -184,8 +184,9 @@ the command `npm install -g yarn` if needed)
6. Configure InvokeAI and install a starting set of image generation models (you only need to do this once):
```terminal
invokeai-configure
invokeai-configure --root .
```
Don't miss the dot at the end!
7. Launch the web server (do it every time you run InvokeAI):
@ -193,15 +194,9 @@ the command `npm install -g yarn` if needed)
invokeai-web
```
8. Build Node.js assets
8. Point your browser to http://localhost:9090 to bring up the web interface.
```terminal
cd invokeai/frontend/web/
yarn vite build
```
9. Point your browser to http://localhost:9090 to bring up the web interface.
10. Type `banana sushi` in the box on the top left and click `Invoke`.
9. Type `banana sushi` in the box on the top left and click `Invoke`.
Be sure to activate the virtual environment each time before re-launching InvokeAI,
using `source .venv/bin/activate` or `.venv\Scripts\activate`.

View File

@ -192,9 +192,11 @@ manager, please follow these steps:
your outputs.
```terminal
invokeai-configure
invokeai-configure --root .
```
Don't miss the dot at the end of the command!
The script `invokeai-configure` will interactively guide you through the
process of downloading and installing the weights files needed for InvokeAI.
Note that the main Stable Diffusion weights file is protected by a license
@ -225,12 +227,6 @@ manager, please follow these steps:
!!! warning "Make sure that the virtual environment is activated, which should create `(.venv)` in front of your prompt!"
=== "CLI"
```bash
invokeai
```
=== "local Webserver"
```bash
@ -243,6 +239,12 @@ manager, please follow these steps:
invokeai --web --host 0.0.0.0
```
=== "CLI"
```bash
invokeai
```
If you choose the run the web interface, point your browser at
http://localhost:9090 in order to load the GUI.

View File

@ -34,6 +34,10 @@
cudaPackages.cudnn
cudaPackages.cuda_nvrtc
cudatoolkit
pkgconfig
libconfig
cmake
blas
freeglut
glib
gperf
@ -42,6 +46,12 @@
libGLU
linuxPackages.nvidia_x11
python
(opencv4.override {
enableGtk3 = true;
enableFfmpeg = true;
enableCuda = true;
enableUnfree = true;
})
stdenv.cc
stdenv.cc.cc.lib
xorg.libX11

View File

@ -1,4 +1,4 @@
import { ButtonGroup, Flex, Text } from '@chakra-ui/react';
import { ButtonGroup, Flex, Spinner, Text } from '@chakra-ui/react';
import { EntityState } from '@reduxjs/toolkit';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
@ -6,23 +6,23 @@ import { forEach } from 'lodash-es';
import type { ChangeEvent, PropsWithChildren } from 'react';
import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants';
import {
LoRAModelConfigEntity,
MainModelConfigEntity,
OnnxModelConfigEntity,
useGetLoRAModelsQuery,
useGetMainModelsQuery,
useGetOnnxModelsQuery,
useGetLoRAModelsQuery,
LoRAModelConfigEntity,
} from 'services/api/endpoints/models';
import ModelListItem from './ModelListItem';
import { ALL_BASE_MODELS } from 'services/api/constants';
type ModelListProps = {
selectedModelId: string | undefined;
setSelectedModelId: (name: string | undefined) => void;
};
type ModelFormat = 'images' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx';
type ModelFormat = 'all' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx';
type ModelType = 'main' | 'lora' | 'onnx';
@ -33,47 +33,63 @@ const ModelList = (props: ModelListProps) => {
const { t } = useTranslation();
const [nameFilter, setNameFilter] = useState<string>('');
const [modelFormatFilter, setModelFormatFilter] =
useState<CombinedModelFormat>('images');
useState<CombinedModelFormat>('all');
const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({
const { filteredDiffusersModels, isLoadingDiffusersModels } =
useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data, isLoading }) => ({
filteredDiffusersModels: modelsFilter(
data,
'main',
'diffusers',
nameFilter
),
isLoadingDiffusersModels: isLoading,
}),
});
const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({
const { filteredCheckpointModels, isLoadingCheckpointModels } =
useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data, isLoading }) => ({
filteredCheckpointModels: modelsFilter(
data,
'main',
'checkpoint',
nameFilter
),
isLoadingCheckpointModels: isLoading,
}),
});
const { filteredLoraModels } = useGetLoRAModelsQuery(undefined, {
selectFromResult: ({ data }) => ({
const { filteredLoraModels, isLoadingLoraModels } = useGetLoRAModelsQuery(
undefined,
{
selectFromResult: ({ data, isLoading }) => ({
filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter),
isLoadingLoraModels: isLoading,
}),
});
}
);
const { filteredOnnxModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({
const { filteredOnnxModels, isLoadingOnnxModels } = useGetOnnxModelsQuery(
ALL_BASE_MODELS,
{
selectFromResult: ({ data, isLoading }) => ({
filteredOnnxModels: modelsFilter(data, 'onnx', 'onnx', nameFilter),
isLoadingOnnxModels: isLoading,
}),
});
}
);
const { filteredOliveModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({
const { filteredOliveModels, isLoadingOliveModels } = useGetOnnxModelsQuery(
ALL_BASE_MODELS,
{
selectFromResult: ({ data, isLoading }) => ({
filteredOliveModels: modelsFilter(data, 'onnx', 'olive', nameFilter),
isLoadingOliveModels: isLoading,
}),
});
}
);
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setNameFilter(e.target.value);
@ -84,8 +100,8 @@ const ModelList = (props: ModelListProps) => {
<Flex flexDirection="column" gap={4} paddingInlineEnd={4}>
<ButtonGroup isAttached>
<IAIButton
onClick={() => setModelFormatFilter('images')}
isChecked={modelFormatFilter === 'images'}
onClick={() => setModelFormatFilter('all')}
isChecked={modelFormatFilter === 'all'}
size="sm"
>
{t('modelManager.allModels')}
@ -139,95 +155,76 @@ const ModelList = (props: ModelListProps) => {
maxHeight={window.innerHeight - 280}
overflow="scroll"
>
{['images', 'diffusers'].includes(modelFormatFilter) &&
{/* Diffusers List */}
{isLoadingDiffusersModels && (
<FetchingModelsLoader loadingMessage="Loading Diffusers..." />
)}
{['all', 'diffusers'].includes(modelFormatFilter) &&
!isLoadingDiffusersModels &&
filteredDiffusersModels.length > 0 && (
<StyledModelContainer>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
Diffusers
</Text>
{filteredDiffusersModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
<ModelListWrapper
title="Diffusers"
modelList={filteredDiffusersModels}
selected={{ selectedModelId, setSelectedModelId }}
key="diffusers"
/>
))}
</Flex>
</StyledModelContainer>
)}
{['images', 'checkpoint'].includes(modelFormatFilter) &&
{/* Checkpoints List */}
{isLoadingCheckpointModels && (
<FetchingModelsLoader loadingMessage="Loading Checkpoints..." />
)}
{['all', 'checkpoint'].includes(modelFormatFilter) &&
!isLoadingCheckpointModels &&
filteredCheckpointModels.length > 0 && (
<StyledModelContainer>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
Checkpoints
</Text>
{filteredCheckpointModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
<ModelListWrapper
title="Checkpoints"
modelList={filteredCheckpointModels}
selected={{ selectedModelId, setSelectedModelId }}
key="checkpoints"
/>
))}
</Flex>
</StyledModelContainer>
)}
{['images', 'olive'].includes(modelFormatFilter) &&
filteredOliveModels.length > 0 && (
<StyledModelContainer>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
Olives
</Text>
{filteredOliveModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
{/* LoRAs List */}
{isLoadingLoraModels && (
<FetchingModelsLoader loadingMessage="Loading LoRAs..." />
)}
{['images', 'onnx'].includes(modelFormatFilter) &&
filteredOnnxModels.length > 0 && (
<StyledModelContainer>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
Onnx
</Text>
{filteredOnnxModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
)}
{['images', 'lora'].includes(modelFormatFilter) &&
{['all', 'lora'].includes(modelFormatFilter) &&
!isLoadingLoraModels &&
filteredLoraModels.length > 0 && (
<StyledModelContainer>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
LoRAs
</Text>
{filteredLoraModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
<ModelListWrapper
title="LoRAs"
modelList={filteredLoraModels}
selected={{ selectedModelId, setSelectedModelId }}
key="loras"
/>
)}
{/* Olive List */}
{isLoadingOliveModels && (
<FetchingModelsLoader loadingMessage="Loading Olives..." />
)}
{['all', 'olive'].includes(modelFormatFilter) &&
!isLoadingOliveModels &&
filteredOliveModels.length > 0 && (
<ModelListWrapper
title="Olives"
modelList={filteredOliveModels}
selected={{ selectedModelId, setSelectedModelId }}
key="olive"
/>
)}
{/* Onnx List */}
{isLoadingOnnxModels && (
<FetchingModelsLoader loadingMessage="Loading ONNX..." />
)}
{['all', 'onnx'].includes(modelFormatFilter) &&
!isLoadingOnnxModels &&
filteredOnnxModels.length > 0 && (
<ModelListWrapper
title="ONNX"
modelList={filteredOnnxModels}
selected={{ selectedModelId, setSelectedModelId }}
key="onnx"
/>
))}
</Flex>
</StyledModelContainer>
)}
</Flex>
</Flex>
@ -287,3 +284,52 @@ const StyledModelContainer = (props: PropsWithChildren) => {
</Flex>
);
};
type ModelListWrapperProps = {
title: string;
modelList:
| MainModelConfigEntity[]
| LoRAModelConfigEntity[]
| OnnxModelConfigEntity[];
selected: ModelListProps;
};
function ModelListWrapper(props: ModelListWrapperProps) {
const { title, modelList, selected } = props;
return (
<StyledModelContainer>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
{title}
</Text>
{modelList.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selected.selectedModelId === model.id}
setSelectedModelId={selected.setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
);
}
function FetchingModelsLoader({ loadingMessage }: { loadingMessage?: string }) {
return (
<StyledModelContainer>
<Flex
justifyContent="center"
alignItems="center"
flexDirection="column"
p={4}
gap={8}
>
<Spinner />
<Text variant="subtext">
{loadingMessage ? loadingMessage : 'Fetching...'}
</Text>
</Flex>
</StyledModelContainer>
);
}

View File

@ -181,7 +181,7 @@ export const modelsApi = api.injectEndpoints({
},
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ id: 'OnnxModel', type: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
];
if (result) {
@ -266,6 +266,7 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
],
}),
importMainModels: build.mutation<
@ -282,6 +283,7 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
],
}),
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
@ -295,6 +297,7 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
],
}),
deleteMainModels: build.mutation<
@ -310,6 +313,7 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
],
}),
convertMainModels: build.mutation<
@ -326,6 +330,7 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
],
}),
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
@ -339,6 +344,7 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
],
}),
syncModels: build.mutation<SyncModelsResponse, void>({
@ -351,6 +357,7 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
],
}),
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({

View File

@ -13,7 +13,8 @@ import { socketSubscribed, socketUnsubscribed } from './actions';
export const socketMiddleware = () => {
let areListenersSet = false;
let socketUrl = `ws://${window.location.host}`;
const wsProtocol = window.location.protocol === 'https:' ? 'wss' : 'ws';
let socketUrl = `${wsProtocol}://${window.location.host}`;
const socketOptions: Parameters<typeof io>[0] = {
timeout: 60000,