feat: generate image by Stability AI / openAI (#3642)

This commit is contained in:
Lucas.Xu 2023-10-09 23:14:24 +08:00 committed by GitHub
parent 41d4351176
commit dace02d34d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 640 additions and 40 deletions

View File

@ -0,0 +1,105 @@
import 'dart:async';
import 'package:appflowy/generated/locale_keys.g.dart';
import 'package:appflowy/plugins/document/presentation/editor_plugins/openai/service/error.dart';
import 'package:appflowy/plugins/document/presentation/editor_plugins/openai/service/openai_client.dart';
import 'package:appflowy/startup/startup.dart';
import 'package:dartz/dartz.dart' hide State;
import 'package:easy_localization/easy_localization.dart';
import 'package:flowy_infra_ui/flowy_infra_ui.dart';
import 'package:flutter/material.dart';
class OpenAIImageWidget extends StatefulWidget {
const OpenAIImageWidget({
super.key,
required this.onSelectNetworkImage,
});
final void Function(String url) onSelectNetworkImage;
@override
State<OpenAIImageWidget> createState() => _OpenAIImageWidgetState();
}
class _OpenAIImageWidgetState extends State<OpenAIImageWidget> {
Future<Either<OpenAIError, List<String>>>? future;
String query = '';
@override
Widget build(BuildContext context) {
return Column(
mainAxisSize: MainAxisSize.min,
children: [
Row(
mainAxisSize: MainAxisSize.min,
children: [
Expanded(
child: FlowyTextField(
autoFocus: true,
hintText: LocaleKeys.document_imageBlock_ai_placeholder.tr(),
onChanged: (value) => query = value,
onEditingComplete: _search,
),
),
const HSpace(4.0),
FlowyButton(
useIntrinsicWidth: true,
text: FlowyText(
LocaleKeys.search_label.tr(),
),
onTap: _search,
),
],
),
const VSpace(12.0),
if (future != null)
Expanded(
child: FutureBuilder(
future: future,
builder: (context, value) {
final data = value.data;
if (!value.hasData ||
value.connectionState != ConnectionState.done ||
data == null) {
return const CircularProgressIndicator.adaptive();
}
return data.fold(
(l) => Center(
child: FlowyText(
l.message,
maxLines: 3,
textAlign: TextAlign.center,
),
),
(r) => GridView.count(
crossAxisCount: 3,
mainAxisSpacing: 16.0,
crossAxisSpacing: 10.0,
childAspectRatio: 4 / 3,
children: r
.map(
(e) => GestureDetector(
onTap: () => widget.onSelectNetworkImage(e),
child: Image.network(e),
),
)
.toList(),
),
);
},
),
)
],
);
}
void _search() async {
final openAI = await getIt.getAsync<OpenAIRepository>();
setState(() {
future = openAI.generateImage(
prompt: query,
n: 6,
);
});
}
}

View File

@ -0,0 +1,121 @@
import 'dart:async';
import 'dart:convert';
import 'dart:io';
import 'package:appflowy/generated/locale_keys.g.dart';
import 'package:appflowy/plugins/document/presentation/editor_plugins/stability_ai/stability_ai_client.dart';
import 'package:appflowy/plugins/document/presentation/editor_plugins/stability_ai/stability_ai_error.dart';
import 'package:appflowy/startup/startup.dart';
import 'package:dartz/dartz.dart' hide State;
import 'package:easy_localization/easy_localization.dart';
import 'package:flowy_infra/uuid.dart';
import 'package:flowy_infra_ui/flowy_infra_ui.dart';
import 'package:flutter/material.dart';
import 'package:path/path.dart' as p;
import 'package:path_provider/path_provider.dart';
class StabilityAIImageWidget extends StatefulWidget {
const StabilityAIImageWidget({
super.key,
required this.onSelectImage,
});
final void Function(String url) onSelectImage;
@override
State<StabilityAIImageWidget> createState() => _StabilityAIImageWidgetState();
}
class _StabilityAIImageWidgetState extends State<StabilityAIImageWidget> {
Future<Either<StabilityAIRequestError, List<String>>>? future;
String query = '';
@override
Widget build(BuildContext context) {
return Column(
mainAxisSize: MainAxisSize.min,
children: [
Row(
mainAxisSize: MainAxisSize.min,
children: [
Expanded(
child: FlowyTextField(
autoFocus: true,
hintText: LocaleKeys
.document_imageBlock_stability_ai_placeholder
.tr(),
onChanged: (value) => query = value,
onEditingComplete: _search,
),
),
const HSpace(4.0),
FlowyButton(
useIntrinsicWidth: true,
text: FlowyText(
LocaleKeys.search_label.tr(),
),
onTap: _search,
),
],
),
const VSpace(12.0),
if (future != null)
Expanded(
child: FutureBuilder(
future: future,
builder: (context, value) {
final data = value.data;
if (!value.hasData ||
value.connectionState != ConnectionState.done ||
data == null) {
return const CircularProgressIndicator.adaptive();
}
return data.fold(
(l) => Center(
child: FlowyText(
l.message,
maxLines: 3,
textAlign: TextAlign.center,
),
),
(r) => GridView.count(
crossAxisCount: 3,
mainAxisSpacing: 16.0,
crossAxisSpacing: 10.0,
childAspectRatio: 4 / 3,
children: r.map(
(e) {
final base64Image = base64Decode(e);
return GestureDetector(
onTap: () async {
final tempDirectory = await getTemporaryDirectory();
final path = p.join(
tempDirectory.path,
'${uuid()}.png',
);
File(path).writeAsBytesSync(base64Image);
widget.onSelectImage(path);
},
child: Image.memory(base64Image),
);
},
).toList(),
),
);
},
),
)
],
);
}
void _search() async {
final stabilityAI = await getIt.getAsync<StabilityAIRepository>();
setState(() {
future = stabilityAI.generateImage(
prompt: query,
n: 6,
);
});
}
}

View File

@ -1,7 +1,10 @@
import 'package:appflowy/generated/locale_keys.g.dart';
import 'package:appflowy/plugins/document/presentation/editor_plugins/image/embed_image_url_widget.dart';
import 'package:appflowy/plugins/document/presentation/editor_plugins/image/open_ai_image_widget.dart';
import 'package:appflowy/plugins/document/presentation/editor_plugins/image/stability_ai_image_widget.dart';
import 'package:appflowy/plugins/document/presentation/editor_plugins/image/unsplash_image_widget.dart';
import 'package:appflowy/plugins/document/presentation/editor_plugins/image/upload_image_file_widget.dart';
import 'package:appflowy/user/application/user_service.dart';
import 'package:easy_localization/easy_localization.dart';
import 'package:flowy_infra_ui/flowy_infra_ui.dart';
import 'package:flowy_infra_ui/style_widget/hover.dart';
@ -11,7 +14,8 @@ enum UploadImageType {
local,
url,
unsplash,
ai;
stabilityAI,
openAI;
String get description {
switch (this) {
@ -21,8 +25,10 @@ enum UploadImageType {
return LocaleKeys.document_imageBlock_embedLink_label.tr();
case UploadImageType.unsplash:
return 'Unsplash';
case UploadImageType.ai:
return 'Generate from AI';
case UploadImageType.openAI:
return LocaleKeys.document_imageBlock_ai_label.tr();
case UploadImageType.stabilityAI:
return LocaleKeys.document_imageBlock_stability_ai_label.tr();
}
}
}
@ -43,11 +49,39 @@ class UploadImageMenu extends StatefulWidget {
class _UploadImageMenuState extends State<UploadImageMenu> {
int currentTabIndex = 0;
List<UploadImageType> values = UploadImageType.values;
bool supportOpenAI = false;
bool supportStabilityAI = false;
@override
void initState() {
super.initState();
UserBackendService.getCurrentUserProfile().then(
(value) {
final supportOpenAI = value.fold(
(l) => false,
(r) => r.openaiKey.isNotEmpty,
);
final supportStabilityAI = value.fold(
(l) => false,
(r) => r.stabilityAiKey.isNotEmpty,
);
if (supportOpenAI != this.supportOpenAI ||
supportStabilityAI != this.supportStabilityAI) {
setState(() {
this.supportOpenAI = supportOpenAI;
this.supportStabilityAI = supportStabilityAI;
});
}
},
);
}
@override
Widget build(BuildContext context) {
return DefaultTabController(
length: 3, // UploadImageType.values.length, // ai is not implemented yet
length: values.length,
child: Column(
mainAxisSize: MainAxisSize.min,
children: [
@ -62,10 +96,7 @@ class _UploadImageMenuState extends State<UploadImageMenu> {
),
padding: EdgeInsets.zero,
// splashBorderRadius: BorderRadius.circular(4),
tabs: UploadImageType.values
.where(
(element) => element != UploadImageType.ai,
) // ai is not implemented yet
tabs: values
.map(
(e) => FlowyHover(
style: const HoverStyle(borderRadius: BorderRadius.zero),
@ -115,8 +146,39 @@ class _UploadImageMenuState extends State<UploadImageMenu> {
),
),
);
case UploadImageType.ai:
return const FlowyText.medium('ai');
case UploadImageType.openAI:
return supportOpenAI
? Expanded(
child: Padding(
padding: const EdgeInsets.all(8.0),
child: OpenAIImageWidget(
onSelectNetworkImage: widget.onSubmit,
),
),
)
: Padding(
padding: const EdgeInsets.all(8.0),
child: FlowyText(
LocaleKeys.document_imageBlock_pleaseInputYourOpenAIKey.tr(),
),
);
case UploadImageType.stabilityAI:
return supportStabilityAI
? Expanded(
child: Padding(
padding: const EdgeInsets.all(8.0),
child: StabilityAIImageWidget(
onSelectImage: widget.onPickFile,
),
),
)
: Padding(
padding: const EdgeInsets.all(8.0),
child: FlowyText(
LocaleKeys.document_imageBlock_pleaseInputYourStabilityAIKey
.tr(),
),
);
}
}
}

View File

@ -1,20 +1,20 @@
import 'dart:async';
import 'dart:convert';
import 'package:appflowy/plugins/document/presentation/editor_plugins/openai/service/text_edit.dart';
import 'text_completion.dart';
import 'package:dartz/dartz.dart';
import 'dart:async';
import 'package:http/http.dart' as http;
import 'error.dart';
import 'package:http/http.dart' as http;
import 'text_completion.dart';
// Please fill in your own API key
const apiKey = '';
enum OpenAIRequestType {
textCompletion,
textEdit;
textEdit,
imageGenerations;
Uri get uri {
switch (this) {
@ -22,6 +22,8 @@ enum OpenAIRequestType {
return Uri.parse('https://api.openai.com/v1/completions');
case OpenAIRequestType.textEdit:
return Uri.parse('https://api.openai.com/v1/edits');
case OpenAIRequestType.imageGenerations:
return Uri.parse('https://api.openai.com/v1/images/generations');
}
}
}
@ -64,6 +66,17 @@ abstract class OpenAIRepository {
required String instruction,
double temperature = 0.3,
});
/// Generate image from GPT-3
///
/// [prompt] is the prompt text
/// [n] is the number of images to generate
///
/// the result is a list of urls
Future<Either<OpenAIError, List<String>>> generateImage({
required String prompt,
int n = 1,
});
}
class HttpOpenAIRepository implements OpenAIRepository {
@ -228,4 +241,40 @@ class HttpOpenAIRepository implements OpenAIRepository {
return Left(OpenAIError.fromJson(json.decode(response.body)['error']));
}
}
@override
Future<Either<OpenAIError, List<String>>> generateImage({
required String prompt,
int n = 1,
}) async {
final parameters = {
'prompt': prompt,
'n': n,
'size': '512x512',
};
try {
final response = await client.post(
OpenAIRequestType.imageGenerations.uri,
headers: headers,
body: json.encode(parameters),
);
if (response.statusCode == 200) {
final data = json.decode(
utf8.decode(response.bodyBytes),
)['data'] as List;
final urls = data
.map((e) => e.values)
.expand((e) => e)
.map((e) => e.toString())
.toList();
return Right(urls);
} else {
return Left(OpenAIError.fromJson(json.decode(response.body)['error']));
}
} catch (error) {
return Left(OpenAIError(message: error.toString()));
}
}
}

View File

@ -0,0 +1,95 @@
import 'dart:async';
import 'dart:convert';
import 'package:appflowy/plugins/document/presentation/editor_plugins/stability_ai/stability_ai_error.dart';
import 'package:dartz/dartz.dart';
import 'package:http/http.dart' as http;
enum StabilityAIRequestType {
imageGenerations;
Uri get uri {
switch (this) {
case StabilityAIRequestType.imageGenerations:
return Uri.parse(
'https://api.stability.ai/v1/generation/stable-diffusion-xl-1024-v1-0/text-to-image',
);
}
}
}
abstract class StabilityAIRepository {
/// Generate image from Stability AI
///
/// [prompt] is the prompt text
/// [n] is the number of images to generate
///
/// the return value is a list of base64 encoded images
Future<Either<StabilityAIRequestError, List<String>>> generateImage({
required String prompt,
int n = 1,
});
}
class HttpStabilityAIRepository implements StabilityAIRepository {
const HttpStabilityAIRepository({
required this.client,
required this.apiKey,
});
final http.Client client;
final String apiKey;
Map<String, String> get headers => {
'Authorization': 'Bearer $apiKey',
'Content-Type': 'application/json',
};
@override
Future<Either<StabilityAIRequestError, List<String>>> generateImage({
required String prompt,
int n = 1,
}) async {
final parameters = {
'text_prompts': [
{
'text': prompt,
}
],
'samples': n,
};
try {
final response = await client.post(
StabilityAIRequestType.imageGenerations.uri,
headers: headers,
body: json.encode(parameters),
);
final data = json.decode(
utf8.decode(response.bodyBytes),
);
if (response.statusCode == 200) {
final artifacts = data['artifacts'] as List;
final base64Images = artifacts
.map(
(e) => e['base64'].toString(),
)
.toList();
return Right(base64Images);
} else {
return Left(
StabilityAIRequestError(
data['message'].toString(),
),
);
}
} catch (error) {
return Left(
StabilityAIRequestError(
error.toString(),
),
);
}
}
}

View File

@ -0,0 +1,10 @@
class StabilityAIRequestError {
final String message;
StabilityAIRequestError(this.message);
@override
String toString() {
return 'StabilityAIRequestError{message: $message}';
}
}

View File

@ -9,12 +9,13 @@ import 'package:appflowy/plugins/database_view/grid/application/grid_header_bloc
import 'package:appflowy/plugins/document/application/prelude.dart';
import 'package:appflowy/plugins/document/presentation/editor_plugins/copy_and_paste/clipboard_service.dart';
import 'package:appflowy/plugins/document/presentation/editor_plugins/openai/service/openai_client.dart';
import 'package:appflowy/plugins/document/presentation/editor_plugins/stability_ai/stability_ai_client.dart';
import 'package:appflowy/plugins/trash/application/prelude.dart';
import 'package:appflowy/startup/startup.dart';
import 'package:appflowy/user/application/auth/af_cloud_auth_service.dart';
import 'package:appflowy/user/application/auth/auth_service.dart';
import 'package:appflowy/user/application/auth/supabase_mock_auth_service.dart';
import 'package:appflowy/user/application/auth/supabase_auth_service.dart';
import 'package:appflowy/user/application/auth/supabase_mock_auth_service.dart';
import 'package:appflowy/user/application/prelude.dart';
import 'package:appflowy/user/application/reminder/reminder_bloc.dart';
import 'package:appflowy/user/application/user_listener.dart';
@ -85,6 +86,23 @@ void _resolveCommonService(
},
);
getIt.registerFactoryAsync<StabilityAIRepository>(
() async {
final result = await UserBackendService.getCurrentUserProfile();
return result.fold(
(l) {
throw Exception('Failed to get user profile: ${l.msg}');
},
(r) {
return HttpStabilityAIRepository(
client: http.Client(),
apiKey: r.stabilityAiKey,
);
},
);
},
);
getIt.registerFactory<ClipboardService>(
() => ClipboardService(),
);

View File

@ -1,10 +1,10 @@
import 'dart:async';
import 'package:appflowy_backend/protobuf/flowy-user/protobuf.dart';
import 'package:dartz/dartz.dart';
import 'package:appflowy_backend/dispatch/dispatch.dart';
import 'package:appflowy_backend/protobuf/flowy-error/errors.pb.dart';
import 'package:appflowy_backend/protobuf/flowy-folder2/workspace.pb.dart';
import 'package:appflowy_backend/protobuf/flowy-user/protobuf.dart';
import 'package:dartz/dartz.dart';
import 'package:fixnum/fixnum.dart';
class UserBackendService {
@ -26,6 +26,7 @@ class UserBackendService {
String? email,
String? iconUrl,
String? openAIKey,
String? stabilityAiKey,
}) {
final payload = UpdateUserProfilePayloadPB.create()..id = userId;
@ -49,6 +50,10 @@ class UserBackendService {
payload.openaiKey = openAIKey;
}
if (stabilityAiKey != null) {
payload.stabilityAiKey = stabilityAiKey;
}
return UserEventUpdateUserProfile(payload).send();
}

View File

@ -3,9 +3,9 @@ import 'package:appflowy/user/application/user_service.dart';
import 'package:appflowy_backend/log.dart';
import 'package:appflowy_backend/protobuf/flowy-error/errors.pb.dart';
import 'package:appflowy_backend/protobuf/flowy-user/user_profile.pb.dart';
import 'package:dartz/dartz.dart';
import 'package:flutter_bloc/flutter_bloc.dart';
import 'package:freezed_annotation/freezed_annotation.dart';
import 'package:dartz/dartz.dart';
part 'settings_user_bloc.freezed.dart';
@ -60,6 +60,16 @@ class SettingsUserViewBloc extends Bloc<SettingsUserEvent, SettingsUserState> {
);
});
},
updateUserStabilityAIKey: (stabilityAIKey) {
_userService
.updateUserProfile(stabilityAiKey: stabilityAIKey)
.then((result) {
result.fold(
(l) => null,
(err) => Log.error(err),
);
});
},
didLoadHistoricalUsers: (List<HistoricalUserPB> historicalUsers) {
emit(state.copyWith(historicalUsers: historicalUsers));
},
@ -119,6 +129,9 @@ class SettingsUserEvent with _$SettingsUserEvent {
const factory SettingsUserEvent.removeUserIcon() = _RemoveUserIcon;
const factory SettingsUserEvent.updateUserOpenAIKey(String openAIKey) =
_UpdateUserOpenaiKey;
const factory SettingsUserEvent.updateUserStabilityAIKey(
String stabilityAIKey,
) = _UpdateUserStabilityAIKey;
const factory SettingsUserEvent.didReceiveUserProfile(
UserProfilePB newUserProfile,
) = _DidReceiveUserProfile;

View File

@ -61,6 +61,8 @@ class SettingsUserView extends StatelessWidget {
const VSpace(12),
_renderCurrentOpenaiKey(context),
const VSpace(12),
_renderCurrentStabilityAIKey(context),
const VSpace(12),
_renderLoginOrLogoutButton(context, state),
const VSpace(12),
],
@ -207,9 +209,29 @@ class SettingsUserView extends StatelessWidget {
}
Widget _renderCurrentOpenaiKey(BuildContext context) {
final String openAIKey =
final String accessKey =
context.read<SettingsUserViewBloc>().state.userProfile.openaiKey;
return _OpenaiKeyInput(openAIKey);
return _AIAccessKeyInput(
accessKey: accessKey,
title: 'OpenAI Key',
hintText: LocaleKeys.settings_user_pleaseInputYourOpenAIKey.tr(),
callback: (key) => context
.read<SettingsUserViewBloc>()
.add(SettingsUserEvent.updateUserOpenAIKey(key)),
);
}
Widget _renderCurrentStabilityAIKey(BuildContext context) {
final String accessKey =
context.read<SettingsUserViewBloc>().state.userProfile.stabilityAiKey;
return _AIAccessKeyInput(
accessKey: accessKey,
title: 'Stability AI Key',
hintText: LocaleKeys.settings_user_pleaseInputYourStabilityAIKey.tr(),
callback: (key) => context
.read<SettingsUserViewBloc>()
.add(SettingsUserEvent.updateUserStabilityAIKey(key)),
);
}
Widget _avatarOverlay({
@ -379,18 +401,24 @@ class UserEmailInputState extends State<UserEmailInput> {
}
}
class _OpenaiKeyInput extends StatefulWidget {
final String openAIKey;
const _OpenaiKeyInput(
this.openAIKey, {
Key? key,
}) : super(key: key);
class _AIAccessKeyInput extends StatefulWidget {
const _AIAccessKeyInput({
required this.accessKey,
required this.title,
required this.hintText,
required this.callback,
});
final String accessKey;
final String title;
final String hintText;
final void Function(String key) callback;
@override
State<_OpenaiKeyInput> createState() => _OpenaiKeyInputState();
State<_AIAccessKeyInput> createState() => _AIAccessKeyInputState();
}
class _OpenaiKeyInputState extends State<_OpenaiKeyInput> {
class _AIAccessKeyInputState extends State<_AIAccessKeyInput> {
bool visible = false;
final textEditingController = TextEditingController();
final debounce = Debounce();
@ -399,7 +427,7 @@ class _OpenaiKeyInputState extends State<_OpenaiKeyInput> {
void initState() {
super.initState();
textEditingController.text = widget.openAIKey;
textEditingController.text = widget.accessKey;
}
@override
@ -415,12 +443,12 @@ class _OpenaiKeyInputState extends State<_OpenaiKeyInput> {
focusedBorder: UnderlineInputBorder(
borderSide: BorderSide(color: Theme.of(context).colorScheme.primary),
),
labelText: 'OpenAI Key',
labelText: widget.title,
labelStyle: Theme.of(context)
.textTheme
.titleMedium!
.copyWith(fontWeight: FontWeight.w500),
hintText: LocaleKeys.settings_user_pleaseInputYourOpenAIKey.tr(),
hintText: widget.hintText,
suffixIcon: FlowyIconButton(
width: 40,
height: 40,
@ -437,9 +465,7 @@ class _OpenaiKeyInputState extends State<_OpenaiKeyInput> {
),
onChanged: (value) {
debounce.call(() {
context
.read<SettingsUserViewBloc>()
.add(SettingsUserEvent.updateUserOpenAIKey(value));
widget.callback(value);
});
},
);

View File

@ -54,8 +54,8 @@ packages:
dependency: "direct main"
description:
path: "."
ref: af8d96b
resolved-ref: af8d96bc1aab07046f4febdd991e1787c75c6e38
ref: "0abcf7f"
resolved-ref: "0abcf7f6d273b838c895abdc17f6833540613729"
url: "https://github.com/AppFlowy-IO/appflowy-editor.git"
source: git
version: "1.4.3"

View File

@ -47,7 +47,7 @@ dependencies:
appflowy_editor:
git:
url: https://github.com/AppFlowy-IO/appflowy-editor.git
ref: 'af8d96b'
ref: "0abcf7f"
appflowy_popover:
path: packages/appflowy_popover
@ -134,7 +134,6 @@ dev_dependencies:
url_launcher_platform_interface: any
run_with_network_images: ^0.0.1
dependency_overrides:
http: ^1.0.0

View File

@ -353,6 +353,7 @@
"tooltipSelectIcon": "Select icon",
"selectAnIcon": "Select an icon",
"pleaseInputYourOpenAIKey": "please input your OpenAI key",
"pleaseInputYourStabilityAIKey": "please input your Stability AI key",
"clickToLogout": "Click to logout the current user"
},
"shortcuts": {
@ -652,6 +653,14 @@
"label": "Image URL",
"placeholder": "Enter image URL"
},
"ai": {
"label": "Generate image from OpenAI",
"placeholder": "Please input the prompt for OpenAI to generate image"
},
"stability_ai": {
"label": "Generate image from Stability AI",
"placeholder": "Please input the prompt for Stability AI to generate image"
},
"support": "Image size limit is 5MB. Supported formats: JPEG, PNG, GIF, SVG",
"error": {
"invalidImage": "Invalid image",
@ -663,7 +672,9 @@
"label": "Embed link",
"placeholder": "Paste or type an image link"
},
"searchForAnImage": "Search for an image"
"searchForAnImage": "Search for an image",
"pleaseInputYourOpenAIKey": "please input your OpenAI key in Settings page",
"pleaseInputYourStabilityAIKey": "please input your Stability AI key in Settings page"
},
"codeBlock": {
"language": {

View File

@ -116,6 +116,7 @@ where
token: token_from_client(client).await.unwrap_or("".to_string()),
icon_url: "".to_owned(),
openai_key: "".to_owned(),
stability_ai_key: "".to_owned(),
workspace_id: match profile.latest_workspace_id {
Some(w) => w.to_string(),
None => "".to_string(),

View File

@ -215,6 +215,7 @@ where
token: "".to_string(),
icon_url: "".to_string(),
openai_key: "".to_string(),
stability_ai_key: "".to_string(),
workspace_id: response.latest_workspace_id,
auth_type: AuthType::Supabase,
encryption_type: EncryptionType::from_sign(&response.encryption_sign),

View File

@ -64,6 +64,7 @@ async fn supabase_update_user_profile_test() {
password: None,
icon_url: None,
openai_key: None,
stability_ai_key: None,
encryption_sign: None,
},
)

View File

@ -0,0 +1,3 @@
-- This file should undo anything in `up.sql`
ALTER TABLE user_table
DROP COLUMN stability_ai_key;

View File

@ -0,0 +1,3 @@
-- Your SQL goes here
ALTER TABLE user_table
ADD COLUMN stability_ai_key TEXT NOT NULL DEFAULT "";

View File

@ -31,6 +31,7 @@ diesel::table! {
email -> Text,
auth_type -> Integer,
encryption_type -> Text,
stability_ai_key -> Text,
}
}

View File

@ -47,6 +47,29 @@ async fn user_update_with_name() {
assert_eq!(user_profile.name, new_name,);
}
#[tokio::test]
async fn user_update_with_ai_key() {
let sdk = FlowyCoreTest::new();
let user = sdk.init_user().await;
let openai_key = "openai_key".to_owned();
let stability_ai_key = "stability_ai_key".to_owned();
let request = UpdateUserProfilePayloadPB::new(user.id)
.openai_key(&openai_key)
.stability_ai_key(&stability_ai_key);
let _ = EventBuilder::new(sdk.clone())
.event(UpdateUserProfile)
.payload(request)
.sync_send();
let user_profile = EventBuilder::new(sdk.clone())
.event(GetUserProfile)
.sync_send()
.parse::<UserProfilePB>();
assert_eq!(user_profile.openai_key, openai_key,);
assert_eq!(user_profile.stability_ai_key, stability_ai_key,);
}
#[tokio::test]
async fn user_update_with_email() {
let sdk = FlowyCoreTest::new();

View File

@ -191,6 +191,7 @@ pub struct UserProfile {
pub token: String,
pub icon_url: String,
pub openai_key: String,
pub stability_ai_key: String,
pub workspace_id: String,
pub auth_type: AuthType,
// If the encryption_sign is not empty, which means the user has enabled the encryption.
@ -252,6 +253,7 @@ where
workspace_id: value.latest_workspace().id.to_owned(),
auth_type: auth_type.clone(),
encryption_type: value.encryption_type(),
stability_ai_key: "".to_owned(),
}
}
}
@ -264,6 +266,7 @@ pub struct UpdateUserProfileParams {
pub password: Option<String>,
pub icon_url: Option<String>,
pub openai_key: Option<String>,
pub stability_ai_key: Option<String>,
pub encryption_sign: Option<String>,
}
@ -300,6 +303,11 @@ impl UpdateUserProfileParams {
self
}
pub fn with_stability_ai_key(mut self, stability_ai_key: &str) -> Self {
self.stability_ai_key = Some(stability_ai_key.to_owned());
self
}
pub fn with_encryption_type(mut self, encryption_type: EncryptionType) -> Self {
let sign = match encryption_type {
EncryptionType::NoEncryption => "".to_string(),
@ -316,6 +324,7 @@ impl UpdateUserProfileParams {
&& self.icon_url.is_none()
&& self.openai_key.is_none()
&& self.encryption_sign.is_none()
&& self.stability_ai_key.is_none()
}
}

View File

@ -4,6 +4,7 @@ pub use user_id::*;
pub use user_name::*;
pub use user_openai_key::*;
pub use user_password::*;
pub use user_stability_ai_key::*;
// https://lexi-lambda.github.io/blog/2019/11/05/parse-don-t-validate/
mod user_email;
@ -12,3 +13,4 @@ mod user_id;
mod user_name;
mod user_openai_key;
mod user_password;
mod user_stability_ai_key;

View File

@ -0,0 +1,16 @@
use flowy_error::ErrorCode;
#[derive(Debug)]
pub struct UserStabilityAIKey(pub String);
impl UserStabilityAIKey {
pub fn parse(s: String) -> Result<UserStabilityAIKey, ErrorCode> {
Ok(Self(s))
}
}
impl AsRef<str> for UserStabilityAIKey {
fn as_ref(&self) -> &str {
&self.0
}
}

View File

@ -8,6 +8,8 @@ use crate::entities::AuthTypePB;
use crate::errors::ErrorCode;
use crate::services::entities::HistoricalUser;
use super::parser::UserStabilityAIKey;
#[derive(Default, ProtoBuf)]
pub struct UserTokenPB {
#[pb(index = 1)]
@ -51,6 +53,9 @@ pub struct UserProfilePB {
#[pb(index = 10)]
pub workspace_id: String,
#[pb(index = 11)]
pub stability_ai_key: String,
}
#[derive(ProtoBuf_Enum, Eq, PartialEq, Debug, Clone)]
@ -82,6 +87,7 @@ impl std::convert::From<UserProfile> for UserProfilePB {
encryption_sign,
encryption_type: encryption_ty,
workspace_id: user_profile.workspace_id,
stability_ai_key: user_profile.stability_ai_key,
}
}
}
@ -105,6 +111,9 @@ pub struct UpdateUserProfilePayloadPB {
#[pb(index = 6, one_of)]
pub openai_key: Option<String>,
#[pb(index = 7, one_of)]
pub stability_ai_key: Option<String>,
}
impl UpdateUserProfilePayloadPB {
@ -139,6 +148,11 @@ impl UpdateUserProfilePayloadPB {
self.openai_key = Some(openai_key.to_owned());
self
}
pub fn stability_ai_key(mut self, stability_ai_key: &str) -> Self {
self.stability_ai_key = Some(stability_ai_key.to_owned());
self
}
}
impl TryInto<UpdateUserProfileParams> for UpdateUserProfilePayloadPB {
@ -170,6 +184,11 @@ impl TryInto<UpdateUserProfileParams> for UpdateUserProfilePayloadPB {
Some(openai_key) => Some(UserOpenaiKey::parse(openai_key)?.0),
};
let stability_ai_key = match self.stability_ai_key {
None => None,
Some(stability_ai_key) => Some(UserStabilityAIKey::parse(stability_ai_key)?.0),
};
Ok(UpdateUserProfileParams {
uid: self.id,
name,
@ -178,6 +197,7 @@ impl TryInto<UpdateUserProfileParams> for UpdateUserProfilePayloadPB {
icon_url,
openai_key,
encryption_sign: None,
stability_ai_key,
})
}
}

View File

@ -18,6 +18,7 @@ pub struct UserTable {
pub(crate) email: String,
pub(crate) auth_type: i32,
pub(crate) encryption_type: String,
pub(crate) stability_ai_key: String,
}
impl UserTable {
@ -41,6 +42,7 @@ impl From<(UserProfile, AuthType)> for UserTable {
email: user_profile.email,
auth_type: auth_type as i32,
encryption_type,
stability_ai_key: user_profile.stability_ai_key,
}
}
}
@ -57,6 +59,7 @@ impl From<UserTable> for UserProfile {
workspace_id: table.workspace,
auth_type: AuthType::from(table.auth_type),
encryption_type: EncryptionType::from_str(&table.encryption_type).unwrap_or_default(),
stability_ai_key: table.stability_ai_key,
}
}
}
@ -71,6 +74,7 @@ pub struct UserTableChangeset {
pub icon_url: Option<String>,
pub openai_key: Option<String>,
pub encryption_type: Option<String>,
pub stability_ai_key: Option<String>,
}
impl UserTableChangeset {
@ -87,6 +91,7 @@ impl UserTableChangeset {
icon_url: params.icon_url,
openai_key: params.openai_key,
encryption_type,
stability_ai_key: params.stability_ai_key,
}
}
@ -100,6 +105,7 @@ impl UserTableChangeset {
icon_url: Some(user_profile.icon_url),
openai_key: Some(user_profile.openai_key),
encryption_type: Some(encryption_type),
stability_ai_key: Some(user_profile.stability_ai_key),
}
}
}