Feature/smart edit v2 (#1880)

* feat: add edit api to openai client

* feat: add translation

* chore: format code

* feat: add smart edit plugin

* fix: close http.client when dispose

* fix: insert openai result to wrong position

* feat: optimize the replace text logic

* test: add test for normalize and getTextInSelection function

* chore: update error message
This commit is contained in:
Lucas.Xu
2023-02-28 14:34:13 +08:00
committed by GitHub
parent 1945b0fe05
commit 085ef8f668
22 changed files with 701 additions and 44 deletions

View File

@ -1,5 +1,7 @@
import 'dart:convert';
import 'package:appflowy/plugins/document/presentation/plugins/openai/service/text_edit.dart';
import 'text_completion.dart';
import 'package:dartz/dartz.dart';
import 'dart:async';
@ -38,6 +40,18 @@ abstract class OpenAIRepository {
int maxTokens = 50,
double temperature = .3,
});
/// Get edits from GPT-3
///
/// [input] is the input text
/// [instruction] is the instruction text
/// [temperature] is the temperature of the model
///
Future<Either<OpenAIError, TextEditResponse>> getEdits({
required String input,
required String instruction,
double temperature = 0.3,
});
}
class HttpOpenAIRepository implements OpenAIRepository {
@ -70,7 +84,7 @@ class HttpOpenAIRepository implements OpenAIRepository {
'stream': false,
};
final response = await http.post(
final response = await client.post(
OpenAIRequestType.textCompletion.uri,
headers: headers,
body: json.encode(parameters),
@ -82,4 +96,30 @@ class HttpOpenAIRepository implements OpenAIRepository {
return Left(OpenAIError.fromJson(json.decode(response.body)['error']));
}
}
@override
Future<Either<OpenAIError, TextEditResponse>> getEdits({
required String input,
required String instruction,
double temperature = 0.3,
}) async {
final parameters = {
'model': 'text-davinci-edit-001',
'input': input,
'instruction': instruction,
'temperature': temperature,
};
final response = await client.post(
OpenAIRequestType.textEdit.uri,
headers: headers,
body: json.encode(parameters),
);
if (response.statusCode == 200) {
return Right(TextEditResponse.fromJson(json.decode(response.body)));
} else {
return Left(OpenAIError.fromJson(json.decode(response.body)['error']));
}
}
}

View File

@ -0,0 +1,24 @@
import 'package:freezed_annotation/freezed_annotation.dart';
part 'text_edit.freezed.dart';
part 'text_edit.g.dart';
@freezed
class TextEditChoice with _$TextEditChoice {
factory TextEditChoice({
required String text,
required int index,
}) = _TextEditChoice;
factory TextEditChoice.fromJson(Map<String, Object?> json) =>
_$TextEditChoiceFromJson(json);
}
@freezed
class TextEditResponse with _$TextEditResponse {
const factory TextEditResponse({
required List<TextEditChoice> choices,
}) = _TextEditResponse;
factory TextEditResponse.fromJson(Map<String, Object?> json) =>
_$TextEditResponseFromJson(json);
}

View File

@ -167,7 +167,7 @@ class _AutoCompletionInputState extends State<_AutoCompletionInput> {
text: '',
style: Theme.of(context).textTheme.bodyMedium?.copyWith(
color: Colors.grey,
), // FIXME: color
),
),
],
),
@ -185,7 +185,7 @@ class _AutoCompletionInputState extends State<_AutoCompletionInput> {
text: LocaleKeys.button_esc.tr(),
style: Theme.of(context).textTheme.bodyMedium?.copyWith(
color: Colors.grey,
), // FIXME: color
),
),
],
),
@ -198,7 +198,6 @@ class _AutoCompletionInputState extends State<_AutoCompletionInput> {
Widget _buildFooterWidget(BuildContext context) {
return Row(
children: [
// FIXME: l10n
FlowyRichTextButton(
TextSpan(
children: [

View File

@ -0,0 +1,36 @@
import 'package:appflowy/workspace/presentation/widgets/pop_up_action.dart';
import 'package:flutter/material.dart';
import 'package:appflowy/generated/locale_keys.g.dart';
import 'package:easy_localization/easy_localization.dart';
enum SmartEditAction {
summarize,
fixSpelling;
String get toInstruction {
switch (this) {
case SmartEditAction.summarize:
return 'Summarize';
case SmartEditAction.fixSpelling:
return 'Fix the spelling mistakes';
}
}
}
class SmartEditActionWrapper extends ActionCell {
final SmartEditAction inner;
SmartEditActionWrapper(this.inner);
Widget? icon(Color iconColor) => null;
@override
String get name {
switch (inner) {
case SmartEditAction.summarize:
return LocaleKeys.document_plugins_smartEditSummarize.tr();
case SmartEditAction.fixSpelling:
return LocaleKeys.document_plugins_smartEditFixSpelling.tr();
}
}
}

View File

@ -0,0 +1,277 @@
import 'package:appflowy/plugins/document/presentation/plugins/openai/service/error.dart';
import 'package:appflowy/plugins/document/presentation/plugins/openai/service/openai_client.dart';
import 'package:appflowy/plugins/document/presentation/plugins/openai/service/text_edit.dart';
import 'package:appflowy/plugins/document/presentation/plugins/openai/widgets/smart_edit_action.dart';
import 'package:appflowy/user/application/user_service.dart';
import 'package:appflowy_editor/appflowy_editor.dart';
import 'package:flowy_infra_ui/style_widget/button.dart';
import 'package:flowy_infra_ui/style_widget/text.dart';
import 'package:flowy_infra_ui/widget/spacing.dart';
import 'package:flutter/material.dart';
import 'package:flutter/services.dart';
import 'package:appflowy/generated/locale_keys.g.dart';
import 'package:easy_localization/easy_localization.dart';
import 'package:http/http.dart' as http;
import 'package:dartz/dartz.dart' as dartz;
import 'package:appflowy/util/either_extension.dart';
const String kSmartEditType = 'smart_edit_input';
const String kSmartEditInstructionType = 'smart_edit_instruction';
const String kSmartEditInputType = 'smart_edit_input';
class SmartEditInputBuilder extends NodeWidgetBuilder<Node> {
@override
NodeValidator<Node> get nodeValidator => (node) {
return SmartEditAction.values.map((e) => e.toInstruction).contains(
node.attributes[kSmartEditInstructionType],
) &&
node.attributes[kSmartEditInputType] is String;
};
@override
Widget build(NodeWidgetContext<Node> context) {
return _SmartEditInput(
key: context.node.key,
node: context.node,
editorState: context.editorState,
);
}
}
class _SmartEditInput extends StatefulWidget {
final Node node;
final EditorState editorState;
const _SmartEditInput({
Key? key,
required this.node,
required this.editorState,
});
@override
State<_SmartEditInput> createState() => _SmartEditInputState();
}
class _SmartEditInputState extends State<_SmartEditInput> {
String get instruction => widget.node.attributes[kSmartEditInstructionType];
String get input => widget.node.attributes[kSmartEditInputType];
final focusNode = FocusNode();
final client = http.Client();
dartz.Either<OpenAIError, TextEditResponse>? result;
bool loading = true;
@override
void initState() {
super.initState();
widget.editorState.service.keyboardService?.disable(showCursor: true);
focusNode.requestFocus();
focusNode.addListener(() {
if (!focusNode.hasFocus) {
widget.editorState.service.keyboardService?.enable();
}
});
_requestEdits().then(
(value) => setState(() {
result = value;
loading = false;
}),
);
}
@override
void dispose() {
client.close();
super.dispose();
}
@override
Widget build(BuildContext context) {
return Card(
elevation: 5,
color: Theme.of(context).colorScheme.surface,
child: Container(
margin: const EdgeInsets.all(10),
child: _buildSmartEditPanel(context),
),
);
}
Widget _buildSmartEditPanel(BuildContext context) {
return RawKeyboardListener(
focusNode: focusNode,
onKey: (RawKeyEvent event) async {
if (event is! RawKeyDownEvent) return;
if (event.logicalKey == LogicalKeyboardKey.enter) {
await _onReplace();
await _onExit();
} else if (event.logicalKey == LogicalKeyboardKey.escape) {
await _onExit();
}
},
child: Column(
mainAxisSize: MainAxisSize.min,
crossAxisAlignment: CrossAxisAlignment.start,
children: [
_buildHeaderWidget(context),
const Space(0, 10),
_buildResultWidget(context),
const Space(0, 10),
_buildInputFooterWidget(context),
],
),
);
}
Widget _buildHeaderWidget(BuildContext context) {
return Row(
children: [
FlowyText.medium(
LocaleKeys.document_plugins_smartEditTitleName.tr(),
fontSize: 14,
),
const Spacer(),
FlowyText.regular(
LocaleKeys.document_plugins_autoGeneratorLearnMore.tr(),
),
],
);
}
Widget _buildResultWidget(BuildContext context) {
final loading = SizedBox.fromSize(
size: const Size.square(14),
child: const CircularProgressIndicator(),
);
if (result == null) {
return loading;
}
return result!.fold((error) {
return Flexible(
child: Text(
error.message,
style: Theme.of(context).textTheme.bodyMedium?.copyWith(
color: Colors.red,
),
),
);
}, (response) {
return Flexible(
child: Text(
response.choices.map((e) => e.text).join('\n'),
),
);
});
}
Widget _buildInputFooterWidget(BuildContext context) {
return Row(
children: [
FlowyRichTextButton(
TextSpan(
children: [
TextSpan(
text: '${LocaleKeys.button_replace.tr()} ',
style: Theme.of(context).textTheme.bodyMedium,
),
TextSpan(
text: '',
style: Theme.of(context).textTheme.bodyMedium?.copyWith(
color: Colors.grey,
),
),
],
),
onPressed: () {
_onReplace();
_onExit();
},
),
const Space(10, 0),
FlowyRichTextButton(
TextSpan(
children: [
TextSpan(
text: '${LocaleKeys.button_Cancel.tr()} ',
style: Theme.of(context).textTheme.bodyMedium,
),
TextSpan(
text: LocaleKeys.button_esc.tr(),
style: Theme.of(context).textTheme.bodyMedium?.copyWith(
color: Colors.grey,
),
),
],
),
onPressed: () async => await _onExit(),
),
],
);
}
Future<void> _onReplace() async {
final selection = widget.editorState.service.selectionService
.currentSelection.value?.normalized;
final selectedNodes = widget
.editorState.service.selectionService.currentSelectedNodes.normalized
.whereType<TextNode>();
if (selection == null || result == null || result!.isLeft()) {
return;
}
final texts = result!.asRight().choices.first.text.split('\n')
..removeWhere((element) => element.isEmpty);
assert(texts.length == selectedNodes.length);
final transaction = widget.editorState.transaction;
transaction.replaceTexts(
selectedNodes.toList(growable: false),
selection,
texts,
);
return widget.editorState.apply(transaction);
}
Future<void> _onExit() async {
final transaction = widget.editorState.transaction;
transaction.deleteNode(widget.node);
return widget.editorState.apply(
transaction,
options: const ApplyOptions(
recordRedo: false,
recordUndo: false,
),
);
}
Future<dartz.Either<OpenAIError, TextEditResponse>> _requestEdits() async {
final result = await UserBackendService.getCurrentUserProfile();
return result.fold((userProfile) async {
final openAIRepository = HttpOpenAIRepository(
client: client,
apiKey: userProfile.openaiKey,
);
final edits = await openAIRepository.getEdits(
input: input,
instruction: instruction,
);
return edits.fold((error) async {
return dartz.Left(
OpenAIError(
message:
LocaleKeys.document_plugins_smartEditCouldNotFetchResult.tr(),
),
);
}, (textEdit) async {
return dartz.Right(textEdit);
});
}, (error) async {
// error
return dartz.Left(
OpenAIError(
message: LocaleKeys.document_plugins_smartEditCouldNotFetchKey.tr(),
),
);
});
}
}

View File

@ -0,0 +1,93 @@
import 'package:appflowy/plugins/document/presentation/plugins/openai/widgets/smart_edit_action.dart';
import 'package:appflowy/plugins/document/presentation/plugins/openai/widgets/smart_edit_node_widget.dart';
import 'package:appflowy/workspace/presentation/widgets/pop_up_action.dart';
import 'package:appflowy_editor/appflowy_editor.dart';
import 'package:appflowy_popover/appflowy_popover.dart';
import 'package:flowy_infra_ui/style_widget/icon_button.dart';
import 'package:flutter/material.dart';
ToolbarItem smartEditItem = ToolbarItem(
id: 'appflowy.toolbar.smart_edit',
type: 0, // headmost
validator: (editorState) {
// All selected nodes must be text.
final nodes = editorState.service.selectionService.currentSelectedNodes;
return nodes.whereType<TextNode>().length == nodes.length;
},
itemBuilder: (context, editorState) {
return _SmartEditWidget(
editorState: editorState,
);
},
);
class _SmartEditWidget extends StatefulWidget {
const _SmartEditWidget({
required this.editorState,
});
final EditorState editorState;
@override
State<_SmartEditWidget> createState() => _SmartEditWidgetState();
}
class _SmartEditWidgetState extends State<_SmartEditWidget> {
@override
Widget build(BuildContext context) {
return PopoverActionList<SmartEditActionWrapper>(
direction: PopoverDirection.bottomWithLeftAligned,
actions: SmartEditAction.values
.map((action) => SmartEditActionWrapper(action))
.toList(),
buildChild: (controller) {
return FlowyIconButton(
tooltipText: 'Smart Edit',
preferBelow: false,
icon: const Icon(
Icons.edit,
size: 14,
),
onPressed: () {
controller.show();
},
);
},
onSelected: (action, controller) {
controller.close();
final selection =
widget.editorState.service.selectionService.currentSelection.value;
if (selection == null) {
return;
}
final textNodes = widget
.editorState.service.selectionService.currentSelectedNodes
.whereType<TextNode>()
.toList(growable: false);
final input = widget.editorState.getTextInSelection(
textNodes.normalized,
selection.normalized,
);
final transaction = widget.editorState.transaction;
transaction.insertNode(
selection.normalized.end.path.next,
Node(
type: kSmartEditType,
attributes: {
kSmartEditInstructionType: action.inner.toInstruction,
kSmartEditInputType: input,
},
),
);
widget.editorState.apply(
transaction,
options: const ApplyOptions(
recordUndo: false,
recordRedo: false,
),
withUpdateCursor: false,
);
},
);
}
}