feat: refactor the gpt3 api and support multi line completion

This commit is contained in:
Lucas.Xu 2023-01-09 12:31:26 +08:00
parent 310236dca0
commit fa0a334d6c
8 changed files with 230 additions and 151 deletions

View File

@ -4,7 +4,7 @@ import 'package:appflowy_editor/appflowy_editor.dart';
import 'package:appflowy_editor_plugins/appflowy_editor_plugins.dart';
import 'package:example/plugin/AI/continue_to_write.dart';
import 'package:example/plugin/AI/auto_completion.dart';
import 'package:example/plugin/AI/getgpt3completions.dart';
import 'package:example/plugin/AI/gpt3.dart';
import 'package:example/plugin/AI/smart_edit.dart';
import 'package:flutter/material.dart';

View File

@ -1,5 +1,5 @@
import 'package:appflowy_editor/appflowy_editor.dart';
import 'package:example/plugin/AI/getgpt3completions.dart';
import 'package:example/plugin/AI/gpt3.dart';
import 'package:example/plugin/AI/text_robot.dart';
import 'package:flutter/material.dart';
import 'package:flutter/services.dart';
@ -37,12 +37,18 @@ SelectionMenuItem autoCompletionMenuItem = SelectionMenuItem(
Navigator.of(context).pop();
// fetch the result and insert it
final textRobot = TextRobot(editorState: editorState);
getGPT3Completion(apiKey, controller.text, '', (result) async {
await textRobot.insertText(
result,
inputType: TextRobotInputType.character,
);
});
const gpt3 = GPT3APIClient(apiKey: apiKey);
gpt3.getGPT3Completion(
controller.text,
'',
onResult: (result) async {
await textRobot.insertText(
result,
inputType: TextRobotInputType.character,
);
},
onError: () async {},
);
} else if (key.logicalKey == LogicalKeyboardKey.escape) {
Navigator.of(context).pop();
}

View File

