2023-08-12 17:49:31 +02:00
""" This module contains the streaming logic for the API. """
2023-08-04 03:30:56 +02:00
import os
2023-08-04 17:29:49 +02:00
import json
2023-10-04 23:24:55 +02:00
import logging
2023-08-04 03:30:56 +02:00
import aiohttp
2023-10-04 23:24:55 +02:00
import asyncio
2023-08-04 03:30:56 +02:00
import starlette
2023-10-06 09:45:50 +02:00
from typing import Any , Coroutine , Set
2023-08-04 17:29:49 +02:00
from rich import print
2023-08-04 03:30:56 +02:00
from dotenv import load_dotenv
import proxies
2023-08-28 00:58:32 +02:00
import after_request
2023-08-04 03:30:56 +02:00
import load_balancing
2023-10-02 21:09:39 +02:00
from helpers import errors
2023-10-04 23:24:55 +02:00
from db import providerkeys
2023-08-04 03:30:56 +02:00
load_dotenv ( )
2023-10-04 23:24:55 +02:00
CRITICAL_API_ERRORS = [ ' invalid_api_key ' , ' account_deactivated ' ]
keymanager = providerkeys . manager
2023-10-06 09:45:50 +02:00
background_tasks : Set [ asyncio . Task [ Any ] ] = set ( )
def create_background_task ( coro : Coroutine [ Any , Any , Any ] ) - > None :
task = asyncio . create_task ( coro )
background_tasks . add ( task )
task . add_done_callback ( background_tasks . discard )
2023-09-11 02:47:21 +02:00
async def respond (
2023-08-04 03:30:56 +02:00
path : str = ' /v1/chat/completions ' ,
user : dict = None ,
payload : dict = None ,
credits_cost : int = 0 ,
input_tokens : int = 0 ,
incoming_request : starlette . requests . Request = None ,
) :
2023-08-13 17:12:35 +02:00
""" Stream the completions request. Sends data in chunks
2023-08-14 10:47:03 +02:00
If not streaming , it sends the result in its entirety .
2023-08-13 17:12:35 +02:00
"""
2023-08-14 10:47:03 +02:00
2023-08-04 17:29:49 +02:00
is_chat = False
2023-08-04 03:30:56 +02:00
2023-08-28 00:58:32 +02:00
model = None
2023-08-04 17:29:49 +02:00
if ' chat/completions ' in path :
is_chat = True
model = payload [ ' model ' ]
2023-08-04 03:30:56 +02:00
2023-10-04 23:24:55 +02:00
server_json_response = { }
2023-08-04 03:30:56 +02:00
2023-08-25 19:13:39 +02:00
headers = {
2023-09-23 21:41:48 +02:00
' Content-Type ' : ' application/json '
2023-08-25 19:13:39 +02:00
}
2023-10-06 23:05:38 +02:00
for i in range ( 20 ) :
print ( i )
2023-08-18 21:23:00 +02:00
# Load balancing: randomly selecting a suitable provider
2023-08-06 00:43:36 +02:00
try :
if is_chat :
target_request = await load_balancing . balance_chat_request ( payload )
else :
target_request = await load_balancing . balance_organic_request ( {
' method ' : incoming_request . method ,
' path ' : path ,
' payload ' : payload ,
' headers ' : headers ,
' cookies ' : incoming_request . cookies
} )
2023-10-06 23:05:38 +02:00
2023-10-04 23:24:55 +02:00
except ValueError :
2023-09-23 21:41:48 +02:00
yield await errors . yield_error ( 500 , f ' Sorry, the API has no active API keys for { model } . ' , ' Please use a different model. ' )
2023-08-06 12:46:41 +02:00
return
2023-08-05 02:30:42 +02:00
2023-10-04 23:24:55 +02:00
provider_auth = target_request . get ( ' provider_auth ' )
if provider_auth :
provider_name = provider_auth . split ( ' > ' ) [ 0 ]
provider_key = provider_auth . split ( ' > ' ) [ 1 ]
2023-10-05 14:17:53 +02:00
if provider_key == ' --NO_KEY-- ' :
2023-10-06 23:05:38 +02:00
print ( f ' No key for { provider_name } ' )
2023-10-05 14:17:53 +02:00
yield await errors . yield_error ( 500 ,
' Sorry, our API seems to have issues connecting to our provider(s). ' ,
' This most likely isn \' t your fault. Please try again later. '
)
return
2023-08-13 18:26:35 +02:00
target_request [ ' headers ' ] . update ( target_request . get ( ' headers ' , { } ) )
2023-08-09 11:15:49 +02:00
if target_request [ ' method ' ] == ' GET ' and not payload :
target_request [ ' payload ' ] = None
2023-08-04 17:29:49 +02:00
2023-08-12 17:49:31 +02:00
async with aiohttp . ClientSession ( connector = proxies . get_proxy ( ) . connector ) as session :
2023-08-05 02:30:42 +02:00
try :
async with session . request (
method = target_request . get ( ' method ' , ' POST ' ) ,
url = target_request [ ' url ' ] ,
data = target_request . get ( ' data ' ) ,
json = target_request . get ( ' payload ' ) ,
2023-08-09 11:15:49 +02:00
headers = target_request . get ( ' headers ' , { } ) ,
2023-08-05 02:30:42 +02:00
cookies = target_request . get ( ' cookies ' ) ,
ssl = False ,
2023-08-16 15:06:16 +02:00
timeout = aiohttp . ClientTimeout (
2023-10-02 20:06:38 +02:00
connect = 1.0 ,
2023-09-10 16:22:46 +02:00
total = float ( os . getenv ( ' TRANSFER_TIMEOUT ' , ' 500 ' ) )
2023-10-05 14:17:53 +02:00
)
2023-08-05 02:30:42 +02:00
) as response :
2023-09-14 18:18:19 +02:00
is_stream = response . content_type == ' text/event-stream '
2023-08-28 00:58:32 +02:00
2023-08-09 11:15:49 +02:00
if response . content_type == ' application/json ' :
2023-10-04 23:24:55 +02:00
client_json_response = await response . json ( )
2023-08-06 21:42:07 +02:00
2023-10-06 23:05:38 +02:00
try :
error_code = client_json_response [ ' error ' ] [ ' code ' ]
except KeyError :
error_code = ' '
if error_code == ' method_not_supported ' :
yield await errors . yield_error ( 400 , ' Sorry, this endpoint does not support this method. ' , ' Please use a different method. ' )
if error_code == ' insufficient_quota ' :
print ( ' [!] insufficient quota ' )
await keymanager . rate_limit_key ( provider_name , provider_key , 86400 )
continue
if error_code == ' billing_not_active ' :
print ( ' [!] billing not active ' )
await keymanager . deactivate_key ( provider_name , provider_key , ' billing_not_active ' )
continue
2023-09-06 11:44:29 +02:00
2023-10-04 23:24:55 +02:00
critical_error = False
for error in CRITICAL_API_ERRORS :
if error in str ( client_json_response ) :
await keymanager . deactivate_key ( provider_name , provider_key , error )
critical_error = True
if critical_error :
2023-10-05 14:17:53 +02:00
print ( ' [!] critical error ' )
2023-08-05 02:30:42 +02:00
continue
2023-08-04 03:30:56 +02:00
2023-08-09 11:15:49 +02:00
if response . ok :
2023-10-04 23:24:55 +02:00
server_json_response = client_json_response
2023-08-09 11:15:49 +02:00
2023-10-02 20:06:18 +02:00
else :
continue
2023-08-06 00:43:36 +02:00
if is_stream :
2023-08-09 11:15:49 +02:00
try :
response . raise_for_status ( )
except Exception as exc :
if ' Too Many Requests ' in str ( exc ) :
2023-09-23 21:41:48 +02:00
print ( ' [!] too many requests ' )
2023-08-09 11:15:49 +02:00
continue
2023-09-14 18:18:19 +02:00
async for chunk in response . content . iter_any ( ) :
chunk = chunk . decode ( ' utf8 ' ) . strip ( )
yield chunk + ' \n \n '
2023-08-04 17:29:49 +02:00
2023-08-05 02:30:42 +02:00
break
2023-08-04 17:29:49 +02:00
2023-08-28 00:58:32 +02:00
except Exception as exc :
2023-10-04 23:24:55 +02:00
print ( ' [!] exception ' , exc )
2023-08-27 04:29:16 +02:00
continue
2023-09-06 11:44:29 +02:00
else :
2023-10-02 20:06:18 +02:00
yield await errors . yield_error ( 500 , ' Sorry, our API seems to have issues connecting to our provider(s). ' , ' This most likely isn \' t your fault. Please try again later. ' )
2023-09-06 11:44:29 +02:00
return
2023-08-24 14:57:36 +02:00
2023-10-04 23:24:55 +02:00
if ( not is_stream ) and server_json_response :
yield json . dumps ( server_json_response )
2023-10-06 09:45:50 +02:00
create_background_task (
2023-10-04 23:24:55 +02:00
after_request . after_request (
incoming_request = incoming_request ,
target_request = target_request ,
user = user ,
credits_cost = credits_cost ,
input_tokens = input_tokens ,
path = path ,
is_chat = is_chat ,
model = model ,
)
2023-08-28 00:58:32 +02:00
)