From 7a22c1726fcc34df9d88ae35fea2f3645760b403 Mon Sep 17 00:00:00 2001 From: nsde Date: Mon, 2 Oct 2023 21:09:39 +0200 Subject: [PATCH] implemented key ratelimit checks --- admintools/pruner.py | 3 --- api/db/key_validation.py | 6 ++---- api/load_balancing.py | 11 ----------- api/provider_auth.py | 2 +- api/responder.py | 15 ++++++--------- checks/client.py | 8 ++++---- 6 files changed, 13 insertions(+), 32 deletions(-) diff --git a/admintools/pruner.py b/admintools/pruner.py index f71858d..a4ed7c0 100644 --- a/admintools/pruner.py +++ b/admintools/pruner.py @@ -82,8 +82,5 @@ # # ==================================================================================== -# def prune(): -# # gets all users from - # if __name__ == '__main__': # launch() diff --git a/api/db/key_validation.py b/api/db/key_validation.py index 1a7f85d..f902300 100644 --- a/api/db/key_validation.py +++ b/api/db/key_validation.py @@ -41,7 +41,7 @@ async def key_is_rated(key: str) -> bool: async def cached_key_is_rated(key: str) -> bool: path = os.path.join(os.getcwd(), 'cache', 'rate_limited_keys.json') - with open(path, 'r') as file: + with open(path, 'r', encoding='utf8') as file: keys = json.load(file) return key in keys @@ -49,8 +49,6 @@ async def cached_key_is_rated(key: str) -> bool: async def remove_rated_keys() -> None: """Removes all keys that have been rate limited for more than a day.""" - a_day = 86400 - client = AsyncIOMotorClient(MONGO_URI) collection = client['Liabilities']['rate-limited-keys'] @@ -58,7 +56,7 @@ async def remove_rated_keys() -> None: marked_for_removal = [] for key in keys: - if int(time.time()) - key['timestamp_added'] > a_day: + if int(time.time()) - key['timestamp_added'] > 86400: marked_for_removal.append(key['_id']) query = { diff --git a/api/load_balancing.py b/api/load_balancing.py index 2bb37fe..e32c770 100644 --- a/api/load_balancing.py +++ b/api/load_balancing.py @@ -1,6 +1,5 @@ import random import asyncio -from db.key_validation import cached_key_is_rated import providers @@ -32,16 +31,6 @@ async def balance_chat_request(payload: dict) -> dict: provider = random.choice(providers_available) target = await provider.chat_completion(**payload) - - while True: - key = target.get('provider_auth') - - if not await cached_key_is_rated(key): - break - - else: - target = await provider.chat_completion(**payload) - module_name = await _get_module_name(provider) target['module'] = module_name diff --git a/api/provider_auth.py b/api/provider_auth.py index 72d2cfd..504dc1c 100644 --- a/api/provider_auth.py +++ b/api/provider_auth.py @@ -49,7 +49,7 @@ async def invalidate_key(provider_and_key: str) -> None: with open(f'secret/{provider}.invalid.txt', 'a', encoding='utf8') as f: f.write(key + '\n') - await invalidation_webhook(provider_and_key) + # await invalidation_webhook(provider_and_key) if __name__ == '__main__': asyncio.run(invalidate_key('closed>demo-...')) diff --git a/api/responder.py b/api/responder.py index 3933afe..fc34adf 100644 --- a/api/responder.py +++ b/api/responder.py @@ -14,7 +14,8 @@ import provider_auth import after_request import load_balancing -from helpers import network, chat, errors +from helpers import errors + from db import key_validation load_dotenv() @@ -48,15 +49,10 @@ async def respond( for _ in range(10): # Load balancing: randomly selecting a suitable provider - # If the request is a chat completion, then we need to load balance between chat providers - # If the request is an organic request, then we need to load balance between organic providers try: if is_chat: target_request = await load_balancing.balance_chat_request(payload) else: - - # In this case we are doing a organic request. "organic" means that it's not using a reverse engineered front-end, but rather ClosedAI's API directly - # churchless.tech is an example of an organic provider, because it redirects the request to ClosedAI. target_request = await load_balancing.balance_organic_request({ 'method': incoming_request.method, 'path': path, @@ -93,6 +89,7 @@ async def respond( is_stream = response.content_type == 'text/event-stream' if response.status == 429: + await key_validation.log_rated_key(target_request.get('provider_auth')) continue if response.content_type == 'application/json': @@ -118,7 +115,6 @@ async def respond( await errors.error(500, 'Sorry, this endpoint does not support this method.', data['error']['message']) if 'invalid_api_key' in str(data) or 'account_deactivated' in str(data): - print('[!] invalid api key', target_request.get('provider_auth')) await provider_auth.invalidate_key(target_request.get('provider_auth')) continue @@ -140,13 +136,14 @@ async def respond( async for chunk in response.content.iter_any(): chunk = chunk.decode('utf8').strip() - print(1) yield chunk + '\n\n' break except Exception as exc: - print('[!] exception', exc) + if 'too many requests' in str(exc): + await key_validation.log_rated_key(key) + continue else: diff --git a/checks/client.py b/checks/client.py index 06a24df..f194d23 100644 --- a/checks/client.py +++ b/checks/client.py @@ -208,8 +208,8 @@ async def demo(): else: raise ConnectionError('API Server is not running.') - print('[lightblue]Checking if function calling works...') - print(await test_function_calling()) + # print('[lightblue]Checking if function calling works...') + # print(await test_function_calling()) print('Checking non-streamed chat completions...') print(await test_chat_non_stream_gpt4()) @@ -220,8 +220,8 @@ async def demo(): # print('[lightblue]Checking if image generation works...') # print(await test_image_generation()) - print('Checking the models endpoint...') - print(await test_models()) + # print('Checking the models endpoint...') + # print(await test_models()) except Exception as exc: print('[red]Error: ' + str(exc))