@ -1,5 +1,5 @@
import 'package:appflowy_editor/appflowy_editor.dart';
import 'package:example/plugin/AI/getgpt3completions.dart';
import 'package:example/plugin/AI/gpt3.dart';
import 'package:example/plugin/AI/text_robot.dart';
import 'package:flutter/material.dart';
@ -14,35 +14,96 @@ SelectionMenuItem continueToWriteMenuItem = SelectionMenuItem(
),
keywords: ['continue to write'],
handler: ((editorState, menuService, context) async {
// get the current text
// Two cases
// 1. if there is content in the text node where the cursor is located,
// then we use the current text content as data.
// 2. if there is no content in the text node where the cursor is located,
// then we use the previous / next text node's content as data.
final selection =
editorState.service.selectionService.currentSelection.value;
final textNodes = editorState.service.selectionService.currentSelectedNodes;
if (selection == null || !selection.isCollapsed || textNodes.length != 1) {
if (selection == null || !selection.isCollapsed) {
return;
}
final textNode = textNodes.first as TextNode;
final prompt = textNode.delta.slice(0, selection.startIndex).toPlainText();
final suffix = textNode.delta
.slice(
selection.endIndex,
textNode.toPlainText().length,
)
.toPlainText();
final textNodes = editorState.service.selectionService.currentSelectedNodes
.whereType<TextNode>();
if (textNodes.isEmpty) {
return;
}
final textRobot = TextRobot(editorState: editorState);
getGPT3Completion(
apiKey,
const gpt3 = GPT3APIClient(apiKey: apiKey);
final textNode = textNodes.first;
var prompt = '';
var suffix = '';
void continueToWriteInSingleLine() {
prompt = textNode.delta.slice(0, selection.startIndex).toPlainText();
suffix = textNode.delta
.slice(
selection.endIndex,
textNode.toPlainText().length,
)
.toPlainText();
}
void continueToWriteInMulitLines() {
final parent = textNode.parent;
if (parent != null) {
for (final node in parent.children) {
if (node is! TextNode || node.toPlainText().isEmpty) continue;
if (node.path < textNode.path) {
prompt += '${node.toPlainText()}\n';
} else if (node.path > textNode.path) {
suffix += '${node.toPlainText()}\n';
}
}
}
}
if (textNodes.first.toPlainText().isNotEmpty) {
continueToWriteInSingleLine();
} else {
continueToWriteInMulitLines();
}
if (prompt.isEmpty && suffix.isEmpty) {
return;
}
late final BuildContext diglogContext;
showDialog(
context: context,
builder: (context) {
diglogContext = context;
return AlertDialog(
content: Column(
mainAxisSize: MainAxisSize.min,
children: const [
CircularProgressIndicator(),
SizedBox(height: 10),
Text('Loading'),
],
),
);
},
);
gpt3.getGPT3Completion(
prompt,
suffix,
(result) async {
if (result == '\\n') {
await editorState.insertNewLineAtCurrentSelection();
} else {
await textRobot.insertText(
result,
inputType: TextRobotInputType.word,
);
}
onResult: (result) async {
Navigator.of(diglogContext).pop(true);
await textRobot.insertText(
result,
inputType: TextRobotInputType.word,
);
},
onError: () async {
Navigator.of(diglogContext).pop(true);
},
);
}),

View File

@ -1,111 +0,0 @@
import 'package:http/http.dart' as http;
import 'dart:async';
import 'dart:convert';
// Please fill in your own API key
const apiKey = '';
Future<void> getGPT3Completion(
String apiKey,
String prompt,
String suffix,
Future<void> Function(String)
onData, // callback function to handle streaming data
{
int maxTokens = 200,
double temperature = .3,
bool stream = true,
}) async {
final data = {
'prompt': prompt,
'suffix': suffix,
'max_tokens': maxTokens,
'temperature': temperature,
'stream': stream, // set stream parameter to true
};
final headers = {
'Authorization': apiKey,
'Content-Type': 'application/json',
};
final request = http.Request(
'POST',
Uri.parse('https://api.openai.com/v1/engines/text-davinci-003/completions'),
);
request.body = json.encode(data);
request.headers.addAll(headers);
final httpResponse = await request.send();
if (httpResponse.statusCode == 200) {
await for (final chunk in httpResponse.stream) {
var result = utf8.decode(chunk).split('text": "');
var text = '';
if (result.length > 1) {
result = result[1].split('",');
if (result.isNotEmpty) {
text = result.first;
}
}
final processedText = text
.replaceAll('\\r', '\r')
.replaceAll('\\t', '\t')
.replaceAll('\\b', '\b')
.replaceAll('\\f', '\f')
.replaceAll('\\v', '\v')
.replaceAll('\\\'', '\'')
.replaceAll('"', '"')
.replaceAll('\\0', '0')
.replaceAll('\\1', '1')
.replaceAll('\\2', '2')
.replaceAll('\\3', '3')
.replaceAll('\\4', '4')
.replaceAll('\\5', '5')
.replaceAll('\\6', '6')
.replaceAll('\\7', '7')
.replaceAll('\\8', '8')
.replaceAll('\\9', '9');
await onData(processedText);
}
}
}
Future<void> getGPT3Edit(
String apiKey,
String input,
String instruction, {
required Future<void> Function(List<String> result) onResult,
required Future<void> Function() onError,
int n = 1,
double temperature = .3,
}) async {
final data = {
'model': 'text-davinci-edit-001',
'input': input,
'instruction': instruction,
'temperature': temperature,
'n': n,
};
final headers = {
'Authorization': apiKey,
'Content-Type': 'application/json',
};
var response = await http.post(
Uri.parse('https://api.openai.com/v1/edits'),
headers: headers,
body: json.encode(data),
);
if (response.statusCode == 200) {
final result = json.decode(response.body);
final choices = result['choices'];
if (choices != null && choices is List) {
onResult(choices.map((e) => e['text'] as String).toList());
}
} else {
onError();
}
}

View File

@ -0,0 +1,119 @@
import 'package:http/http.dart' as http;
import 'dart:async';
import 'dart:convert';
// Please fill in your own API key
const apiKey = '';
enum GPT3API {
completion,
edit,
}
extension on GPT3API {
Uri get uri {
switch (this) {
case GPT3API.completion:
return Uri.parse('https://api.openai.com/v1/completions');
case GPT3API.edit:
return Uri.parse('https://api.openai.com/v1/edits');
}
}
}
class GPT3APIClient {
const GPT3APIClient({
required this.apiKey,
});
final String apiKey;
/// Get completions from GPT-3
///
/// [prompt] is the prompt text
/// [suffix] is the suffix text
/// [onResult] is the callback function to handle the result
/// [maxTokens] is the maximum number of tokens to generate
/// [temperature] is the temperature of the model
///
/// See https://beta.openai.com/docs/api-reference/completions/create
Future<void> getGPT3Completion(
String prompt,
String suffix, {
required Future<void> Function(String result) onResult,
required Future<void> Function() onError,
int maxTokens = 200,
double temperature = .3,
}) async {
final data = {
'model': 'text-davinci-003',
'prompt': prompt,
'suffix': suffix,
'max_tokens': maxTokens,
'temperature': temperature,
'stream': false,
};
final headers = {
'Authorization': apiKey,
'Content-Type': 'application/json',
};
final response = await http.post(
GPT3API.completion.uri,
headers: headers,
body: json.encode(data),
);
if (response.statusCode == 200) {
final result = json.decode(response.body);
final choices = result['choices'];
if (choices != null && choices is List) {
for (final choice in choices) {
final text = choice['text'];
await onResult(text);
}
}
} else {
await onError();
}
}
Future<void> getGPT3Edit(
String apiKey,
String input,
String instruction, {
required Future<void> Function(List<String> result) onResult,
required Future<void> Function() onError,
int n = 1,
double temperature = .3,
}) async {
final data = {
'model': 'text-davinci-edit-001',
'input': input,
'instruction': instruction,
'temperature': temperature,
'n': n,
};
final headers = {
'Authorization': apiKey,
'Content-Type': 'application/json',
};
final response = await http.post(
Uri.parse('https://api.openai.com/v1/edits'),
headers: headers,
body: json.encode(data),
);
if (response.statusCode == 200) {
final result = json.decode(response.body);
final choices = result['choices'];
if (choices != null && choices is List) {
await onResult(choices.map((e) => e['text'] as String).toList());
}
} else {
await onError();
}
}
}

View File

@ -1,5 +1,5 @@
import 'package:appflowy_editor/appflowy_editor.dart';
import 'package:example/plugin/AI/getgpt3completions.dart';
import 'package:example/plugin/AI/gpt3.dart';
import 'package:flutter/material.dart';
import 'package:flutter/services.dart';
@ -52,6 +52,8 @@ class _SmartEditWidgetState extends State<SmartEditWidget> {
var result = '';
final gpt3 = const GPT3APIClient(apiKey: apiKey);
Iterable<TextNode> get currentSelectedTextNodes =>
widget.editorState.service.selectionService.currentSelectedNodes
.whereType<TextNode>();
@ -180,7 +182,7 @@ class _SmartEditWidgetState extends State<SmartEditWidget> {
},
);
getGPT3Edit(
gpt3.getGPT3Edit(
apiKey,
text,
inputEventController.text,

View File

@ -18,7 +18,7 @@ class TextRobot {
String text, {
TextRobotInputType inputType = TextRobotInputType.character,
}) async {
final lines = text.split('\\n');
final lines = text.split('\n');
for (final line in lines) {
if (line.isEmpty) continue;
switch (inputType) {
@ -32,10 +32,13 @@ class TextRobot {
}
break;
case TextRobotInputType.word:
await editorState.insertTextAtCurrentSelection(
line,
);
await Future.delayed(delay, () {});
final words = line.split(' ').map((e) => '$e ');
for (final word in words) {
await editorState.insertTextAtCurrentSelection(
word,
);
await Future.delayed(delay, () {});
}
break;
}

View File

@ -171,10 +171,9 @@ extension TextTransaction on Transaction {
void splitText(TextNode textNode, int offset) {
final delta = textNode.delta;
final first = delta.slice(0, offset);
final second = delta.slice(offset, delta.length);
final path = textNode.path.next;
updateText(textNode, first);
deleteText(textNode, offset, delta.length);
insertNode(
path,
TextNode(