feat: more request parameters (#62)

* feat: more request parameters

added top_p and frequency_penalty

* fix: value range of top_p

valid top_p range should be 0-1

* migration + change chats validation

* style chat title

---------

Co-authored-by: Jing Hua <tohjinghua123@gmail.com>
This commit is contained in:
akira0245 2023-03-13 13:01:24 +08:00 committed by GitHub
parent 9f1529d07a
commit 4c80898078
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 136 additions and 23 deletions

View file

@ -4,10 +4,18 @@
"default": "Default",
"temperature": {
"label": "Temperature",
"description": "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic (Default: 1)"
"description": "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or top p but not both. (Default: 1)"
},
"presencePenalty": {
"label": "Presence Penalty",
"description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. (Default: 0)"
},
"topP": {
"label": "Top P",
"description": "Number between 0 and 1. An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both. (Default: 1)"
},
"frequencyPenalty": {
"label": "Frequency Penalty",
"description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. (Default: 0)"
}
}

View file

@ -4,10 +4,18 @@
"default": "默认",
"temperature": {
"label": "采样温度",
"description": "使用何种采样温度,值在 0 到 2 之间。较高的数值如 0.8 会使输出更加随机,而较低的数值如 0.2 会使输出更加集中和确定。(默认: 1)"
"description": "使用何种采样温度,值在 0 到 2 之间。较高的数值如 0.8 会使输出更加随机,而较低的数值如 0.2 会使输出更加集中和确定。我们通常建议修改此参数或概率质量,但不要同时修改两者。(默认: 1)"
},
"presencePenalty": {
"label": "存在惩罚",
"description": "数值在 -2.0 到 2.0 之间。正值会根据新 token 是否已经出现在文本中来惩罚它们,增加模型谈论新话题的可能性。 (默认: 0)"
},
"topP": {
"label": "概率质量",
"description": "数值在 0 到 1 之间。采用核采样nucleus sampling的一种采样温度的替代方法模型考虑具有最高概率质量的 token 的结果。因此0.1 表示仅考虑占前 10% 概率质量的 token。我们通常建议修改此参数或采样温度但不要同时修改两者。(默认: 1)"
},
"frequencyPenalty": {
"label": "频率惩罚",
"description": "数值在 -2.0 到 2.0 之间。正值会根据新 token 在文本中的现有频率来惩罚它们,降低模型直接重复相同语句的可能性。(默认: 0)"
}
}

View file

