diff --git a/api/after_request.py b/api/after_request.py new file mode 100644 index 0000000..6eec61e --- /dev/null +++ b/api/after_request.py @@ -0,0 +1,29 @@ +from db import logs, stats, users +from helpers import network + +async def after_request( + incoming_request: dict, + target_request: dict, + user: dict, + credits_cost: int, + input_tokens: int, + path: str, + is_chat: bool, + model: str, +) -> None: + if user and incoming_request: + await logs.log_api_request(user=user, incoming_request=incoming_request, target_url=target_request['url']) + + if credits_cost and user: + await users.manager.update_by_id(user['_id'], {'$inc': {'credits': -credits_cost}}) + + ip_address = await network.get_ip(incoming_request) + + await stats.manager.add_date() + await stats.manager.add_ip_address(ip_address) + await stats.manager.add_path(path) + await stats.manager.add_target(target_request['url']) + + if is_chat: + await stats.manager.add_model(model) + await stats.manager.add_tokens(input_tokens, model) diff --git a/api/db/stats.py b/api/db/stats.py index 53bc17c..c5ce07b 100644 --- a/api/db/stats.py +++ b/api/db/stats.py @@ -61,6 +61,8 @@ class StatsManager: db = await self._get_collection('stats') return await db.find_one({obj_filter}) +manager = StatsManager() + if __name__ == '__main__': stats = StatsManager() asyncio.run(stats.add_date()) diff --git a/api/db/users.py b/api/db/users.py index e1150a6..2e325d4 100644 --- a/api/db/users.py +++ b/api/db/users.py @@ -101,6 +101,8 @@ class UserManager: db = await self._get_collection('users') await db.delete_one({'_id': user_id}) +manager = UserManager() + async def demo(): user = await UserManager().create(69420) print(user) diff --git a/api/handler.py b/api/handler.py index ff725dc..443765a 100644 --- a/api/handler.py +++ b/api/handler.py @@ -40,11 +40,6 @@ async def handle(incoming_request: fastapi.Request): except json.decoder.JSONDecodeError: payload = {} - try: - input_tokens = await tokens.count_for_messages(payload.get('messages', [])) - except (KeyError, TypeError): - input_tokens = 0 - received_key = incoming_request.headers.get('Authorization') if not received_key or not received_key.startswith('Bearer '): @@ -70,10 +65,16 @@ async def handle(incoming_request: fastapi.Request): policy_violation = False if '/moderations' not in path: - if '/chat/completions' in path or ('input' in payload or 'prompt' in payload): + inp = '' + + if 'input' in payload or 'prompt' in payload: inp = payload.get('input', payload.get('prompt', '')) - if inp and len(inp) > 2 and not inp.isnumeric(): - policy_violation = await moderation.is_policy_violated(inp) + + if isinstance(payload.get('messages'), list): + inp = '\n'.join([message['content'] for message in payload['messages']]) + + if inp and len(inp) > 2 and not inp.isnumeric(): + policy_violation = await moderation.is_policy_violated(inp) if policy_violation: return await errors.error( @@ -104,7 +105,7 @@ async def handle(incoming_request: fastapi.Request): path=path, payload=payload, credits_cost=cost, - input_tokens=input_tokens, + input_tokens=-1, incoming_request=incoming_request, ), media_type=media_type diff --git a/api/helpers/tokens.py b/api/helpers/tokens.py index f39d15a..848cac8 100644 --- a/api/helpers/tokens.py +++ b/api/helpers/tokens.py @@ -1,3 +1,5 @@ +import time +import asyncio import tiktoken async def count_for_messages(messages: list, model: str='gpt-3.5-turbo-0613') -> int: @@ -57,3 +59,15 @@ for information on how messages are converted to tokens.""") num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> return num_tokens + +if __name__ == '__main__': + start = time.perf_counter() + + messages = [ + { + 'role': 'user', + 'content': '1+1=' + } + ] + print(asyncio.run(count_for_messages(messages))) + print(f'Took {(time.perf_counter() - start) * 1000}ms') diff --git a/api/main.py b/api/main.py index fc76e8a..349854c 100644 --- a/api/main.py +++ b/api/main.py @@ -63,4 +63,5 @@ async def root(): @app.route('/v1/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH']) async def v1_handler(request: fastapi.Request): - return await handler.handle(request) + res = await handler.handle(request) + return res diff --git a/api/moderation.py b/api/moderation.py index bcbff1d..538c8c7 100644 --- a/api/moderation.py +++ b/api/moderation.py @@ -30,8 +30,6 @@ async def is_policy_violated(inp: Union[str, list]) -> bool: else: text = '\n'.join(inp) - print(f'[i] checking moderation for {text}') - for _ in range(3): req = await load_balancing.balance_organic_request( { @@ -39,7 +37,6 @@ async def is_policy_violated(inp: Union[str, list]) -> bool: 'payload': {'input': text} } ) - print(f'[i] moderation request sent to {req["url"]}') async with aiohttp.ClientSession(connector=proxies.get_proxy().connector) as session: try: @@ -52,16 +49,13 @@ async def is_policy_violated(inp: Union[str, list]) -> bool: headers=req.get('headers'), cookies=req.get('cookies'), ssl=False, - timeout=aiohttp.ClientTimeout(total=2), + timeout=aiohttp.ClientTimeout(total=3), ) as res: res.raise_for_status() json_response = await res.json() - print(json_response) categories = json_response['results'][0]['category_scores'] - print(f'[i] moderation check took {time.perf_counter() - start:.2f}s') - if json_response['results'][0]['flagged']: return max(categories, key=categories.get) diff --git a/api/streaming.py b/api/streaming.py index d6a32a4..4f4fa5a 100644 --- a/api/streaming.py +++ b/api/streaming.py @@ -10,16 +10,13 @@ import starlette from rich import print from dotenv import load_dotenv -from python_socks._errors import ProxyError import chunks import proxies import provider_auth +import after_request import load_balancing -from db import logs -from db.users import UserManager -from db.stats import StatsManager from helpers import network, chat, errors load_dotenv() @@ -54,13 +51,11 @@ async def stream( If not streaming, it sends the result in its entirety. """ - ## Setup managers - db = UserManager() - stats = StatsManager() - is_chat = False is_stream = payload.get('stream', False) + model = None + if 'chat/completions' in path: is_chat = True model = payload['model'] @@ -78,7 +73,6 @@ async def stream( } for _ in range(5): - # Load balancing: randomly selecting a suitable provider # If the request is a chat completion, then we need to load balance between chat providers # If the request is an organic request, then we need to load balance between organic providers @@ -120,17 +114,19 @@ async def stream( cookies=target_request.get('cookies'), ssl=False, timeout=aiohttp.ClientTimeout( - connect=60, + connect=2, total=float(os.getenv('TRANSFER_TIMEOUT', '120')) ), ) as response: + if response.status == 429: continue if response.content_type == 'application/json': data = await response.json() - if data.get('code') == 'invalid_api_key': + if 'invalid_api_key' in str(data) or 'account_deactivated' in str(data): + print('[!] invalid api key', target_request.get('provider_auth')) await provider_auth.invalidate_key(target_request.get('provider_auth')) continue @@ -155,19 +151,11 @@ async def stream( break - except ProxyError as exc: - print('[!] aiohttp ProxyError') + except Exception as exc: + print(f'[!] {type(exc)} - {exc}') continue - except ConnectionResetError as exc: - print('[!] aiohttp ConnectionResetError') - continue - - except aiohttp.client_exceptions.ClientConnectionError: - print('[!] aiohttp ClientConnectionError') - continue - - if not json_response and is_chat and is_stream: + if (not json_response) and is_chat: print('[!] chat response is empty') continue @@ -178,20 +166,16 @@ async def stream( if not is_stream and json_response: yield json.dumps(json_response) - if user and incoming_request: - await logs.log_api_request(user=user, incoming_request=incoming_request, target_url=target_request['url']) - - if credits_cost and user: - await db.update_by_id(user['_id'], {'$inc': {'credits': -credits_cost}}) - - ip_address = await network.get_ip(incoming_request) - await stats.add_date() - await stats.add_ip_address(ip_address) - await stats.add_path(path) - await stats.add_target(target_request['url']) - if is_chat: - await stats.add_model(model) - await stats.add_tokens(input_tokens, model) + await after_request.after_request( + incoming_request=incoming_request, + target_request=target_request, + user=user, + credits_cost=credits_cost, + input_tokens=input_tokens, + path=path, + is_chat=is_chat, + model=model, + ) if __name__ == '__main__': asyncio.run(stream())