"""This module contains the streaming logic for the API.""" import os import json import logging import aiohttp import asyncio import starlette from typing import Set from rich import print from dotenv import load_dotenv import proxies import after_request import load_balancing from helpers import errors from db import providerkeys load_dotenv() CRITICAL_API_ERRORS = ['invalid_api_key', 'account_deactivated'] keymanager = providerkeys.manager background_tasks: Set[asyncio.Task] = set() async def respond( path: str='/v1/chat/completions', user: dict=None, payload: dict=None, credits_cost: int=0, input_tokens: int=0, incoming_request: starlette.requests.Request=None, ): """Stream the completions request. Sends data in chunks If not streaming, it sends the result in its entirety. """ is_chat = False model = None if 'chat/completions' in path: is_chat = True model = payload['model'] server_json_response = {} headers = { 'Content-Type': 'application/json' } for _ in range(20): # Load balancing: randomly selecting a suitable provider try: if is_chat: target_request = await load_balancing.balance_chat_request(payload) else: target_request = await load_balancing.balance_organic_request({ 'method': incoming_request.method, 'path': path, 'payload': payload, 'headers': headers, 'cookies': incoming_request.cookies }) except ValueError: yield await errors.yield_error(500, f'Sorry, the API has no active API keys for {model}.', 'Please use a different model.') return provider_auth = target_request.get('provider_auth') if provider_auth: provider_name = provider_auth.split('>')[0] provider_key = provider_auth.split('>')[1] if provider_key == '--NO_KEY--': yield await errors.yield_error(500, 'Sorry, our API seems to have issues connecting to our provider(s).', 'This most likely isn\'t your fault. Please try again later.' ) return target_request['headers'].update(target_request.get('headers', {})) if target_request['method'] == 'GET' and not payload: target_request['payload'] = None async with aiohttp.ClientSession(connector=proxies.get_proxy().connector) as session: try: async with session.request( method=target_request.get('method', 'POST'), url=target_request['url'], data=target_request.get('data'), json=target_request.get('payload'), headers=target_request.get('headers', {}), cookies=target_request.get('cookies'), ssl=False, timeout=aiohttp.ClientTimeout( connect=1.0, total=float(os.getenv('TRANSFER_TIMEOUT', '500')) ) ) as response: is_stream = response.content_type == 'text/event-stream' if response.status == 429: print('[!] rate limit') # await keymanager.rate_limit_key(provider_name, provider_key) continue if response.content_type == 'application/json': client_json_response = await response.json() if 'method_not_supported' in str(client_json_response): await errors.error(500, 'Sorry, this endpoint does not support this method.', data['error']['message']) critical_error = False for error in CRITICAL_API_ERRORS: if error in str(client_json_response): await keymanager.deactivate_key(provider_name, provider_key, error) critical_error = True if critical_error: print('[!] critical error') continue if response.ok: server_json_response = client_json_response else: print('[!] non-ok response', client_json_response) continue if is_stream: try: response.raise_for_status() except Exception as exc: if 'Too Many Requests' in str(exc): print('[!] too many requests') continue async for chunk in response.content.iter_any(): chunk = chunk.decode('utf8').strip() yield chunk + '\n\n' break except Exception as exc: print('[!] exception', exc) continue else: yield await errors.yield_error(500, 'Sorry, our API seems to have issues connecting to our provider(s).', 'This most likely isn\'t your fault. Please try again later.') return if (not is_stream) and server_json_response: yield json.dumps(server_json_response) task = asyncio.create_task( after_request.after_request( incoming_request=incoming_request, target_request=target_request, user=user, credits_cost=credits_cost, input_tokens=input_tokens, path=path, is_chat=is_chat, model=model, ) ) background_tasks.add(task) task.add_done_callback(background_tasks.discard)