From b080c185f43e118f84f6da13f2e39fd5f94a7a18 Mon Sep 17 00:00:00 2001 From: Jing Hua Date: Thu, 9 Mar 2023 21:00:29 +0800 Subject: [PATCH] feat: auto generate title --- .../ChatContent/Message/NewMessageButton.tsx | 13 +---- src/constants/chat.ts | 9 ++- src/hooks/useAddChat.ts | 8 +-- src/hooks/useInitialiseNewChat.ts | 14 +---- src/hooks/useSubmit.ts | 57 ++++++++++++++++++- src/store/migrate.ts | 5 ++ src/store/store.ts | 11 ++++ src/types/chat.ts | 12 ++++ 8 files changed, 96 insertions(+), 33 deletions(-) create mode 100644 src/store/migrate.ts diff --git a/src/components/Chat/ChatContent/Message/NewMessageButton.tsx b/src/components/Chat/ChatContent/Message/NewMessageButton.tsx index 9df6d28..4c9fd83 100644 --- a/src/components/Chat/ChatContent/Message/NewMessageButton.tsx +++ b/src/components/Chat/ChatContent/Message/NewMessageButton.tsx @@ -4,7 +4,7 @@ import useStore from '@store/store'; import PlusIcon from '@icon/PlusIcon'; import { ChatInterface } from '@type/chat'; -import { defaultChatConfig, defaultSystemMessage } from '@constants/chat'; +import { generateDefaultChat } from '@constants/chat'; const NewMessageButton = React.memo( ({ messageIndex }: { messageIndex: number }) => { @@ -24,16 +24,7 @@ const NewMessageButton = React.memo( title = `New Chat ${titleIndex}`; } - updatedChats.unshift({ - title, - messages: [ - { - role: 'system', - content: defaultSystemMessage, - }, - ], - config: { ...defaultChatConfig }, - }); + updatedChats.unshift(generateDefaultChat(title)); setChats(updatedChats); setCurrentChatIndex(0); } diff --git a/src/constants/chat.ts b/src/constants/chat.ts index 2ea2df8..e7a6779 100644 --- a/src/constants/chat.ts +++ b/src/constants/chat.ts @@ -1,4 +1,4 @@ -import { ConfigInterface } from '@type/chat'; +import { ChatInterface, ConfigInterface } from '@type/chat'; const date = new Date(); const dateString = @@ -17,3 +17,10 @@ export const defaultChatConfig: ConfigInterface = { temperature: 1, presence_penalty: 0, }; + +export const generateDefaultChat = (title?: string): ChatInterface => ({ + title: title ? title : 'New Chat', + messages: [{ role: 'system', content: defaultSystemMessage }], + config: { ...defaultChatConfig }, + titleSet: false, +}); diff --git a/src/hooks/useAddChat.ts b/src/hooks/useAddChat.ts index 95045e7..5382d75 100644 --- a/src/hooks/useAddChat.ts +++ b/src/hooks/useAddChat.ts @@ -1,6 +1,6 @@ import React from 'react'; import useStore from '@store/store'; -import { defaultChatConfig, defaultSystemMessage } from '@constants/chat'; +import { generateDefaultChat } from '@constants/chat'; import { ChatInterface } from '@type/chat'; const useAddChat = () => { @@ -19,11 +19,7 @@ const useAddChat = () => { title = `New Chat ${titleIndex}`; } - updatedChats.unshift({ - title, - messages: [{ role: 'system', content: defaultSystemMessage }], - config: { ...defaultChatConfig }, - }); + updatedChats.unshift(generateDefaultChat(title)); setChats(updatedChats); setCurrentChatIndex(0); } diff --git a/src/hooks/useInitialiseNewChat.ts b/src/hooks/useInitialiseNewChat.ts index 2f1a678..9997544 100644 --- a/src/hooks/useInitialiseNewChat.ts +++ b/src/hooks/useInitialiseNewChat.ts @@ -1,24 +1,14 @@ import React from 'react'; import useStore from '@store/store'; import { MessageInterface } from '@type/chat'; -import { defaultChatConfig, defaultSystemMessage } from '@constants/chat'; +import { generateDefaultChat } from '@constants/chat'; const useInitialiseNewChat = () => { const setChats = useStore((state) => state.setChats); const setCurrentChatIndex = useStore((state) => state.setCurrentChatIndex); const initialiseNewChat = () => { - const message: MessageInterface = { - role: 'system', - content: defaultSystemMessage, - }; - setChats([ - { - title: 'New Chat', - messages: [message], - config: { ...defaultChatConfig }, - }, - ]); + setChats([generateDefaultChat()]); setCurrentChatIndex(0); }; diff --git a/src/hooks/useSubmit.ts b/src/hooks/useSubmit.ts index fc75a9e..c810624 100644 --- a/src/hooks/useSubmit.ts +++ b/src/hooks/useSubmit.ts @@ -1,10 +1,17 @@ import React from 'react'; import useStore from '@store/store'; -import { ChatInterface } from '@type/chat'; -import { getChatCompletionStream as getChatCompletionStreamFree } from '@api/freeApi'; -import { getChatCompletionStream as getChatCompletionStreamCustom } from '@api/customApi'; +import { ChatInterface, MessageInterface } from '@type/chat'; +import { + getChatCompletionStream as getChatCompletionStreamFree, + getChatCompletion as getChatCompletionFree, +} from '@api/freeApi'; +import { + getChatCompletionStream as getChatCompletionStreamCustom, + getChatCompletion as getChatCompletionCustom, +} from '@api/customApi'; import { parseEventSource } from '@api/helper'; import { limitMessageTokens } from '@utils/messageUtils'; +import { defaultChatConfig } from '@constants/chat'; const useSubmit = () => { const error = useStore((state) => state.error); @@ -16,6 +23,22 @@ const useSubmit = () => { const currentChatIndex = useStore((state) => state.currentChatIndex); const setChats = useStore((state) => state.setChats); + const generateTitle = async ( + message: MessageInterface[] + ): Promise => { + let data; + if (apiFree) { + data = await getChatCompletionFree( + useStore.getState().apiFreeEndpoint, + message, + defaultChatConfig + ); + } else if (apiKey) { + data = await getChatCompletionCustom(apiKey, message, defaultChatConfig); + } + return data.choices[0].message.content; + }; + const handleSubmit = async () => { const chats = useStore.getState().chats; if (generating || !chats) return; @@ -94,6 +117,34 @@ const useSubmit = () => { reader.releaseLock(); stream.cancel(); } + + // generate title for new chats + const currChats = useStore.getState().chats; + if (currChats && !currChats[currentChatIndex]?.titleSet) { + const messages_length = currChats[currentChatIndex].messages.length; + const assistant_message = + currChats[currentChatIndex].messages[messages_length - 1].content; + const user_message = + currChats[currentChatIndex].messages[messages_length - 2].content; + + const message: MessageInterface = { + role: 'user', + content: `Generate a title in less than 6 words for the following message:\nUser: ${user_message}\nAssistant: ${assistant_message}`, + }; + + let title = await generateTitle([message]); + if (title.startsWith('"') && title.endsWith('"')) { + title = title.slice(1, -1); + } + const updatedChats: ChatInterface[] = JSON.parse( + JSON.stringify(useStore.getState().chats) + ); + updatedChats[currentChatIndex].title = title; + updatedChats[currentChatIndex].titleSet = true; + setChats(updatedChats); + console.log(message); + console.log(title); + } } catch (e: unknown) { const err = (e as Error).message; console.log(err); diff --git a/src/store/migrate.ts b/src/store/migrate.ts new file mode 100644 index 0000000..1b777a2 --- /dev/null +++ b/src/store/migrate.ts @@ -0,0 +1,5 @@ +import { LocalStorageInterface } from '@type/chat'; + +export const migrateV0 = (persistedState: LocalStorageInterface) => { + persistedState.chats.forEach((chat) => (chat.titleSet = false)); +}; diff --git a/src/store/store.ts b/src/store/store.ts index f41c3fa..405f3c2 100644 --- a/src/store/store.ts +++ b/src/store/store.ts @@ -4,6 +4,8 @@ import { ChatSlice, createChatSlice } from './chat-slice'; import { InputSlice, createInputSlice } from './input-slice'; import { AuthSlice, createAuthSlice } from './auth-slice'; import { ConfigSlice, createConfigSlice } from './config-slice'; +import { LocalStorageInterface } from '@type/chat'; +import { migrateV0 } from './migrate'; export type StoreState = ChatSlice & InputSlice & AuthSlice & ConfigSlice; @@ -30,6 +32,15 @@ const useStore = create()( apiFreeEndpoint: state.apiFreeEndpoint, theme: state.theme, }), + version: 1, + migrate: (persistedState, version) => { + switch (version) { + case 0: + migrateV0(persistedState as LocalStorageInterface); + break; + } + return persistedState as StoreState; + }, } ) ); diff --git a/src/types/chat.ts b/src/types/chat.ts index ee2b8b2..118406c 100644 --- a/src/types/chat.ts +++ b/src/types/chat.ts @@ -1,3 +1,5 @@ +import { Theme } from './theme'; + export type Role = 'user' | 'assistant' | 'system'; export const roles: Role[] = ['user', 'assistant', 'system']; @@ -10,9 +12,19 @@ export interface ChatInterface { title: string; messages: MessageInterface[]; config: ConfigInterface; + titleSet: boolean; } export interface ConfigInterface { temperature: number; presence_penalty: number; } + +export interface LocalStorageInterface { + chats: ChatInterface[]; + currentChatIndex: number; + apiKey: string; + apiFree: boolean; + apiFreeEndpoint: string; + theme: Theme; +}