nova-api/api/responder.py

292 lines
10 KiB
Python
Raw Normal View History

2023-08-12 17:49:31 +02:00
"""This module contains the streaming logic for the API."""
2023-08-04 03:30:56 +02:00
import os
2023-08-04 17:29:49 +02:00
import json
2023-10-16 23:34:54 +02:00
import yaml
2023-10-08 00:28:13 +02:00
import ujson
2023-08-04 03:30:56 +02:00
import aiohttp
2023-10-04 23:24:55 +02:00
import asyncio
2023-08-04 03:30:56 +02:00
import starlette
2023-10-06 09:45:50 +02:00
from typing import Any, Coroutine, Set
2023-08-04 17:29:49 +02:00
from rich import print
2023-08-04 03:30:56 +02:00
from dotenv import load_dotenv
import proxies
import after_request
2023-08-04 03:30:56 +02:00
import load_balancing
2023-10-02 21:09:39 +02:00
from helpers import errors
2023-10-04 23:24:55 +02:00
from db import providerkeys
2023-10-16 23:34:54 +02:00
from helpers.tokens import count_tokens_for_messages
2023-08-04 03:30:56 +02:00
load_dotenv()
2023-11-07 00:56:43 +01:00
RETRIES = 10
2023-10-04 23:24:55 +02:00
CRITICAL_API_ERRORS = ['invalid_api_key', 'account_deactivated']
keymanager = providerkeys.manager
2023-10-06 09:45:50 +02:00
background_tasks: Set[asyncio.Task[Any]] = set()
2023-10-16 23:34:54 +02:00
with open(os.path.join('config', 'config.yml'), encoding='utf8') as f:
config = yaml.safe_load(f)
2023-10-06 09:45:50 +02:00
def create_background_task(coro: Coroutine[Any, Any, Any]) -> None:
2023-11-07 00:56:43 +01:00
"""Utilizes asyncio.create_task, which prevents the task from being garbage collected.
2023-10-06 09:45:50 +02:00
https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
"""
task = asyncio.create_task(coro)
background_tasks.add(task)
task.add_done_callback(background_tasks.discard)
2023-09-11 02:47:21 +02:00
async def respond(
2023-08-04 03:30:56 +02:00
path: str='/v1/chat/completions',
user: dict=None,
payload: dict=None,
2023-11-07 00:56:43 +01:00
incoming_request=None,
overwrite_method=None
2023-08-04 03:30:56 +02:00
):
2023-10-12 00:03:15 +02:00
"""
Stream the completions request. Sends data in chunks
2023-08-14 10:47:03 +02:00
If not streaming, it sends the result in its entirety.
2023-08-13 17:12:35 +02:00
"""
2023-08-14 10:47:03 +02:00
2023-08-04 17:29:49 +02:00
is_chat = False
2023-08-04 03:30:56 +02:00
model = None
2023-08-04 17:29:49 +02:00
if 'chat/completions' in path:
is_chat = True
model = payload['model']
2023-08-04 03:30:56 +02:00
2023-10-04 23:24:55 +02:00
server_json_response = {}
2023-08-04 03:30:56 +02:00
2023-08-25 19:13:39 +02:00
headers = {
'Content-Type': 'application/json'
2023-08-25 19:13:39 +02:00
}
2023-10-09 19:09:01 +02:00
skipped_errors = {
'no_provider_key': 0,
2023-10-09 19:09:01 +02:00
'insufficient_quota': 0,
'billing_not_active': 0,
'critical_provider_error': 0,
2023-10-25 21:09:49 +02:00
'timeout': 0,
'other_errors': []
2023-10-09 19:09:01 +02:00
}
2023-10-16 23:34:54 +02:00
input_tokens = 0
output_tokens = 0
2023-11-07 00:56:43 +01:00
if incoming_request:
cookies = incoming_request.cookies
else:
cookies = {}
if overwrite_method:
method = overwrite_method
else:
method = incoming_request.method
for _ in range(RETRIES):
try:
if is_chat:
target_request = await load_balancing.balance_chat_request(payload)
else:
target_request = await load_balancing.balance_organic_request({
2023-11-07 00:56:43 +01:00
'method': method,
'path': path,
'payload': payload,
'headers': headers,
2023-11-07 00:56:43 +01:00
'cookies': cookies
})
2023-10-06 23:05:38 +02:00
2023-10-04 23:24:55 +02:00
except ValueError:
yield await errors.yield_error(500, f'Sorry, the API has no active API keys for {model}.', 'Please use a different model.')
2023-08-06 12:46:41 +02:00
return
2023-08-05 02:30:42 +02:00
2023-10-04 23:24:55 +02:00
provider_auth = target_request.get('provider_auth')
if provider_auth:
provider_name = provider_auth.split('>')[0]
provider_key = provider_auth.split('>')[1]
if provider_key == '--NO_KEY--':
skipped_errors['no_provider_key'] += 1
continue
target_request['headers'].update(target_request.get('headers', {}))
if target_request['method'] == 'GET' and not payload:
target_request['payload'] = None
2023-08-04 17:29:49 +02:00
2023-10-15 22:35:18 +02:00
connector = None
if os.getenv('PROXY_HOST') or os.getenv('USE_PROXY_LIST', 'False').lower() == 'true':
connector = proxies.get_proxy().connector
async with aiohttp.ClientSession(connector=connector) as session:
2023-08-05 02:30:42 +02:00
try:
async with session.request(
method=target_request.get('method', 'POST'),
url=target_request['url'],
data=target_request.get('data'),
json=target_request.get('payload'),
headers=target_request.get('headers', {}),
2023-08-05 02:30:42 +02:00
cookies=target_request.get('cookies'),
ssl=False,
2023-08-16 15:06:16 +02:00
timeout=aiohttp.ClientTimeout(
connect=0.75,
2023-09-10 16:22:46 +02:00
total=float(os.getenv('TRANSFER_TIMEOUT', '500'))
)
2023-08-05 02:30:42 +02:00
) as response:
2023-09-14 18:18:19 +02:00
is_stream = response.content_type == 'text/event-stream'
if response.content_type == 'application/json':
2023-10-04 23:24:55 +02:00
client_json_response = await response.json()
2023-08-06 21:42:07 +02:00
2023-10-06 23:05:38 +02:00
try:
error_code = client_json_response['error']['code']
except KeyError:
error_code = ''
if error_code == 'method_not_supported':
yield await errors.yield_error(400, 'Sorry, this endpoint does not support this method.', 'Please use a different method.')
if error_code == 'insufficient_quota':
print('[!] insufficient quota')
await keymanager.rate_limit_key(provider_name, provider_key, 86400)
2023-10-09 19:09:01 +02:00
skipped_errors['insufficient_quota'] += 1
2023-10-06 23:05:38 +02:00
continue
if error_code == 'billing_not_active':
print('[!] billing not active')
await keymanager.deactivate_key(provider_name, provider_key, 'billing_not_active')
2023-10-09 19:09:01 +02:00
skipped_errors['billing_not_active'] += 1
2023-10-06 23:05:38 +02:00
continue
2023-09-06 11:44:29 +02:00
2023-10-04 23:24:55 +02:00
critical_error = False
for error in CRITICAL_API_ERRORS:
if error in str(client_json_response):
await keymanager.deactivate_key(provider_name, provider_key, error)
critical_error = True
2023-10-09 19:09:01 +02:00
2023-10-04 23:24:55 +02:00
if critical_error:
2023-10-09 19:09:01 +02:00
print('[!] critical provider error')
skipped_errors['critical_provider_error'] += 1
2023-08-05 02:30:42 +02:00
continue
2023-08-04 03:30:56 +02:00
if response.ok:
2023-10-16 23:34:54 +02:00
if is_chat and not is_stream:
input_tokens = client_json_response['usage']['prompt_tokens']
output_tokens = client_json_response['usage']['completion_tokens']
2023-10-04 23:24:55 +02:00
server_json_response = client_json_response
2023-10-25 21:09:49 +02:00
elif response.content_type == 'text/plain':
data = (await response.read()).decode("utf-8")
print(f'[!] {data}')
skipped_errors['other_errors'] = skipped_errors['other_errors'].append(data)
continue
if is_stream:
2023-10-16 23:34:54 +02:00
input_tokens = await count_tokens_for_messages(payload['messages'], model=model)
2023-10-08 21:53:27 +02:00
chunk_no = 0
buffer = ''
2023-11-07 00:56:43 +01:00
async for chunk in response.content.iter_any():
2023-10-08 21:53:27 +02:00
chunk_no += 1
chunk = chunk.decode('utf8')
2023-10-08 00:28:13 +02:00
if 'azure' in provider_name:
2023-10-09 19:09:01 +02:00
chunk = chunk.replace('data: ', '', 1)
2023-10-08 00:28:13 +02:00
2023-10-16 23:34:54 +02:00
if not chunk.strip() or chunk_no == 1:
2023-10-08 00:28:13 +02:00
continue
2023-11-07 00:56:43 +01:00
buffer += chunk
while '\n\n' in buffer:
subchunk, buffer = buffer.split('\n\n', 1)
if not subchunk.strip():
continue
2023-10-09 19:09:01 +02:00
if not subchunk.startswith('data: '):
subchunk = 'data: ' + subchunk
2023-11-07 00:56:43 +01:00
subchunk = subchunk.rsplit('[DONE]', 1)[0]
subchunk += '\n\n'
yield subchunk
2023-08-04 17:29:49 +02:00
2023-10-16 23:34:54 +02:00
output_tokens = chunk_no
2023-08-05 02:30:42 +02:00
break
2023-08-04 17:29:49 +02:00
except aiohttp.client_exceptions.ServerTimeoutError:
2023-10-09 19:09:01 +02:00
skipped_errors['timeout'] += 1
continue
2023-09-06 11:44:29 +02:00
else:
2023-11-07 00:56:43 +01:00
skipped_errors = {k: v for k, v in skipped_errors.items() if ((isinstance(v, int) and v > 0) or (isinstance(v, list) and len(v) > 0))}
skipped_errors['model'] = model
skipped_errors['provider'] = provider_name
print(f'[!] Skipped {RETRIES} errors:\n{skipped_errors}')
2023-10-09 19:09:01 +02:00
skipped_errors = ujson.dumps(skipped_errors, indent=4)
yield await errors.yield_error(500,
2023-10-16 23:34:54 +02:00
f'Sorry, our API seems to have issues connecting to "{model}".',
2023-10-09 19:09:01 +02:00
f'Please send this info to support: {skipped_errors}'
)
2023-09-06 11:44:29 +02:00
return
2023-08-24 14:57:36 +02:00
2023-10-04 23:24:55 +02:00
if (not is_stream) and server_json_response:
server_json_response['system_fingerprint'] = f'fp_' + os.urandom(5).hex()
2023-10-04 23:24:55 +02:00
yield json.dumps(server_json_response)
2023-11-07 00:56:43 +01:00
if incoming_request: # not called by other code, but actually a request
role = user.get('role', 'default')
model_multipliers = config['costs']
model_multiplier = model_multipliers['other']
2023-10-16 23:34:54 +02:00
2023-11-07 00:56:43 +01:00
if is_chat:
model_multiplier = model_multipliers['chat-models'].get(payload.get('model'), model_multiplier)
total_tokens = input_tokens + output_tokens
credits_cost = total_tokens / 60
credits_cost = round(credits_cost * model_multiplier)
2023-10-16 23:34:54 +02:00
2023-11-07 00:56:43 +01:00
if credits_cost < 1:
credits_cost = 1
2023-10-16 23:34:54 +02:00
2023-11-07 00:56:43 +01:00
tokens = {'input': input_tokens, 'output': output_tokens, 'total': total_tokens}
2023-10-16 23:55:54 +02:00
2023-11-07 00:56:43 +01:00
elif model == 'dall-e-2':
credits_cost = 50
tokens = {'input': 0,'output': 0,'total': credits_cost}
elif model == 'dall-e-3':
credits_cost = 100
tokens = {'input': 0, 'output': 0, 'total': credits_cost}
try:
role_cost_multiplier = config['roles'][role]['bonus']
except KeyError:
role_cost_multiplier = 1
credits_cost = round(credits_cost * role_cost_multiplier)
create_background_task(
after_request.after_request(
provider=provider_name,
incoming_request=incoming_request,
target_request=target_request,
user=user,
credits_cost=credits_cost,
tokens=tokens,
path=path,
is_chat=is_chat,
model=model,
)
2023-10-04 23:24:55 +02:00
)