feat: model parameters customisation (#31)

issue #14
This commit is contained in:
Jing Hua 2023-03-06 22:22:05 +08:00 committed by GitHub
parent 670c26774a
commit a009be9986
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 209 additions and 20 deletions

View file

@ -1,4 +1,4 @@
import { MessageInterface } from '@type/chat'; import { ConfigInterface, MessageInterface } from '@type/chat';
export const endpoint = 'https://api.openai.com/v1/chat/completions'; export const endpoint = 'https://api.openai.com/v1/chat/completions';
@ -23,7 +23,8 @@ export const validateApiKey = async (apiKey: string) => {
export const getChatCompletion = async ( export const getChatCompletion = async (
apiKey: string, apiKey: string,
messages: MessageInterface[] messages: MessageInterface[],
config: ConfigInterface
) => { ) => {
const response = await fetch(endpoint, { const response = await fetch(endpoint, {
method: 'POST', method: 'POST',
@ -34,6 +35,7 @@ export const getChatCompletion = async (
body: JSON.stringify({ body: JSON.stringify({
model: 'gpt-3.5-turbo', model: 'gpt-3.5-turbo',
messages, messages,
...config,
}), }),
}); });
if (!response.ok) throw new Error(await response.text()); if (!response.ok) throw new Error(await response.text());
@ -44,7 +46,8 @@ export const getChatCompletion = async (
export const getChatCompletionStream = async ( export const getChatCompletionStream = async (
apiKey: string, apiKey: string,
messages: MessageInterface[] messages: MessageInterface[],
config: ConfigInterface
) => { ) => {
const response = await fetch(endpoint, { const response = await fetch(endpoint, {
method: 'POST', method: 'POST',
@ -55,6 +58,7 @@ export const getChatCompletionStream = async (
body: JSON.stringify({ body: JSON.stringify({
model: 'gpt-3.5-turbo', model: 'gpt-3.5-turbo',
messages, messages,
...config,
stream: true, stream: true,
}), }),
}); });

View file

@ -1,8 +1,11 @@
import { MessageInterface } from '@type/chat'; import { ConfigInterface, MessageInterface } from '@type/chat';
export const endpoint = 'https://chatgpt-api.shn.hk/v1/'; export const endpoint = 'https://chatgpt-api.shn.hk/v1/';
export const getChatCompletion = async (messages: MessageInterface[]) => { export const getChatCompletion = async (
messages: MessageInterface[],
config: ConfigInterface
) => {
const response = await fetch(endpoint, { const response = await fetch(endpoint, {
method: 'POST', method: 'POST',
headers: { headers: {
@ -11,6 +14,7 @@ export const getChatCompletion = async (messages: MessageInterface[]) => {
body: JSON.stringify({ body: JSON.stringify({
model: 'gpt-3.5-turbo', model: 'gpt-3.5-turbo',
messages, messages,
...config,
}), }),
}); });
if (!response.ok) throw new Error(await response.text()); if (!response.ok) throw new Error(await response.text());
@ -19,7 +23,10 @@ export const getChatCompletion = async (messages: MessageInterface[]) => {
return data; return data;
}; };
export const getChatCompletionStream = async (messages: MessageInterface[]) => { export const getChatCompletionStream = async (
messages: MessageInterface[],
config: ConfigInterface
) => {
const response = await fetch(endpoint, { const response = await fetch(endpoint, {
method: 'POST', method: 'POST',
headers: { headers: {
@ -28,6 +35,7 @@ export const getChatCompletionStream = async (messages: MessageInterface[]) => {
body: JSON.stringify({ body: JSON.stringify({
model: 'gpt-3.5-turbo', model: 'gpt-3.5-turbo',
messages, messages,
...config,
stream: true, stream: true,
}), }),
}); });

View file

@ -1,11 +1,73 @@
import React from 'react'; import React, { useEffect, useState } from 'react';
import { shallow } from 'zustand/shallow';
import useStore from '@store/store';
import ConfigMenu from '@components/ConfigMenu';
import { ChatInterface, ConfigInterface } from '@type/chat';
import { defaultChatConfig } from '@constants/chat';
const ChatTitle = () => { const ChatTitle = React.memo(() => {
return ( const config = useStore(
<div className='flex w-full items-center justify-center gap-1 border-b border-black/10 bg-gray-50 p-3 text-gray-500 dark:border-gray-900/50 dark:bg-gray-700 dark:text-gray-300'> (state) =>
state.chats &&
state.chats.length > 0 &&
state.currentChatIndex >= 0 &&
state.currentChatIndex < state.chats.length
? state.chats[state.currentChatIndex].config
: undefined,
shallow
);
const setChats = useStore((state) => state.setChats);
const currentChatIndex = useStore((state) => state.currentChatIndex);
const [isModalOpen, setIsModalOpen] = useState<boolean>(false);
const setConfig = (config: ConfigInterface) => {
const updatedChats: ChatInterface[] = JSON.parse(
JSON.stringify(useStore.getState().chats)
);
updatedChats[currentChatIndex].config = config;
setChats(updatedChats);
};
// for migrating from old ChatInterface to new ChatInterface (with config)
useEffect(() => {
if (!config) {
const updatedChats: ChatInterface[] = JSON.parse(
JSON.stringify(useStore.getState().chats)
);
updatedChats[currentChatIndex].config = { ...defaultChatConfig };
setChats(updatedChats);
}
}, [currentChatIndex]);
return config ? (
<>
<div
className='flex gap-4 flex-wrap w-full items-center justify-center gap-1 border-b border-black/10 bg-gray-50 p-3 text-gray-500 dark:border-gray-900/50 dark:bg-gray-700 dark:text-gray-300 cursor-pointer'
onClick={() => {
setIsModalOpen(true);
}}
>
<div className='text-center p-1 rounded-md bg-gray-900/10 hover:bg-gray-900/50'>
Model: Default Model: Default
</div> </div>
<div className='text-center p-1 rounded-md bg-gray-900/10 hover:bg-gray-900/50'>
Temperature: {config.temperature}
</div>
<div className='text-center p-1 rounded-md bg-gray-900/10 hover:bg-gray-900/50'>
PresencePenalty: {config.presence_penalty}
</div>
</div>
{isModalOpen && (
<ConfigMenu
setIsModalOpen={setIsModalOpen}
config={config}
setConfig={setConfig}
/>
)}
</>
) : (
<></>
); );
}; });
export default ChatTitle; export default ChatTitle;

View file

@ -4,7 +4,7 @@ import useStore from '@store/store';
import PlusIcon from '@icon/PlusIcon'; import PlusIcon from '@icon/PlusIcon';
import { ChatInterface } from '@type/chat'; import { ChatInterface } from '@type/chat';
import { defaultSystemMessage } from '@constants/chat'; import { defaultChatConfig, defaultSystemMessage } from '@constants/chat';
const NewMessageButton = React.memo( const NewMessageButton = React.memo(
({ messageIndex }: { messageIndex: number }) => { ({ messageIndex }: { messageIndex: number }) => {
@ -26,7 +26,13 @@ const NewMessageButton = React.memo(
updatedChats.unshift({ updatedChats.unshift({
title, title,
messages: [{ role: 'system', content: defaultSystemMessage }], messages: [
{
role: 'system',
content: defaultSystemMessage,
},
],
config: { ...defaultChatConfig },
}); });
setChats(updatedChats); setChats(updatedChats);
setCurrentChatIndex(0); setCurrentChatIndex(0);

View file

@ -0,0 +1,83 @@
import React, { useState } from 'react';
import PopupModal from '@components/PopupModal';
import { ConfigInterface } from '@type/chat';
const ConfigMenu = ({
setIsModalOpen,
config,
setConfig,
}: {
setIsModalOpen: React.Dispatch<React.SetStateAction<boolean>>;
config: ConfigInterface;
setConfig: (config: ConfigInterface) => void;
}) => {
const [_temperature, _setTemperature] = useState<number>(config.temperature);
const [_presencePenalty, _setPresencePenalty] = useState<number>(
config.presence_penalty
);
const handleConfirm = () => {
setConfig({
temperature: _temperature,
presence_penalty: _presencePenalty,
});
setIsModalOpen(false);
};
return (
<PopupModal
title='Configuration'
setIsModalOpen={setIsModalOpen}
handleConfirm={handleConfirm}
>
<div className='p-6 border-b border-gray-200 dark:border-gray-600'>
<div>
<label className='block text-sm font-medium text-gray-900 dark:text-white'>
Temperature: {_temperature}
</label>
<input
id='default-range'
type='range'
value={_temperature}
onChange={(e) => {
_setTemperature(Number(e.target.value));
}}
min={0}
max={2}
step={0.1}
className='w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer'
/>
<div className='min-w-fit text-gray-500 dark:text-gray-300 text-sm mt-2'>
What sampling temperature to use, between 0 and 2. Higher values
like 0.8 will make the output more random, while lower values like
0.2 will make it more focused and deterministic.
</div>
</div>
<div className='mt-5 pt-5 border-t border-gray-500'>
<label className='block text-sm font-medium text-gray-900 dark:text-white'>
Presence Penalty: {_presencePenalty}
</label>
<input
id='default-range'
type='range'
value={_presencePenalty}
onChange={(e) => {
_setPresencePenalty(Number(e.target.value));
}}
min={-2}
max={2}
step={0.1}
className='w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer'
/>
<div className='min-w-fit text-gray-500 dark:text-gray-300 text-sm mt-2'>
Number between -2.0 and 2.0. Positive values penalize new tokens
based on whether they appear in the text so far, increasing the
model's likelihood to talk about new topics.
</div>
</div>
</div>
</PopupModal>
);
};
export default ConfigMenu;

View file

@ -0,0 +1 @@
export { default } from './ConfigMenu';

View file

@ -31,7 +31,7 @@ const PopupModal = ({
return ReactDOM.createPortal( return ReactDOM.createPortal(
<div className='fixed top-0 left-0 z-[999] w-full p-4 overflow-x-hidden overflow-y-auto h-full flex justify-center items-center'> <div className='fixed top-0 left-0 z-[999] w-full p-4 overflow-x-hidden overflow-y-auto h-full flex justify-center items-center'>
<div className='relative z-2 max-w-2xl md:h-auto flex justify-center items-center'> <div className='relative z-2 max-w-2xl md:h-auto flex justify-center items-center'>
<div className='relative bg-white rounded-lg shadow dark:bg-gray-700'> <div className='relative bg-gray-50 rounded-lg shadow dark:bg-gray-700'>
<div className='flex items-center justify-between p-4 border-b rounded-t dark:border-gray-600'> <div className='flex items-center justify-between p-4 border-b rounded-t dark:border-gray-600'>
<h3 className='ml-2 text-lg font-semibold text-gray-900 dark:text-white'> <h3 className='ml-2 text-lg font-semibold text-gray-900 dark:text-white'>
{title} {title}

View file

@ -1,3 +1,5 @@
import { ConfigInterface } from '@type/chat';
const date = new Date(); const date = new Date();
const dateString = const dateString =
date.getFullYear() + date.getFullYear() +
@ -10,3 +12,8 @@ const dateString =
export const defaultSystemMessage = `You are ChatGPT, a large language model trained by OpenAI. export const defaultSystemMessage = `You are ChatGPT, a large language model trained by OpenAI.
Knowledge cutoff: 2021-09 Knowledge cutoff: 2021-09
Current date: ${dateString}`; Current date: ${dateString}`;
export const defaultChatConfig: ConfigInterface = {
temperature: 1,
presence_penalty: 0,
};

View file

@ -1,6 +1,6 @@
import React from 'react'; import React from 'react';
import useStore from '@store/store'; import useStore from '@store/store';
import { defaultSystemMessage } from '@constants/chat'; import { defaultChatConfig, defaultSystemMessage } from '@constants/chat';
import { ChatInterface } from '@type/chat'; import { ChatInterface } from '@type/chat';
const useAddChat = () => { const useAddChat = () => {
@ -22,6 +22,7 @@ const useAddChat = () => {
updatedChats.unshift({ updatedChats.unshift({
title, title,
messages: [{ role: 'system', content: defaultSystemMessage }], messages: [{ role: 'system', content: defaultSystemMessage }],
config: { ...defaultChatConfig },
}); });
setChats(updatedChats); setChats(updatedChats);
setCurrentChatIndex(0); setCurrentChatIndex(0);

View file

@ -1,7 +1,7 @@
import React from 'react'; import React from 'react';
import useStore from '@store/store'; import useStore from '@store/store';
import { MessageInterface } from '@type/chat'; import { MessageInterface } from '@type/chat';
import { defaultSystemMessage } from '@constants/chat'; import { defaultChatConfig, defaultSystemMessage } from '@constants/chat';
const useInitialiseNewChat = () => { const useInitialiseNewChat = () => {
const setChats = useStore((state) => state.setChats); const setChats = useStore((state) => state.setChats);
@ -16,6 +16,7 @@ const useInitialiseNewChat = () => {
{ {
title: 'New Chat', title: 'New Chat',
messages: [message], messages: [message],
config: { ...defaultChatConfig },
}, },
]); ]);
setCurrentChatIndex(0); setCurrentChatIndex(0);

View file

@ -33,12 +33,14 @@ const useSubmit = () => {
try { try {
if (apiFree) { if (apiFree) {
stream = await getChatCompletionStreamFree( stream = await getChatCompletionStreamFree(
chats[currentChatIndex].messages chats[currentChatIndex].messages,
chats[currentChatIndex].config
); );
} else if (apiKey) { } else if (apiKey) {
stream = await getChatCompletionStreamCustom( stream = await getChatCompletionStreamCustom(
apiKey, apiKey,
chats[currentChatIndex].messages chats[currentChatIndex].messages,
chats[currentChatIndex].config
); );
} }

View file

@ -27,6 +27,14 @@
height: 100%; height: 100%;
} }
input[type='range']::-webkit-slider-thumb {
-webkit-appearance: none;
@apply w-4;
@apply h-4;
@apply rounded-full;
background: rgba(16, 163, 127);
}
::-webkit-scrollbar { ::-webkit-scrollbar {
height: 1rem; height: 1rem;
width: 0.5rem; width: 0.5rem;

View file

@ -9,4 +9,10 @@ export interface MessageInterface {
export interface ChatInterface { export interface ChatInterface {
title: string; title: string;
messages: MessageInterface[]; messages: MessageInterface[];
config: ConfigInterface;
}
export interface ConfigInterface {
temperature: number;
presence_penalty: number;
} }