diff --git a/src/api/api.ts b/src/api/api.ts index 5948796..153de0f 100644 --- a/src/api/api.ts +++ b/src/api/api.ts @@ -1,16 +1,20 @@ import { ShareGPTSubmitBodyInterface } from '@type/api'; import { ConfigInterface, MessageInterface } from '@type/chat'; +import { isAzureEndpoint } from '@utils/api'; export const getChatCompletion = async ( endpoint: string, messages: MessageInterface[], config: ConfigInterface, - apiKey?: string + apiKey?: string, + customHeaders?: Record ) => { const headers: HeadersInit = { 'Content-Type': 'application/json', + ...customHeaders, }; if (apiKey) headers.Authorization = `Bearer ${apiKey}`; + if (isAzureEndpoint(endpoint) && apiKey) headers['api-key'] = apiKey; const response = await fetch(endpoint, { method: 'POST', @@ -31,12 +35,15 @@ export const getChatCompletionStream = async ( endpoint: string, messages: MessageInterface[], config: ConfigInterface, - apiKey?: string + apiKey?: string, + customHeaders?: Record ) => { const headers: HeadersInit = { 'Content-Type': 'application/json', + ...customHeaders, }; if (apiKey) headers.Authorization = `Bearer ${apiKey}`; + if (isAzureEndpoint(endpoint) && apiKey) headers['api-key'] = apiKey; const response = await fetch(endpoint, { method: 'POST', diff --git a/src/utils/api.ts b/src/utils/api.ts new file mode 100644 index 0000000..d66c22a --- /dev/null +++ b/src/utils/api.ts @@ -0,0 +1,3 @@ +export const isAzureEndpoint = (endpoint: string) => { + return endpoint.includes('openai.azure.com'); +};