@ -43,7 +43,7 @@ const ChatTitle = React.memo(() => {
return config ? (
<>
<div
className='flex gap-4 flex-wrap w-full items-center justify-center gap-1 border-b border-black/10 bg-gray-50 p-3 text-gray-500 dark:border-gray-900/50 dark:bg-gray-700 dark:text-gray-300 cursor-pointer'
className='flex gap-x-4 gap-y-1 flex-wrap w-full items-center justify-center border-b border-black/10 bg-gray-50 p-3 text-gray-500 dark:border-gray-900/50 dark:bg-gray-700 dark:text-gray-300 cursor-pointer'
onClick={() => {
setIsModalOpen(true);
}}
@ -54,9 +54,15 @@ const ChatTitle = React.memo(() => {
<div className='text-center p-1 rounded-md bg-gray-900/10 hover:bg-gray-900/50'>
{t('temperature.label')}: {config.temperature}
</div>
<div className='text-center p-1 rounded-md bg-gray-900/10 hover:bg-gray-900/50'>
{t('topP.label')}: {config.top_p}
</div>
<div className='text-center p-1 rounded-md bg-gray-900/10 hover:bg-gray-900/50'>
{t('presencePenalty.label')}: {config.presence_penalty}
</div>
<div className='text-center p-1 rounded-md bg-gray-900/10 hover:bg-gray-900/50'>
{t('frequencyPenalty.label')}: {config.frequency_penalty}
</div>
</div>
{isModalOpen && (
<ConfigMenu

View file

@ -13,15 +13,17 @@ const ConfigMenu = ({
setConfig: (config: ConfigInterface) => void;
}) => {
const [_temperature, _setTemperature] = useState<number>(config.temperature);
const [_presencePenalty, _setPresencePenalty] = useState<number>(
config.presence_penalty
);
const [_presencePenalty, _setPresencePenalty] = useState<number>(config.presence_penalty);
const [_topP, _setTopP] = useState<number>(config.top_p);
const [_frequencyPenalty, _setFrequencyPenalty] = useState<number>(config.frequency_penalty);
const { t } = useTranslation('model');
const handleConfirm = () => {
setConfig({
temperature: _temperature,
presence_penalty: _presencePenalty,
top_p: _topP,
frequency_penalty: _frequencyPenalty
});
setIsModalOpen(false);
};
@ -53,6 +55,26 @@ const ConfigMenu = ({
{t('temperature.description')}
</div>
</div>
<div className='mt-5 pt-5 border-t border-gray-500'>
<label className='block text-sm font-medium text-gray-900 dark:text-white'>
{t('topP.label')}: {_topP}
</label>
<input
id='default-range'
type='range'
value={_topP}
onChange={(e) => {
_setTopP(Number(e.target.value));
}}
min={0}
max={1}
step={0.05}
className='w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer'
/>
<div className='min-w-fit text-gray-500 dark:text-gray-300 text-sm mt-2'>
{t('topP.description')}
</div>
</div>
<div className='mt-5 pt-5 border-t border-gray-500'>
<label className='block text-sm font-medium text-gray-900 dark:text-white'>
{t('presencePenalty.label')}: {_presencePenalty}
@ -73,6 +95,26 @@ const ConfigMenu = ({
{t('presencePenalty.description')}
</div>
</div>
<div className='mt-5 pt-5 border-t border-gray-500'>
<label className='block text-sm font-medium text-gray-900 dark:text-white'>
{t('frequencyPenalty.label')}: {_frequencyPenalty}
</label>
<input
id='default-range'
type='range'
value={_frequencyPenalty}
onChange={(e) => {
_setFrequencyPenalty(Number(e.target.value));
}}
min={-2}
max={2}
step={0.1}
className='w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer'
/>
<div className='min-w-fit text-gray-500 dark:text-gray-300 text-sm mt-2'>
{t('frequencyPenalty.description')}
</div>
</div>
</div>
</PopupModal>
);

View file

@ -6,7 +6,7 @@ import ExportIcon from '@icon/ExportIcon';
import downloadFile from '@utils/downloadFile';
import { getToday } from '@utils/date';
import PopupModal from '@components/PopupModal';
import { isChats } from '@utils/chat';
import { validateAndFixChats } from '@utils/chat';
const ImportExportChat = () => {
const { t } = useTranslation();
@ -60,7 +60,7 @@ const ImportChat = () => {
try {
const parsedData = JSON.parse(data);
if (isChats(parsedData)) {
if (validateAndFixChats(parsedData)) {
setChats(parsedData);
setAlert({ message: 'Succesfully imported!', success: true });
} else {

View file

@ -16,6 +16,8 @@ Current date: ${dateString}`;
export const defaultChatConfig: ConfigInterface = {
temperature: 1,
presence_penalty: 0,
top_p: 1,
frequency_penalty: 0
};
export const generateDefaultChat = (title?: string): ChatInterface => ({

View file

@ -19,3 +19,10 @@ export const migrateV1 = (persistedState: LocalStorageInterfaceV1ToV2) => {
persistedState.apiEndpoint = officialAPIEndpoint;
}
};
export const migrateV2 = (persistedState: LocalStorageInterfaceV1ToV2) => {
persistedState.chats.forEach((chat) => {
chat.config.top_p = defaultChatConfig.top_p;
chat.config.frequency_penalty = defaultChatConfig.frequency_penalty;
});
};

View file

@ -7,8 +7,9 @@ import { ConfigSlice, createConfigSlice } from './config-slice';
import {
LocalStorageInterfaceV0ToV1,
LocalStorageInterfaceV1ToV2,
LocalStorageInterfaceV2ToV3,
} from '@type/chat';
import { migrateV0, migrateV1 } from './migrate';
import { migrateV0, migrateV1, migrateV2 } from './migrate';
export type StoreState = ChatSlice & InputSlice & AuthSlice & ConfigSlice;
@ -35,13 +36,15 @@ const useStore = create<StoreState>()(
apiEndpoint: state.apiEndpoint,
theme: state.theme,
}),
version: 2,
version: 3,
migrate: (persistedState, version) => {
switch (version) {
case 0:
migrateV0(persistedState as LocalStorageInterfaceV0ToV1);
case 1:
migrateV1(persistedState as LocalStorageInterfaceV1ToV2);
case 2:
migrateV2(persistedState as LocalStorageInterfaceV2ToV3);
break;
}
return persistedState as StoreState;

View file

@ -18,6 +18,8 @@ export interface ChatInterface {
export interface ConfigInterface {
temperature: number;
presence_penalty: number;
top_p: number;
frequency_penalty: number;
}
export interface LocalStorageInterfaceV0ToV1 {
@ -38,3 +40,13 @@ export interface LocalStorageInterfaceV1ToV2 {
apiEndpoint?: string;
theme: Theme;
}
export interface LocalStorageInterfaceV2ToV3 {
chats: ChatInterface[];
currentChatIndex: number;
apiKey: string;
apiFree: boolean;
apiFreeEndpoint: string;
apiEndpoint?: string;
theme: Theme;
}

View file

@ -1,28 +1,53 @@
import html2canvas from 'html2canvas';
import useStore from '@store/store';
import jsPDF from 'jspdf';
import { ChatInterface } from '@type/chat';
import { ChatInterface, ConfigInterface, MessageInterface } from '@type/chat';
import { roles } from '@type/chat';
import { Theme } from '@type/theme';
import { defaultChatConfig } from '@constants/chat';
export const isChats = (chats: any): chats is ChatInterface[] => {
export const validateAndFixChats = (chats: any): chats is ChatInterface[] => {
if (!Array.isArray(chats)) return false;
for (const chat of chats) {
if (!(typeof chat.title === 'string') || chat.title === '') return false;
if (chat.titleSet === undefined) chat.titleSet = false;
if (!(typeof chat.titleSet === 'boolean')) return false;
if (!Array.isArray(chat.messages)) return false;
for (const message of chat.messages) {
if (!validateMessage(chat.messages)) return false;
if (!validateAndFixChatConfig(chat.config)) return false;
}
return true;
};
const validateMessage = (messages: MessageInterface[]) => {
if (!Array.isArray(messages)) return false;
for (const message of messages) {
if (!(typeof message.content === 'string')) return false;
if (!(typeof message.role === 'string')) return false;
if (!roles.includes(message.role)) return false;
}
return true;
};
if (!(typeof chat.config === 'object')) return false;
if (!(typeof chat.config.temperature === 'number')) return false;
if (!(typeof chat.config.presence_penalty === 'number')) return false;
}
const validateAndFixChatConfig = (config: ConfigInterface) => {
if (config === undefined) config = defaultChatConfig;
if (!(typeof config === 'object')) return false;
if (!config.temperature) config.temperature = defaultChatConfig.temperature;
if (!(typeof config.temperature === 'number')) return false;
if (!config.presence_penalty)
config.presence_penalty = defaultChatConfig.presence_penalty;
if (!(typeof config.presence_penalty === 'number')) return false;
if (!config.top_p) config.top_p = defaultChatConfig.top_p;
if (!(typeof config.top_p === 'number')) return false;
if (!config.frequency_penalty)
config.frequency_penalty = defaultChatConfig.frequency_penalty;
if (!(typeof config.frequency_penalty === 'number')) return false;
return true;
};