From b0bfe56fd33ffa897d94d55a6a57e8771e4546d9 Mon Sep 17 00:00:00 2001 From: Jing Hua Date: Mon, 20 Mar 2023 16:06:46 +0800 Subject: [PATCH] feat: customise default model parameters and system message Fixes #97, Fixes #89, Fixes #35 --- public/locales/en/model.json | 5 +- public/locales/zh-CN/model.json | 5 +- public/locales/zh-HK/model.json | 5 +- src/components/Chat/ChatContent/ChatTitle.tsx | 4 +- .../ChatConfigMenu/ChatConfigMenu.tsx | 168 +++++++++++++ src/components/ChatConfigMenu/index.ts | 1 + src/components/ConfigMenu/ConfigMenu.tsx | 229 +++++++++++------- src/components/SettingsMenu/SettingsMenu.tsx | 2 + src/constants/chat.ts | 11 +- src/hooks/useSubmit.ts | 6 +- src/store/config-slice.ts | 20 ++ src/store/migrate.ts | 8 +- src/store/store.ts | 2 + src/utils/chat.ts | 12 +- 14 files changed, 374 insertions(+), 104 deletions(-) create mode 100644 src/components/ChatConfigMenu/ChatConfigMenu.tsx create mode 100644 src/components/ChatConfigMenu/index.ts diff --git a/public/locales/en/model.json b/public/locales/en/model.json index 606b645..4e8e9ab 100644 --- a/public/locales/en/model.json +++ b/public/locales/en/model.json @@ -21,5 +21,8 @@ "frequencyPenalty": { "label": "Frequency Penalty", "description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. (Default: 0)" - } + }, + "defaultChatConfig": "Default Chat Config", + "defaultSystemMessage": "Default System Message", + "resetToDefault": "Reset To Default" } diff --git a/public/locales/zh-CN/model.json b/public/locales/zh-CN/model.json index c502f6f..a18d089 100644 --- a/public/locales/zh-CN/model.json +++ b/public/locales/zh-CN/model.json @@ -21,5 +21,8 @@ "frequencyPenalty": { "label": "频率惩罚", "description": "数值在 -2.0 到 2.0 之间。正值会根据新 token 在文本中的现有频率来惩罚它们,降低模型直接重复相同语句的可能性。(默认: 0)" - } + }, + "defaultChatConfig": "默认聊天配置", + "defaultSystemMessage": "默认系统消息", + "resetToDefault": "重置为默认" } diff --git a/public/locales/zh-HK/model.json b/public/locales/zh-HK/model.json index d676960..e4922c4 100644 --- a/public/locales/zh-HK/model.json +++ b/public/locales/zh-HK/model.json @@ -21,5 +21,8 @@ "frequencyPenalty": { "label": "頻率懲罰", "description": "係一個 -2.0 到 2.0 之間嘅數值。正嘅數值表示,如果 token 喺之前嘅文字中出現頻率越高,輸出嗰陣就會越大力噉懲罰佢,令到佢被揀中嘅機率降低,即係可以降低模型重複同一句説話嘅機會。(預設: 0)" - } + }, + "defaultChatConfig": "預設聊天配置", + "defaultSystemMessage": "預設系統消息", + "resetToDefault": "重置為預設" } diff --git a/src/components/Chat/ChatContent/ChatTitle.tsx b/src/components/Chat/ChatContent/ChatTitle.tsx index 3c972c8..4130b0e 100644 --- a/src/components/Chat/ChatContent/ChatTitle.tsx +++ b/src/components/Chat/ChatContent/ChatTitle.tsx @@ -4,7 +4,7 @@ import { shallow } from 'zustand/shallow'; import useStore from '@store/store'; import ConfigMenu from '@components/ConfigMenu'; import { ChatInterface, ConfigInterface } from '@type/chat'; -import { defaultChatConfig } from '@constants/chat'; +import { _defaultChatConfig } from '@constants/chat'; const ChatTitle = React.memo(() => { const { t } = useTranslation('model'); @@ -35,7 +35,7 @@ const ChatTitle = React.memo(() => { const chats = useStore.getState().chats; if (chats && chats.length > 0 && currentChatIndex !== -1 && !config) { const updatedChats: ChatInterface[] = JSON.parse(JSON.stringify(chats)); - updatedChats[currentChatIndex].config = { ...defaultChatConfig }; + updatedChats[currentChatIndex].config = { ..._defaultChatConfig }; setChats(updatedChats); } }, [currentChatIndex]); diff --git a/src/components/ChatConfigMenu/ChatConfigMenu.tsx b/src/components/ChatConfigMenu/ChatConfigMenu.tsx new file mode 100644 index 0000000..8bef767 --- /dev/null +++ b/src/components/ChatConfigMenu/ChatConfigMenu.tsx @@ -0,0 +1,168 @@ +import React, { useState } from 'react'; +import useStore from '@store/store'; +import { useTranslation } from 'react-i18next'; + +import PopupModal from '@components/PopupModal'; +import { + FrequencyPenaltySlider, + MaxTokenSlider, + ModelSelector, + PresencePenaltySlider, + TemperatureSlider, + TopPSlider, +} from '@components/ConfigMenu/ConfigMenu'; + +import { ModelOptions } from '@type/chat'; +import { _defaultChatConfig, _defaultSystemMessage } from '@constants/chat'; + +const ChatConfigMenu = () => { + const { t } = useTranslation('model'); + const [isModalOpen, setIsModalOpen] = useState(false); + return ( +
+ + {isModalOpen && } +
+ ); +}; + +const ChatConfigPopup = ({ + setIsModalOpen, +}: { + setIsModalOpen: React.Dispatch>; +}) => { + const config = useStore.getState().defaultChatConfig; + const setDefaultChatConfig = useStore((state) => state.setDefaultChatConfig); + const setDefaultSystemMessage = useStore( + (state) => state.setDefaultSystemMessage + ); + + const [_systemMessage, _setSystemMessage] = useState( + useStore.getState().defaultSystemMessage + ); + const [_model, _setModel] = useState(config.model); + const [_maxToken, _setMaxToken] = useState(config.max_tokens); + const [_temperature, _setTemperature] = useState(config.temperature); + const [_topP, _setTopP] = useState(config.top_p); + const [_presencePenalty, _setPresencePenalty] = useState( + config.presence_penalty + ); + const [_frequencyPenalty, _setFrequencyPenalty] = useState( + config.frequency_penalty + ); + + const { t } = useTranslation('model'); + + const handleSave = () => { + setDefaultChatConfig({ + model: _model, + max_tokens: _maxToken, + temperature: _temperature, + top_p: _topP, + presence_penalty: _presencePenalty, + frequency_penalty: _frequencyPenalty, + }); + setDefaultSystemMessage(_systemMessage); + setIsModalOpen(false); + }; + + const handleReset = () => { + _setModel(_defaultChatConfig.model); + _setMaxToken(_defaultChatConfig.max_tokens); + _setTemperature(_defaultChatConfig.temperature); + _setTopP(_defaultChatConfig.top_p); + _setPresencePenalty(_defaultChatConfig.presence_penalty); + _setFrequencyPenalty(_defaultChatConfig.frequency_penalty); + _setSystemMessage(_defaultSystemMessage); + }; + + return ( + +
+ + + + + + + +
+ {t('resetToDefault')} +
+
+
+ ); +}; + +const DefaultSystemChat = ({ + _systemMessage, + _setSystemMessage, +}: { + _systemMessage: string; + _setSystemMessage: React.Dispatch>; +}) => { + const { t } = useTranslation('model'); + + const handleInput = (e: React.ChangeEvent) => { + e.target.style.height = 'auto'; + e.target.style.height = `${e.target.scrollHeight}px`; + e.target.style.maxHeight = `${e.target.scrollHeight}px`; + }; + + const handleOnFocus = (e: React.FocusEvent) => { + e.target.style.height = 'auto'; + e.target.style.height = `${e.target.scrollHeight}px`; + e.target.style.maxHeight = `${e.target.scrollHeight}px`; + }; + + const handleOnBlur = (e: React.FocusEvent) => { + e.target.style.height = 'auto'; + e.target.style.maxHeight = '2.5rem'; + }; + + return ( +
+
+ {t('defaultSystemMessage')} +
+ +
+ ); +}; + +export default ChatConfigMenu; diff --git a/src/components/ChatConfigMenu/index.ts b/src/components/ChatConfigMenu/index.ts new file mode 100644 index 0000000..3c336f6 --- /dev/null +++ b/src/components/ChatConfigMenu/index.ts @@ -0,0 +1 @@ +export { default } from './ChatConfigMenu'; diff --git a/src/components/ConfigMenu/ConfigMenu.tsx b/src/components/ConfigMenu/ConfigMenu.tsx index e77ca75..9ecb9ef 100644 --- a/src/components/ConfigMenu/ConfigMenu.tsx +++ b/src/components/ConfigMenu/ConfigMenu.tsx @@ -52,92 +52,25 @@ const ConfigMenu = ({ _setMaxToken={_setMaxToken} _model={_model} /> -
- - { - _setTemperature(Number(e.target.value)); - }} - min={0} - max={2} - step={0.1} - className='w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer' - /> -
- {t('temperature.description')} -
-
-
- - { - _setTopP(Number(e.target.value)); - }} - min={0} - max={1} - step={0.05} - className='w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer' - /> -
- {t('topP.description')} -
-
-
- - { - _setPresencePenalty(Number(e.target.value)); - }} - min={-2} - max={2} - step={0.1} - className='w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer' - /> -
- {t('presencePenalty.description')} -
-
-
- - { - _setFrequencyPenalty(Number(e.target.value)); - }} - min={-2} - max={2} - step={0.1} - className='w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer' - /> -
- {t('frequencyPenalty.description')} -
-
+ + + + ); }; -const ModelSelector = ({ +export const ModelSelector = ({ _model, _setModel, }: { @@ -184,7 +117,7 @@ const ModelSelector = ({ ); }; -const MaxTokenSlider = ({ +export const MaxTokenSlider = ({ _maxToken, _setMaxToken, _model, @@ -226,4 +159,136 @@ const MaxTokenSlider = ({ ); }; +export const TemperatureSlider = ({ + _temperature, + _setTemperature, +}: { + _temperature: number; + _setTemperature: React.Dispatch>; +}) => { + const { t } = useTranslation('model'); + + return ( +
+ + { + _setTemperature(Number(e.target.value)); + }} + min={0} + max={2} + step={0.1} + className='w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer' + /> +
+ {t('temperature.description')} +
+
+ ); +}; + +export const TopPSlider = ({ + _topP, + _setTopP, +}: { + _topP: number; + _setTopP: React.Dispatch>; +}) => { + const { t } = useTranslation('model'); + + return ( +
+ + { + _setTopP(Number(e.target.value)); + }} + min={0} + max={1} + step={0.05} + className='w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer' + /> +
+ {t('topP.description')} +
+
+ ); +}; + +export const PresencePenaltySlider = ({ + _presencePenalty, + _setPresencePenalty, +}: { + _presencePenalty: number; + _setPresencePenalty: React.Dispatch>; +}) => { + const { t } = useTranslation('model'); + + return ( +
+ + { + _setPresencePenalty(Number(e.target.value)); + }} + min={-2} + max={2} + step={0.1} + className='w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer' + /> +
+ {t('presencePenalty.description')} +
+
+ ); +}; + +export const FrequencyPenaltySlider = ({ + _frequencyPenalty, + _setFrequencyPenalty, +}: { + _frequencyPenalty: number; + _setFrequencyPenalty: React.Dispatch>; +}) => { + const { t } = useTranslation('model'); + + return ( +
+ + { + _setFrequencyPenalty(Number(e.target.value)); + }} + min={-2} + max={2} + step={0.1} + className='w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer' + /> +
+ {t('frequencyPenalty.description')} +
+
+ ); +}; + export default ConfigMenu; diff --git a/src/components/SettingsMenu/SettingsMenu.tsx b/src/components/SettingsMenu/SettingsMenu.tsx index 188e4cb..f66bd86 100644 --- a/src/components/SettingsMenu/SettingsMenu.tsx +++ b/src/components/SettingsMenu/SettingsMenu.tsx @@ -8,6 +8,7 @@ import ThemeSwitcher from '@components/Menu/MenuOptions/ThemeSwitcher'; import LanguageSelector from '@components/LanguageSelector'; import AutoTitleToggle from './AutoTitleToggle'; import PromptLibraryMenu from '@components/PromptLibraryMenu'; +import ChatConfigMenu from '@components/ChatConfigMenu'; const SettingsMenu = () => { const { t } = useTranslation(); @@ -39,6 +40,7 @@ const SettingsMenu = () => { + )} diff --git a/src/constants/chat.ts b/src/constants/chat.ts index 35f2e5f..9d9b1c4 100644 --- a/src/constants/chat.ts +++ b/src/constants/chat.ts @@ -1,4 +1,5 @@ import { ChatInterface, ConfigInterface, ModelOptions } from '@type/chat'; +import useStore from '@store/store'; const date = new Date(); const dateString = @@ -9,7 +10,7 @@ const dateString = ('0' + date.getDate()).slice(-2); // default system message obtained using the following method: https://twitter.com/DeminDimin/status/1619935545144279040 -export const defaultSystemMessage = `You are ChatGPT, a large language model trained by OpenAI. +export const _defaultSystemMessage = `You are ChatGPT, a large language model trained by OpenAI. Knowledge cutoff: 2021-09 Current date: ${dateString}`; @@ -35,7 +36,7 @@ export const modelMaxToken = { export const defaultUserMaxToken = 4000; -export const defaultChatConfig: ConfigInterface = { +export const _defaultChatConfig: ConfigInterface = { model: defaultModel, max_tokens: defaultUserMaxToken, temperature: 1, @@ -46,8 +47,10 @@ export const defaultChatConfig: ConfigInterface = { export const generateDefaultChat = (title?: string): ChatInterface => ({ title: title ? title : 'New Chat', - messages: [{ role: 'system', content: defaultSystemMessage }], - config: { ...defaultChatConfig }, + messages: [ + { role: 'system', content: useStore.getState().defaultSystemMessage }, + ], + config: { ...useStore.getState().defaultChatConfig }, titleSet: false, }); diff --git a/src/hooks/useSubmit.ts b/src/hooks/useSubmit.ts index 1056615..879c270 100644 --- a/src/hooks/useSubmit.ts +++ b/src/hooks/useSubmit.ts @@ -4,7 +4,7 @@ import { ChatInterface, MessageInterface } from '@type/chat'; import { getChatCompletion, getChatCompletionStream } from '@api/api'; import { parseEventSource } from '@api/helper'; import { limitMessageTokens } from '@utils/messageUtils'; -import { defaultChatConfig } from '@constants/chat'; +import { _defaultChatConfig } from '@constants/chat'; const useSubmit = () => { const error = useStore((state) => state.error); @@ -24,13 +24,13 @@ const useSubmit = () => { data = await getChatCompletion( useStore.getState().apiEndpoint, message, - defaultChatConfig + _defaultChatConfig ); } else if (apiKey) { data = await getChatCompletion( useStore.getState().apiEndpoint, message, - defaultChatConfig, + _defaultChatConfig, apiKey ); } diff --git a/src/store/config-slice.ts b/src/store/config-slice.ts index 22985a7..30418a8 100644 --- a/src/store/config-slice.ts +++ b/src/store/config-slice.ts @@ -1,19 +1,27 @@ import { StoreSlice } from './store'; import { Theme } from '@type/theme'; +import { ConfigInterface } from '@type/chat'; +import { _defaultChatConfig, _defaultSystemMessage } from '@constants/chat'; export interface ConfigSlice { openConfig: boolean; theme: Theme; autoTitle: boolean; + defaultChatConfig: ConfigInterface; + defaultSystemMessage: string; setOpenConfig: (openConfig: boolean) => void; setTheme: (theme: Theme) => void; setAutoTitle: (autoTitle: boolean) => void; + setDefaultChatConfig: (defaultChatConfig: ConfigInterface) => void; + setDefaultSystemMessage: (defaultSystemMessage: string) => void; } export const createConfigSlice: StoreSlice = (set, get) => ({ openConfig: false, theme: 'dark', autoTitle: false, + defaultChatConfig: _defaultChatConfig, + defaultSystemMessage: _defaultSystemMessage, setOpenConfig: (openConfig: boolean) => { set((prev: ConfigSlice) => ({ ...prev, @@ -32,4 +40,16 @@ export const createConfigSlice: StoreSlice = (set, get) => ({ autoTitle: autoTitle, })); }, + setDefaultChatConfig: (defaultChatConfig: ConfigInterface) => { + set((prev: ConfigSlice) => ({ + ...prev, + defaultChatConfig: defaultChatConfig, + })); + }, + setDefaultSystemMessage: (defaultSystemMessage: string) => { + set((prev: ConfigSlice) => ({ + ...prev, + defaultSystemMessage: defaultSystemMessage, + })); + }, }); diff --git a/src/store/migrate.ts b/src/store/migrate.ts index e2cb9f2..c26b6c3 100644 --- a/src/store/migrate.ts +++ b/src/store/migrate.ts @@ -7,7 +7,7 @@ import { LocalStorageInterfaceV5ToV6, } from '@type/chat'; import { - defaultChatConfig, + _defaultChatConfig, defaultModel, defaultUserMaxToken, } from '@constants/chat'; @@ -17,7 +17,7 @@ import defaultPrompts from '@constants/prompt'; export const migrateV0 = (persistedState: LocalStorageInterfaceV0ToV1) => { persistedState.chats.forEach((chat) => { chat.titleSet = false; - if (!chat.config) chat.config = { ...defaultChatConfig }; + if (!chat.config) chat.config = { ..._defaultChatConfig }; }); }; @@ -33,8 +33,8 @@ export const migrateV2 = (persistedState: LocalStorageInterfaceV2ToV3) => { persistedState.chats.forEach((chat) => { chat.config = { ...chat.config, - top_p: defaultChatConfig.top_p, - frequency_penalty: defaultChatConfig.frequency_penalty, + top_p: _defaultChatConfig.top_p, + frequency_penalty: _defaultChatConfig.frequency_penalty, }; }); persistedState.autoTitle = false; diff --git a/src/store/store.ts b/src/store/store.ts index e7e90c6..2decba6 100644 --- a/src/store/store.ts +++ b/src/store/store.ts @@ -53,6 +53,8 @@ const useStore = create()( theme: state.theme, autoTitle: state.autoTitle, prompts: state.prompts, + defaultChatConfig: state.defaultChatConfig, + defaultSystemMessage: state.defaultSystemMessage, }), version: 6, migrate: (persistedState, version) => { diff --git a/src/utils/chat.ts b/src/utils/chat.ts index a0efeb5..5cbcf7d 100644 --- a/src/utils/chat.ts +++ b/src/utils/chat.ts @@ -3,7 +3,7 @@ import jsPDF from 'jspdf'; import { ChatInterface, ConfigInterface, MessageInterface } from '@type/chat'; import { roles } from '@type/chat'; import { Theme } from '@type/theme'; -import { defaultChatConfig } from '@constants/chat'; +import { _defaultChatConfig } from '@constants/chat'; export const validateAndFixChats = (chats: any): chats is ChatInterface[] => { if (!Array.isArray(chats)) return false; @@ -32,21 +32,21 @@ const validateMessage = (messages: MessageInterface[]) => { }; const validateAndFixChatConfig = (config: ConfigInterface) => { - if (config === undefined) config = defaultChatConfig; + if (config === undefined) config = _defaultChatConfig; if (!(typeof config === 'object')) return false; - if (!config.temperature) config.temperature = defaultChatConfig.temperature; + if (!config.temperature) config.temperature = _defaultChatConfig.temperature; if (!(typeof config.temperature === 'number')) return false; if (!config.presence_penalty) - config.presence_penalty = defaultChatConfig.presence_penalty; + config.presence_penalty = _defaultChatConfig.presence_penalty; if (!(typeof config.presence_penalty === 'number')) return false; - if (!config.top_p) config.top_p = defaultChatConfig.top_p; + if (!config.top_p) config.top_p = _defaultChatConfig.top_p; if (!(typeof config.top_p === 'number')) return false; if (!config.frequency_penalty) - config.frequency_penalty = defaultChatConfig.frequency_penalty; + config.frequency_penalty = _defaultChatConfig.frequency_penalty; if (!(typeof config.frequency_penalty === 'number')) return false; return true;