From a009be99869a471fb49c1d924195529ad938e07b Mon Sep 17 00:00:00 2001
From: Jing Hua <59118459+ztjhz@users.noreply.github.com>
Date: Mon, 6 Mar 2023 22:22:05 +0800
Subject: [PATCH] feat: model parameters customisation (#31)
issue #14
---
src/api/customApi.ts | 10 ++-
src/api/freeApi.ts | 14 +++-
src/components/Chat/ChatContent/ChatTitle.tsx | 76 +++++++++++++++--
.../ChatContent/Message/NewMessageButton.tsx | 10 ++-
src/components/ConfigMenu/ConfigMenu.tsx | 83 +++++++++++++++++++
src/components/ConfigMenu/index.ts | 1 +
src/components/PopupModal/PopupModal.tsx | 2 +-
src/constants/chat.ts | 7 ++
src/hooks/useAddChat.ts | 3 +-
src/hooks/useInitialiseNewChat.ts | 3 +-
src/hooks/useSubmit.ts | 6 +-
src/main.css | 8 ++
src/types/chat.ts | 6 ++
13 files changed, 209 insertions(+), 20 deletions(-)
create mode 100644 src/components/ConfigMenu/ConfigMenu.tsx
create mode 100644 src/components/ConfigMenu/index.ts
diff --git a/src/api/customApi.ts b/src/api/customApi.ts
index 98bb65f..96a6cde 100644
--- a/src/api/customApi.ts
+++ b/src/api/customApi.ts
@@ -1,4 +1,4 @@
-import { MessageInterface } from '@type/chat';
+import { ConfigInterface, MessageInterface } from '@type/chat';
export const endpoint = 'https://api.openai.com/v1/chat/completions';
@@ -23,7 +23,8 @@ export const validateApiKey = async (apiKey: string) => {
export const getChatCompletion = async (
apiKey: string,
- messages: MessageInterface[]
+ messages: MessageInterface[],
+ config: ConfigInterface
) => {
const response = await fetch(endpoint, {
method: 'POST',
@@ -34,6 +35,7 @@ export const getChatCompletion = async (
body: JSON.stringify({
model: 'gpt-3.5-turbo',
messages,
+ ...config,
}),
});
if (!response.ok) throw new Error(await response.text());
@@ -44,7 +46,8 @@ export const getChatCompletion = async (
export const getChatCompletionStream = async (
apiKey: string,
- messages: MessageInterface[]
+ messages: MessageInterface[],
+ config: ConfigInterface
) => {
const response = await fetch(endpoint, {
method: 'POST',
@@ -55,6 +58,7 @@ export const getChatCompletionStream = async (
body: JSON.stringify({
model: 'gpt-3.5-turbo',
messages,
+ ...config,
stream: true,
}),
});
diff --git a/src/api/freeApi.ts b/src/api/freeApi.ts
index 2835f79..67dbf15 100644
--- a/src/api/freeApi.ts
+++ b/src/api/freeApi.ts
@@ -1,8 +1,11 @@
-import { MessageInterface } from '@type/chat';
+import { ConfigInterface, MessageInterface } from '@type/chat';
export const endpoint = 'https://chatgpt-api.shn.hk/v1/';
-export const getChatCompletion = async (messages: MessageInterface[]) => {
+export const getChatCompletion = async (
+ messages: MessageInterface[],
+ config: ConfigInterface
+) => {
const response = await fetch(endpoint, {
method: 'POST',
headers: {
@@ -11,6 +14,7 @@ export const getChatCompletion = async (messages: MessageInterface[]) => {
body: JSON.stringify({
model: 'gpt-3.5-turbo',
messages,
+ ...config,
}),
});
if (!response.ok) throw new Error(await response.text());
@@ -19,7 +23,10 @@ export const getChatCompletion = async (messages: MessageInterface[]) => {
return data;
};
-export const getChatCompletionStream = async (messages: MessageInterface[]) => {
+export const getChatCompletionStream = async (
+ messages: MessageInterface[],
+ config: ConfigInterface
+) => {
const response = await fetch(endpoint, {
method: 'POST',
headers: {
@@ -28,6 +35,7 @@ export const getChatCompletionStream = async (messages: MessageInterface[]) => {
body: JSON.stringify({
model: 'gpt-3.5-turbo',
messages,
+ ...config,
stream: true,
}),
});
diff --git a/src/components/Chat/ChatContent/ChatTitle.tsx b/src/components/Chat/ChatContent/ChatTitle.tsx
index 7acab42..b8f3691 100644
--- a/src/components/Chat/ChatContent/ChatTitle.tsx
+++ b/src/components/Chat/ChatContent/ChatTitle.tsx
@@ -1,11 +1,73 @@
-import React from 'react';
+import React, { useEffect, useState } from 'react';
+import { shallow } from 'zustand/shallow';
+import useStore from '@store/store';
+import ConfigMenu from '@components/ConfigMenu';
+import { ChatInterface, ConfigInterface } from '@type/chat';
+import { defaultChatConfig } from '@constants/chat';
-const ChatTitle = () => {
- return (
-
- Model: Default
-
+const ChatTitle = React.memo(() => {
+ const config = useStore(
+ (state) =>
+ state.chats &&
+ state.chats.length > 0 &&
+ state.currentChatIndex >= 0 &&
+ state.currentChatIndex < state.chats.length
+ ? state.chats[state.currentChatIndex].config
+ : undefined,
+ shallow
);
-};
+ const setChats = useStore((state) => state.setChats);
+ const currentChatIndex = useStore((state) => state.currentChatIndex);
+ const [isModalOpen, setIsModalOpen] = useState(false);
+
+ const setConfig = (config: ConfigInterface) => {
+ const updatedChats: ChatInterface[] = JSON.parse(
+ JSON.stringify(useStore.getState().chats)
+ );
+ updatedChats[currentChatIndex].config = config;
+ setChats(updatedChats);
+ };
+
+ // for migrating from old ChatInterface to new ChatInterface (with config)
+ useEffect(() => {
+ if (!config) {
+ const updatedChats: ChatInterface[] = JSON.parse(
+ JSON.stringify(useStore.getState().chats)
+ );
+ updatedChats[currentChatIndex].config = { ...defaultChatConfig };
+ setChats(updatedChats);
+ }
+ }, [currentChatIndex]);
+
+ return config ? (
+ <>
+ {
+ setIsModalOpen(true);
+ }}
+ >
+
+ Model: Default
+
+
+ Temperature: {config.temperature}
+
+
+ PresencePenalty: {config.presence_penalty}
+
+
+ {isModalOpen && (
+
+ )}
+ >
+ ) : (
+ <>>
+ );
+});
export default ChatTitle;
diff --git a/src/components/Chat/ChatContent/Message/NewMessageButton.tsx b/src/components/Chat/ChatContent/Message/NewMessageButton.tsx
index ef36c5a..9df6d28 100644
--- a/src/components/Chat/ChatContent/Message/NewMessageButton.tsx
+++ b/src/components/Chat/ChatContent/Message/NewMessageButton.tsx
@@ -4,7 +4,7 @@ import useStore from '@store/store';
import PlusIcon from '@icon/PlusIcon';
import { ChatInterface } from '@type/chat';
-import { defaultSystemMessage } from '@constants/chat';
+import { defaultChatConfig, defaultSystemMessage } from '@constants/chat';
const NewMessageButton = React.memo(
({ messageIndex }: { messageIndex: number }) => {
@@ -26,7 +26,13 @@ const NewMessageButton = React.memo(
updatedChats.unshift({
title,
- messages: [{ role: 'system', content: defaultSystemMessage }],
+ messages: [
+ {
+ role: 'system',
+ content: defaultSystemMessage,
+ },
+ ],
+ config: { ...defaultChatConfig },
});
setChats(updatedChats);
setCurrentChatIndex(0);
diff --git a/src/components/ConfigMenu/ConfigMenu.tsx b/src/components/ConfigMenu/ConfigMenu.tsx
new file mode 100644
index 0000000..f1e2628
--- /dev/null
+++ b/src/components/ConfigMenu/ConfigMenu.tsx
@@ -0,0 +1,83 @@
+import React, { useState } from 'react';
+import PopupModal from '@components/PopupModal';
+import { ConfigInterface } from '@type/chat';
+
+const ConfigMenu = ({
+ setIsModalOpen,
+ config,
+ setConfig,
+}: {
+ setIsModalOpen: React.Dispatch>;
+ config: ConfigInterface;
+ setConfig: (config: ConfigInterface) => void;
+}) => {
+ const [_temperature, _setTemperature] = useState(config.temperature);
+ const [_presencePenalty, _setPresencePenalty] = useState(
+ config.presence_penalty
+ );
+
+ const handleConfirm = () => {
+ setConfig({
+ temperature: _temperature,
+ presence_penalty: _presencePenalty,
+ });
+ setIsModalOpen(false);
+ };
+
+ return (
+
+
+
+
+ Temperature: {_temperature}
+
+
{
+ _setTemperature(Number(e.target.value));
+ }}
+ min={0}
+ max={2}
+ step={0.1}
+ className='w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer'
+ />
+
+ What sampling temperature to use, between 0 and 2. Higher values
+ like 0.8 will make the output more random, while lower values like
+ 0.2 will make it more focused and deterministic.
+
+
+
+
+ Presence Penalty: {_presencePenalty}
+
+
{
+ _setPresencePenalty(Number(e.target.value));
+ }}
+ min={-2}
+ max={2}
+ step={0.1}
+ className='w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer'
+ />
+
+ Number between -2.0 and 2.0. Positive values penalize new tokens
+ based on whether they appear in the text so far, increasing the
+ model's likelihood to talk about new topics.
+
+
+
+
+ );
+};
+
+export default ConfigMenu;
diff --git a/src/components/ConfigMenu/index.ts b/src/components/ConfigMenu/index.ts
new file mode 100644
index 0000000..6497639
--- /dev/null
+++ b/src/components/ConfigMenu/index.ts
@@ -0,0 +1 @@
+export { default } from './ConfigMenu';
diff --git a/src/components/PopupModal/PopupModal.tsx b/src/components/PopupModal/PopupModal.tsx
index b806444..2d412da 100644
--- a/src/components/PopupModal/PopupModal.tsx
+++ b/src/components/PopupModal/PopupModal.tsx
@@ -31,7 +31,7 @@ const PopupModal = ({
return ReactDOM.createPortal(
-
+
{title}
diff --git a/src/constants/chat.ts b/src/constants/chat.ts
index 82f6570..2ea2df8 100644
--- a/src/constants/chat.ts
+++ b/src/constants/chat.ts
@@ -1,3 +1,5 @@
+import { ConfigInterface } from '@type/chat';
+
const date = new Date();
const dateString =
date.getFullYear() +
@@ -10,3 +12,8 @@ const dateString =
export const defaultSystemMessage = `You are ChatGPT, a large language model trained by OpenAI.
Knowledge cutoff: 2021-09
Current date: ${dateString}`;
+
+export const defaultChatConfig: ConfigInterface = {
+ temperature: 1,
+ presence_penalty: 0,
+};
diff --git a/src/hooks/useAddChat.ts b/src/hooks/useAddChat.ts
index f1fe6bf..95045e7 100644
--- a/src/hooks/useAddChat.ts
+++ b/src/hooks/useAddChat.ts
@@ -1,6 +1,6 @@
import React from 'react';
import useStore from '@store/store';
-import { defaultSystemMessage } from '@constants/chat';
+import { defaultChatConfig, defaultSystemMessage } from '@constants/chat';
import { ChatInterface } from '@type/chat';
const useAddChat = () => {
@@ -22,6 +22,7 @@ const useAddChat = () => {
updatedChats.unshift({
title,
messages: [{ role: 'system', content: defaultSystemMessage }],
+ config: { ...defaultChatConfig },
});
setChats(updatedChats);
setCurrentChatIndex(0);
diff --git a/src/hooks/useInitialiseNewChat.ts b/src/hooks/useInitialiseNewChat.ts
index 25fe721..2f1a678 100644
--- a/src/hooks/useInitialiseNewChat.ts
+++ b/src/hooks/useInitialiseNewChat.ts
@@ -1,7 +1,7 @@
import React from 'react';
import useStore from '@store/store';
import { MessageInterface } from '@type/chat';
-import { defaultSystemMessage } from '@constants/chat';
+import { defaultChatConfig, defaultSystemMessage } from '@constants/chat';
const useInitialiseNewChat = () => {
const setChats = useStore((state) => state.setChats);
@@ -16,6 +16,7 @@ const useInitialiseNewChat = () => {
{
title: 'New Chat',
messages: [message],
+ config: { ...defaultChatConfig },
},
]);
setCurrentChatIndex(0);
diff --git a/src/hooks/useSubmit.ts b/src/hooks/useSubmit.ts
index e776748..52e84d1 100644
--- a/src/hooks/useSubmit.ts
+++ b/src/hooks/useSubmit.ts
@@ -33,12 +33,14 @@ const useSubmit = () => {
try {
if (apiFree) {
stream = await getChatCompletionStreamFree(
- chats[currentChatIndex].messages
+ chats[currentChatIndex].messages,
+ chats[currentChatIndex].config
);
} else if (apiKey) {
stream = await getChatCompletionStreamCustom(
apiKey,
- chats[currentChatIndex].messages
+ chats[currentChatIndex].messages,
+ chats[currentChatIndex].config
);
}
diff --git a/src/main.css b/src/main.css
index a987cea..80a9b38 100644
--- a/src/main.css
+++ b/src/main.css
@@ -27,6 +27,14 @@
height: 100%;
}
+ input[type='range']::-webkit-slider-thumb {
+ -webkit-appearance: none;
+ @apply w-4;
+ @apply h-4;
+ @apply rounded-full;
+ background: rgba(16, 163, 127);
+ }
+
::-webkit-scrollbar {
height: 1rem;
width: 0.5rem;
diff --git a/src/types/chat.ts b/src/types/chat.ts
index 651020e..ee2b8b2 100644
--- a/src/types/chat.ts
+++ b/src/types/chat.ts
@@ -9,4 +9,10 @@ export interface MessageInterface {
export interface ChatInterface {
title: string;
messages: MessageInterface[];
+ config: ConfigInterface;
+}
+
+export interface ConfigInterface {
+ temperature: number;
+ presence_penalty: number;
}