2023-08-12 17:49:31 +02:00
|
|
|
"""This module contains the streaming logic for the API."""
|
|
|
|
|
2023-08-04 03:30:56 +02:00
|
|
|
import os
|
2023-08-04 17:29:49 +02:00
|
|
|
import json
|
2023-10-08 00:28:13 +02:00
|
|
|
import ujson
|
2023-08-04 03:30:56 +02:00
|
|
|
import aiohttp
|
2023-10-04 23:24:55 +02:00
|
|
|
import asyncio
|
2023-08-04 03:30:56 +02:00
|
|
|
import starlette
|
|
|
|
|
2023-10-06 09:45:50 +02:00
|
|
|
from typing import Any, Coroutine, Set
|
2023-08-04 17:29:49 +02:00
|
|
|
from rich import print
|
2023-08-04 03:30:56 +02:00
|
|
|
from dotenv import load_dotenv
|
|
|
|
|
|
|
|
import proxies
|
2023-08-28 00:58:32 +02:00
|
|
|
import after_request
|
2023-08-04 03:30:56 +02:00
|
|
|
import load_balancing
|
|
|
|
|
2023-10-02 21:09:39 +02:00
|
|
|
from helpers import errors
|
2023-10-04 23:24:55 +02:00
|
|
|
from db import providerkeys
|
2023-08-04 03:30:56 +02:00
|
|
|
|
|
|
|
load_dotenv()
|
|
|
|
|
2023-10-04 23:24:55 +02:00
|
|
|
CRITICAL_API_ERRORS = ['invalid_api_key', 'account_deactivated']
|
|
|
|
|
|
|
|
keymanager = providerkeys.manager
|
|
|
|
|
2023-10-06 09:45:50 +02:00
|
|
|
background_tasks: Set[asyncio.Task[Any]] = set()
|
|
|
|
|
|
|
|
|
|
|
|
def create_background_task(coro: Coroutine[Any, Any, Any]) -> None:
|
|
|
|
"""asyncio.create_task, which prevents the task from being garbage collected.
|
|
|
|
|
|
|
|
https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
|
|
|
|
"""
|
|
|
|
task = asyncio.create_task(coro)
|
|
|
|
background_tasks.add(task)
|
|
|
|
task.add_done_callback(background_tasks.discard)
|
|
|
|
|
|
|
|
|
2023-09-11 02:47:21 +02:00
|
|
|
async def respond(
|
2023-08-04 03:30:56 +02:00
|
|
|
path: str='/v1/chat/completions',
|
|
|
|
user: dict=None,
|
|
|
|
payload: dict=None,
|
|
|
|
credits_cost: int=0,
|
|
|
|
input_tokens: int=0,
|
|
|
|
incoming_request: starlette.requests.Request=None,
|
|
|
|
):
|
2023-10-12 00:03:15 +02:00
|
|
|
"""
|
|
|
|
Stream the completions request. Sends data in chunks
|
2023-08-14 10:47:03 +02:00
|
|
|
If not streaming, it sends the result in its entirety.
|
2023-08-13 17:12:35 +02:00
|
|
|
"""
|
2023-08-14 10:47:03 +02:00
|
|
|
|
2023-08-04 17:29:49 +02:00
|
|
|
is_chat = False
|
2023-08-04 03:30:56 +02:00
|
|
|
|
2023-08-28 00:58:32 +02:00
|
|
|
model = None
|
|
|
|
|
2023-08-04 17:29:49 +02:00
|
|
|
if 'chat/completions' in path:
|
|
|
|
is_chat = True
|
|
|
|
model = payload['model']
|
2023-08-04 03:30:56 +02:00
|
|
|
|
2023-10-04 23:24:55 +02:00
|
|
|
server_json_response = {}
|
2023-08-04 03:30:56 +02:00
|
|
|
|
2023-08-25 19:13:39 +02:00
|
|
|
headers = {
|
2023-09-23 21:41:48 +02:00
|
|
|
'Content-Type': 'application/json'
|
2023-08-25 19:13:39 +02:00
|
|
|
}
|
|
|
|
|
2023-10-09 19:09:01 +02:00
|
|
|
skipped_errors = {
|
|
|
|
'insufficient_quota': 0,
|
|
|
|
'billing_not_active': 0,
|
|
|
|
'critical_provider_error': 0,
|
|
|
|
'timeout': 0
|
|
|
|
}
|
|
|
|
|
|
|
|
for _ in range(5):
|
2023-08-06 00:43:36 +02:00
|
|
|
try:
|
|
|
|
if is_chat:
|
|
|
|
target_request = await load_balancing.balance_chat_request(payload)
|
|
|
|
else:
|
|
|
|
target_request = await load_balancing.balance_organic_request({
|
|
|
|
'method': incoming_request.method,
|
|
|
|
'path': path,
|
|
|
|
'payload': payload,
|
|
|
|
'headers': headers,
|
|
|
|
'cookies': incoming_request.cookies
|
|
|
|
})
|
2023-10-06 23:05:38 +02:00
|
|
|
|
2023-10-04 23:24:55 +02:00
|
|
|
except ValueError:
|
2023-09-23 21:41:48 +02:00
|
|
|
yield await errors.yield_error(500, f'Sorry, the API has no active API keys for {model}.', 'Please use a different model.')
|
2023-08-06 12:46:41 +02:00
|
|
|
return
|
2023-08-05 02:30:42 +02:00
|
|
|
|
2023-10-04 23:24:55 +02:00
|
|
|
provider_auth = target_request.get('provider_auth')
|
|
|
|
|
|
|
|
if provider_auth:
|
|
|
|
provider_name = provider_auth.split('>')[0]
|
|
|
|
provider_key = provider_auth.split('>')[1]
|
|
|
|
|
2023-10-05 14:17:53 +02:00
|
|
|
if provider_key == '--NO_KEY--':
|
2023-10-06 23:05:38 +02:00
|
|
|
print(f'No key for {provider_name}')
|
2023-10-05 14:17:53 +02:00
|
|
|
yield await errors.yield_error(500,
|
|
|
|
'Sorry, our API seems to have issues connecting to our provider(s).',
|
|
|
|
'This most likely isn\'t your fault. Please try again later.'
|
|
|
|
)
|
|
|
|
return
|
|
|
|
|
2023-08-13 18:26:35 +02:00
|
|
|
target_request['headers'].update(target_request.get('headers', {}))
|
2023-08-09 11:15:49 +02:00
|
|
|
|
|
|
|
if target_request['method'] == 'GET' and not payload:
|
|
|
|
target_request['payload'] = None
|
2023-08-04 17:29:49 +02:00
|
|
|
|
2023-08-12 17:49:31 +02:00
|
|
|
async with aiohttp.ClientSession(connector=proxies.get_proxy().connector) as session:
|
2023-08-05 02:30:42 +02:00
|
|
|
try:
|
|
|
|
async with session.request(
|
|
|
|
method=target_request.get('method', 'POST'),
|
|
|
|
url=target_request['url'],
|
|
|
|
data=target_request.get('data'),
|
|
|
|
json=target_request.get('payload'),
|
2023-08-09 11:15:49 +02:00
|
|
|
headers=target_request.get('headers', {}),
|
2023-08-05 02:30:42 +02:00
|
|
|
cookies=target_request.get('cookies'),
|
|
|
|
ssl=False,
|
2023-08-16 15:06:16 +02:00
|
|
|
timeout=aiohttp.ClientTimeout(
|
2023-10-08 23:56:32 +02:00
|
|
|
connect=0.75,
|
2023-09-10 16:22:46 +02:00
|
|
|
total=float(os.getenv('TRANSFER_TIMEOUT', '500'))
|
2023-10-05 14:17:53 +02:00
|
|
|
)
|
2023-08-05 02:30:42 +02:00
|
|
|
) as response:
|
2023-09-14 18:18:19 +02:00
|
|
|
is_stream = response.content_type == 'text/event-stream'
|
2023-08-28 00:58:32 +02:00
|
|
|
|
2023-08-09 11:15:49 +02:00
|
|
|
if response.content_type == 'application/json':
|
2023-10-04 23:24:55 +02:00
|
|
|
client_json_response = await response.json()
|
2023-08-06 21:42:07 +02:00
|
|
|
|
2023-10-06 23:05:38 +02:00
|
|
|
try:
|
|
|
|
error_code = client_json_response['error']['code']
|
|
|
|
except KeyError:
|
|
|
|
error_code = ''
|
|
|
|
|
|
|
|
if error_code == 'method_not_supported':
|
|
|
|
yield await errors.yield_error(400, 'Sorry, this endpoint does not support this method.', 'Please use a different method.')
|
|
|
|
|
|
|
|
if error_code == 'insufficient_quota':
|
|
|
|
print('[!] insufficient quota')
|
|
|
|
await keymanager.rate_limit_key(provider_name, provider_key, 86400)
|
2023-10-09 19:09:01 +02:00
|
|
|
skipped_errors['insufficient_quota'] += 1
|
2023-10-06 23:05:38 +02:00
|
|
|
continue
|
|
|
|
|
|
|
|
if error_code == 'billing_not_active':
|
|
|
|
print('[!] billing not active')
|
|
|
|
await keymanager.deactivate_key(provider_name, provider_key, 'billing_not_active')
|
2023-10-09 19:09:01 +02:00
|
|
|
skipped_errors['billing_not_active'] += 1
|
2023-10-06 23:05:38 +02:00
|
|
|
continue
|
2023-09-06 11:44:29 +02:00
|
|
|
|
2023-10-04 23:24:55 +02:00
|
|
|
critical_error = False
|
|
|
|
for error in CRITICAL_API_ERRORS:
|
|
|
|
if error in str(client_json_response):
|
|
|
|
await keymanager.deactivate_key(provider_name, provider_key, error)
|
|
|
|
critical_error = True
|
2023-10-09 19:09:01 +02:00
|
|
|
|
2023-10-04 23:24:55 +02:00
|
|
|
if critical_error:
|
2023-10-09 19:09:01 +02:00
|
|
|
print('[!] critical provider error')
|
|
|
|
skipped_errors['critical_provider_error'] += 1
|
2023-08-05 02:30:42 +02:00
|
|
|
continue
|
2023-08-04 03:30:56 +02:00
|
|
|
|
2023-08-09 11:15:49 +02:00
|
|
|
if response.ok:
|
2023-10-04 23:24:55 +02:00
|
|
|
server_json_response = client_json_response
|
2023-08-09 11:15:49 +02:00
|
|
|
|
2023-08-06 00:43:36 +02:00
|
|
|
if is_stream:
|
2023-10-08 21:53:27 +02:00
|
|
|
chunk_no = 0
|
2023-10-08 23:56:32 +02:00
|
|
|
buffer = ''
|
|
|
|
|
|
|
|
async for chunk in response.content.iter_chunked(1024):
|
2023-10-08 21:53:27 +02:00
|
|
|
chunk_no += 1
|
2023-10-08 23:56:32 +02:00
|
|
|
|
|
|
|
chunk = chunk.decode('utf8')
|
2023-10-08 00:28:13 +02:00
|
|
|
|
|
|
|
if 'azure' in provider_name:
|
2023-10-09 19:09:01 +02:00
|
|
|
chunk = chunk.replace('data: ', '', 1)
|
2023-10-08 00:28:13 +02:00
|
|
|
|
2023-10-08 21:53:27 +02:00
|
|
|
if not chunk or chunk_no == 1:
|
2023-10-08 00:28:13 +02:00
|
|
|
continue
|
|
|
|
|
2023-10-08 23:56:32 +02:00
|
|
|
subchunks = chunk.split('\n\n')
|
|
|
|
buffer += subchunks[0]
|
|
|
|
|
2023-10-09 19:09:01 +02:00
|
|
|
for subchunk in [buffer] + subchunks[1:-1]:
|
|
|
|
if not subchunk.startswith('data: '):
|
|
|
|
subchunk = 'data: ' + subchunk
|
2023-10-08 23:56:32 +02:00
|
|
|
|
|
|
|
yield subchunk + '\n\n'
|
2023-08-04 17:29:49 +02:00
|
|
|
|
2023-10-09 19:09:01 +02:00
|
|
|
buffer = subchunks[-1]
|
2023-08-05 02:30:42 +02:00
|
|
|
break
|
2023-08-04 17:29:49 +02:00
|
|
|
|
2023-10-08 23:56:32 +02:00
|
|
|
except aiohttp.client_exceptions.ServerTimeoutError:
|
2023-10-09 19:09:01 +02:00
|
|
|
skipped_errors['timeout'] += 1
|
2023-10-08 23:56:32 +02:00
|
|
|
continue
|
2023-08-27 04:29:16 +02:00
|
|
|
|
2023-09-06 11:44:29 +02:00
|
|
|
else:
|
2023-10-09 19:09:01 +02:00
|
|
|
skipped_errors = {k: v for k, v in skipped_errors.items() if v > 0}
|
|
|
|
skipped_errors = ujson.dumps(skipped_errors, indent=4)
|
|
|
|
yield await errors.yield_error(500,
|
|
|
|
'Sorry, our API seems to have issues connecting to our provider(s).',
|
|
|
|
f'Please send this info to support: {skipped_errors}'
|
|
|
|
)
|
2023-09-06 11:44:29 +02:00
|
|
|
return
|
2023-08-24 14:57:36 +02:00
|
|
|
|
2023-10-04 23:24:55 +02:00
|
|
|
if (not is_stream) and server_json_response:
|
|
|
|
yield json.dumps(server_json_response)
|
|
|
|
|
2023-10-06 09:45:50 +02:00
|
|
|
create_background_task(
|
2023-10-04 23:24:55 +02:00
|
|
|
after_request.after_request(
|
|
|
|
incoming_request=incoming_request,
|
|
|
|
target_request=target_request,
|
|
|
|
user=user,
|
|
|
|
credits_cost=credits_cost,
|
|
|
|
input_tokens=input_tokens,
|
|
|
|
path=path,
|
|
|
|
is_chat=is_chat,
|
|
|
|
model=model,
|
|
|
|
)
|
2023-08-28 00:58:32 +02:00
|
|
|
)
|