From a009be99869a471fb49c1d924195529ad938e07b Mon Sep 17 00:00:00 2001 From: Jing Hua <59118459+ztjhz@users.noreply.github.com> Date: Mon, 6 Mar 2023 22:22:05 +0800 Subject: [PATCH] feat: model parameters customisation (#31) issue #14 --- src/api/customApi.ts | 10 ++- src/api/freeApi.ts | 14 +++- src/components/Chat/ChatContent/ChatTitle.tsx | 76 +++++++++++++++-- .../ChatContent/Message/NewMessageButton.tsx | 10 ++- src/components/ConfigMenu/ConfigMenu.tsx | 83 +++++++++++++++++++ src/components/ConfigMenu/index.ts | 1 + src/components/PopupModal/PopupModal.tsx | 2 +- src/constants/chat.ts | 7 ++ src/hooks/useAddChat.ts | 3 +- src/hooks/useInitialiseNewChat.ts | 3 +- src/hooks/useSubmit.ts | 6 +- src/main.css | 8 ++ src/types/chat.ts | 6 ++ 13 files changed, 209 insertions(+), 20 deletions(-) create mode 100644 src/components/ConfigMenu/ConfigMenu.tsx create mode 100644 src/components/ConfigMenu/index.ts diff --git a/src/api/customApi.ts b/src/api/customApi.ts index 98bb65f..96a6cde 100644 --- a/src/api/customApi.ts +++ b/src/api/customApi.ts @@ -1,4 +1,4 @@ -import { MessageInterface } from '@type/chat'; +import { ConfigInterface, MessageInterface } from '@type/chat'; export const endpoint = 'https://api.openai.com/v1/chat/completions'; @@ -23,7 +23,8 @@ export const validateApiKey = async (apiKey: string) => { export const getChatCompletion = async ( apiKey: string, - messages: MessageInterface[] + messages: MessageInterface[], + config: ConfigInterface ) => { const response = await fetch(endpoint, { method: 'POST', @@ -34,6 +35,7 @@ export const getChatCompletion = async ( body: JSON.stringify({ model: 'gpt-3.5-turbo', messages, + ...config, }), }); if (!response.ok) throw new Error(await response.text()); @@ -44,7 +46,8 @@ export const getChatCompletion = async ( export const getChatCompletionStream = async ( apiKey: string, - messages: MessageInterface[] + messages: MessageInterface[], + config: ConfigInterface ) => { const response = await fetch(endpoint, { method: 'POST', @@ -55,6 +58,7 @@ export const getChatCompletionStream = async ( body: JSON.stringify({ model: 'gpt-3.5-turbo', messages, + ...config, stream: true, }), }); diff --git a/src/api/freeApi.ts b/src/api/freeApi.ts index 2835f79..67dbf15 100644 --- a/src/api/freeApi.ts +++ b/src/api/freeApi.ts @@ -1,8 +1,11 @@ -import { MessageInterface } from '@type/chat'; +import { ConfigInterface, MessageInterface } from '@type/chat'; export const endpoint = 'https://chatgpt-api.shn.hk/v1/'; -export const getChatCompletion = async (messages: MessageInterface[]) => { +export const getChatCompletion = async ( + messages: MessageInterface[], + config: ConfigInterface +) => { const response = await fetch(endpoint, { method: 'POST', headers: { @@ -11,6 +14,7 @@ export const getChatCompletion = async (messages: MessageInterface[]) => { body: JSON.stringify({ model: 'gpt-3.5-turbo', messages, + ...config, }), }); if (!response.ok) throw new Error(await response.text()); @@ -19,7 +23,10 @@ export const getChatCompletion = async (messages: MessageInterface[]) => { return data; }; -export const getChatCompletionStream = async (messages: MessageInterface[]) => { +export const getChatCompletionStream = async ( + messages: MessageInterface[], + config: ConfigInterface +) => { const response = await fetch(endpoint, { method: 'POST', headers: { @@ -28,6 +35,7 @@ export const getChatCompletionStream = async (messages: MessageInterface[]) => { body: JSON.stringify({ model: 'gpt-3.5-turbo', messages, + ...config, stream: true, }), }); diff --git a/src/components/Chat/ChatContent/ChatTitle.tsx b/src/components/Chat/ChatContent/ChatTitle.tsx index 7acab42..b8f3691 100644 --- a/src/components/Chat/ChatContent/ChatTitle.tsx +++ b/src/components/Chat/ChatContent/ChatTitle.tsx @@ -1,11 +1,73 @@ -import React from 'react'; +import React, { useEffect, useState } from 'react'; +import { shallow } from 'zustand/shallow'; +import useStore from '@store/store'; +import ConfigMenu from '@components/ConfigMenu'; +import { ChatInterface, ConfigInterface } from '@type/chat'; +import { defaultChatConfig } from '@constants/chat'; -const ChatTitle = () => { - return ( -
- Model: Default -
+const ChatTitle = React.memo(() => { + const config = useStore( + (state) => + state.chats && + state.chats.length > 0 && + state.currentChatIndex >= 0 && + state.currentChatIndex < state.chats.length + ? state.chats[state.currentChatIndex].config + : undefined, + shallow ); -}; + const setChats = useStore((state) => state.setChats); + const currentChatIndex = useStore((state) => state.currentChatIndex); + const [isModalOpen, setIsModalOpen] = useState(false); + + const setConfig = (config: ConfigInterface) => { + const updatedChats: ChatInterface[] = JSON.parse( + JSON.stringify(useStore.getState().chats) + ); + updatedChats[currentChatIndex].config = config; + setChats(updatedChats); + }; + + // for migrating from old ChatInterface to new ChatInterface (with config) + useEffect(() => { + if (!config) { + const updatedChats: ChatInterface[] = JSON.parse( + JSON.stringify(useStore.getState().chats) + ); + updatedChats[currentChatIndex].config = { ...defaultChatConfig }; + setChats(updatedChats); + } + }, [currentChatIndex]); + + return config ? ( + <> +
{ + setIsModalOpen(true); + }} + > +
+ Model: Default +
+
+ Temperature: {config.temperature} +
+
+ PresencePenalty: {config.presence_penalty} +
+
+ {isModalOpen && ( + + )} + + ) : ( + <> + ); +}); export default ChatTitle; diff --git a/src/components/Chat/ChatContent/Message/NewMessageButton.tsx b/src/components/Chat/ChatContent/Message/NewMessageButton.tsx index ef36c5a..9df6d28 100644 --- a/src/components/Chat/ChatContent/Message/NewMessageButton.tsx +++ b/src/components/Chat/ChatContent/Message/NewMessageButton.tsx @@ -4,7 +4,7 @@ import useStore from '@store/store'; import PlusIcon from '@icon/PlusIcon'; import { ChatInterface } from '@type/chat'; -import { defaultSystemMessage } from '@constants/chat'; +import { defaultChatConfig, defaultSystemMessage } from '@constants/chat'; const NewMessageButton = React.memo( ({ messageIndex }: { messageIndex: number }) => { @@ -26,7 +26,13 @@ const NewMessageButton = React.memo( updatedChats.unshift({ title, - messages: [{ role: 'system', content: defaultSystemMessage }], + messages: [ + { + role: 'system', + content: defaultSystemMessage, + }, + ], + config: { ...defaultChatConfig }, }); setChats(updatedChats); setCurrentChatIndex(0); diff --git a/src/components/ConfigMenu/ConfigMenu.tsx b/src/components/ConfigMenu/ConfigMenu.tsx new file mode 100644 index 0000000..f1e2628 --- /dev/null +++ b/src/components/ConfigMenu/ConfigMenu.tsx @@ -0,0 +1,83 @@ +import React, { useState } from 'react'; +import PopupModal from '@components/PopupModal'; +import { ConfigInterface } from '@type/chat'; + +const ConfigMenu = ({ + setIsModalOpen, + config, + setConfig, +}: { + setIsModalOpen: React.Dispatch>; + config: ConfigInterface; + setConfig: (config: ConfigInterface) => void; +}) => { + const [_temperature, _setTemperature] = useState(config.temperature); + const [_presencePenalty, _setPresencePenalty] = useState( + config.presence_penalty + ); + + const handleConfirm = () => { + setConfig({ + temperature: _temperature, + presence_penalty: _presencePenalty, + }); + setIsModalOpen(false); + }; + + return ( + +
+
+ + { + _setTemperature(Number(e.target.value)); + }} + min={0} + max={2} + step={0.1} + className='w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer' + /> +
+ What sampling temperature to use, between 0 and 2. Higher values + like 0.8 will make the output more random, while lower values like + 0.2 will make it more focused and deterministic. +
+
+
+ + { + _setPresencePenalty(Number(e.target.value)); + }} + min={-2} + max={2} + step={0.1} + className='w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer' + /> +
+ Number between -2.0 and 2.0. Positive values penalize new tokens + based on whether they appear in the text so far, increasing the + model's likelihood to talk about new topics. +
+
+
+
+ ); +}; + +export default ConfigMenu; diff --git a/src/components/ConfigMenu/index.ts b/src/components/ConfigMenu/index.ts new file mode 100644 index 0000000..6497639 --- /dev/null +++ b/src/components/ConfigMenu/index.ts @@ -0,0 +1 @@ +export { default } from './ConfigMenu'; diff --git a/src/components/PopupModal/PopupModal.tsx b/src/components/PopupModal/PopupModal.tsx index b806444..2d412da 100644 --- a/src/components/PopupModal/PopupModal.tsx +++ b/src/components/PopupModal/PopupModal.tsx @@ -31,7 +31,7 @@ const PopupModal = ({ return ReactDOM.createPortal(
-
+

{title} diff --git a/src/constants/chat.ts b/src/constants/chat.ts index 82f6570..2ea2df8 100644 --- a/src/constants/chat.ts +++ b/src/constants/chat.ts @@ -1,3 +1,5 @@ +import { ConfigInterface } from '@type/chat'; + const date = new Date(); const dateString = date.getFullYear() + @@ -10,3 +12,8 @@ const dateString = export const defaultSystemMessage = `You are ChatGPT, a large language model trained by OpenAI. Knowledge cutoff: 2021-09 Current date: ${dateString}`; + +export const defaultChatConfig: ConfigInterface = { + temperature: 1, + presence_penalty: 0, +}; diff --git a/src/hooks/useAddChat.ts b/src/hooks/useAddChat.ts index f1fe6bf..95045e7 100644 --- a/src/hooks/useAddChat.ts +++ b/src/hooks/useAddChat.ts @@ -1,6 +1,6 @@ import React from 'react'; import useStore from '@store/store'; -import { defaultSystemMessage } from '@constants/chat'; +import { defaultChatConfig, defaultSystemMessage } from '@constants/chat'; import { ChatInterface } from '@type/chat'; const useAddChat = () => { @@ -22,6 +22,7 @@ const useAddChat = () => { updatedChats.unshift({ title, messages: [{ role: 'system', content: defaultSystemMessage }], + config: { ...defaultChatConfig }, }); setChats(updatedChats); setCurrentChatIndex(0); diff --git a/src/hooks/useInitialiseNewChat.ts b/src/hooks/useInitialiseNewChat.ts index 25fe721..2f1a678 100644 --- a/src/hooks/useInitialiseNewChat.ts +++ b/src/hooks/useInitialiseNewChat.ts @@ -1,7 +1,7 @@ import React from 'react'; import useStore from '@store/store'; import { MessageInterface } from '@type/chat'; -import { defaultSystemMessage } from '@constants/chat'; +import { defaultChatConfig, defaultSystemMessage } from '@constants/chat'; const useInitialiseNewChat = () => { const setChats = useStore((state) => state.setChats); @@ -16,6 +16,7 @@ const useInitialiseNewChat = () => { { title: 'New Chat', messages: [message], + config: { ...defaultChatConfig }, }, ]); setCurrentChatIndex(0); diff --git a/src/hooks/useSubmit.ts b/src/hooks/useSubmit.ts index e776748..52e84d1 100644 --- a/src/hooks/useSubmit.ts +++ b/src/hooks/useSubmit.ts @@ -33,12 +33,14 @@ const useSubmit = () => { try { if (apiFree) { stream = await getChatCompletionStreamFree( - chats[currentChatIndex].messages + chats[currentChatIndex].messages, + chats[currentChatIndex].config ); } else if (apiKey) { stream = await getChatCompletionStreamCustom( apiKey, - chats[currentChatIndex].messages + chats[currentChatIndex].messages, + chats[currentChatIndex].config ); } diff --git a/src/main.css b/src/main.css index a987cea..80a9b38 100644 --- a/src/main.css +++ b/src/main.css @@ -27,6 +27,14 @@ height: 100%; } + input[type='range']::-webkit-slider-thumb { + -webkit-appearance: none; + @apply w-4; + @apply h-4; + @apply rounded-full; + background: rgba(16, 163, 127); + } + ::-webkit-scrollbar { height: 1rem; width: 0.5rem; diff --git a/src/types/chat.ts b/src/types/chat.ts index 651020e..ee2b8b2 100644 --- a/src/types/chat.ts +++ b/src/types/chat.ts @@ -9,4 +9,10 @@ export interface MessageInterface { export interface ChatInterface { title: string; messages: MessageInterface[]; + config: ConfigInterface; +} + +export interface ConfigInterface { + temperature: number; + presence_penalty: number; }