"""This module contains the streaming logic for the API.""" import os import json import ujson import aiohttp import asyncio import starlette from rich import print from dotenv import load_dotenv import proxies import after_request import load_balancing from helpers import errors from db import providerkeys load_dotenv() CRITICAL_API_ERRORS = ['invalid_api_key', 'account_deactivated'] keymanager = providerkeys.manager async def respond( path: str='/v1/chat/completions', user: dict=None, payload: dict=None, credits_cost: int=0, input_tokens: int=0, incoming_request: starlette.requests.Request=None, ): """Stream the completions request. Sends data in chunks If not streaming, it sends the result in its entirety. """ is_chat = False model = None if 'chat/completions' in path: is_chat = True model = payload['model'] server_json_response = {} headers = { 'Content-Type': 'application/json' } skipped_errors = { 'insufficient_quota': 0, 'billing_not_active': 0, 'critical_provider_error': 0, 'timeout': 0 } for _ in range(5): try: if is_chat: target_request = await load_balancing.balance_chat_request(payload) else: target_request = await load_balancing.balance_organic_request({ 'method': incoming_request.method, 'path': path, 'payload': payload, 'headers': headers, 'cookies': incoming_request.cookies }) except ValueError: yield await errors.yield_error(500, f'Sorry, the API has no active API keys for {model}.', 'Please use a different model.') return provider_auth = target_request.get('provider_auth') if provider_auth: provider_name = provider_auth.split('>')[0] provider_key = provider_auth.split('>')[1] if provider_key == '--NO_KEY--': print(f'No key for {provider_name}') yield await errors.yield_error(500, 'Sorry, our API seems to have issues connecting to our provider(s).', 'This most likely isn\'t your fault. Please try again later.' ) return target_request['headers'].update(target_request.get('headers', {})) if target_request['method'] == 'GET' and not payload: target_request['payload'] = None async with aiohttp.ClientSession(connector=proxies.get_proxy().connector) as session: try: async with session.request( method=target_request.get('method', 'POST'), url=target_request['url'], data=target_request.get('data'), json=target_request.get('payload'), headers=target_request.get('headers', {}), cookies=target_request.get('cookies'), ssl=False, timeout=aiohttp.ClientTimeout( connect=0.75, total=float(os.getenv('TRANSFER_TIMEOUT', '500')) ) ) as response: is_stream = response.content_type == 'text/event-stream' if response.content_type == 'application/json': client_json_response = await response.json() try: error_code = client_json_response['error']['code'] except KeyError: error_code = '' if error_code == 'method_not_supported': yield await errors.yield_error(400, 'Sorry, this endpoint does not support this method.', 'Please use a different method.') if error_code == 'insufficient_quota': print('[!] insufficient quota') await keymanager.rate_limit_key(provider_name, provider_key, 86400) skipped_errors['insufficient_quota'] += 1 continue if error_code == 'billing_not_active': print('[!] billing not active') await keymanager.deactivate_key(provider_name, provider_key, 'billing_not_active') skipped_errors['billing_not_active'] += 1 continue critical_error = False for error in CRITICAL_API_ERRORS: if error in str(client_json_response): await keymanager.deactivate_key(provider_name, provider_key, error) critical_error = True if critical_error: print('[!] critical provider error') skipped_errors['critical_provider_error'] += 1 continue if response.ok: server_json_response = client_json_response if is_stream: chunk_no = 0 buffer = '' async for chunk in response.content.iter_chunked(1024): chunk_no += 1 chunk = chunk.decode('utf8') if 'azure' in provider_name: chunk = chunk.replace('data: ', '', 1) if not chunk or chunk_no == 1: continue subchunks = chunk.split('\n\n') buffer += subchunks[0] for subchunk in [buffer] + subchunks[1:-1]: if not subchunk.startswith('data: '): subchunk = 'data: ' + subchunk yield subchunk + '\n\n' buffer = subchunks[-1] break except aiohttp.client_exceptions.ServerTimeoutError: skipped_errors['timeout'] += 1 continue else: skipped_errors = {k: v for k, v in skipped_errors.items() if v > 0} skipped_errors = ujson.dumps(skipped_errors, indent=4) yield await errors.yield_error(500, 'Sorry, our API seems to have issues connecting to our provider(s).', f'Please send this info to support: {skipped_errors}' ) return if (not is_stream) and server_json_response: yield json.dumps(server_json_response) asyncio.create_task( after_request.after_request( incoming_request=incoming_request, target_request=target_request, user=user, credits_cost=credits_cost, input_tokens=input_tokens, path=path, is_chat=is_chat, model=model, ) )