implemented key ratelimit checks

This commit is contained in:
nsde 2023-10-02 21:09:39 +02:00
parent 007050e9fe
commit 7a22c1726f
6 changed files with 13 additions and 32 deletions

View file

@ -82,8 +82,5 @@
# # ==================================================================================== # # ====================================================================================
# def prune():
# # gets all users from
# if __name__ == '__main__': # if __name__ == '__main__':
# launch() # launch()

View file

@ -41,7 +41,7 @@ async def key_is_rated(key: str) -> bool:
async def cached_key_is_rated(key: str) -> bool: async def cached_key_is_rated(key: str) -> bool:
path = os.path.join(os.getcwd(), 'cache', 'rate_limited_keys.json') path = os.path.join(os.getcwd(), 'cache', 'rate_limited_keys.json')
with open(path, 'r') as file: with open(path, 'r', encoding='utf8') as file:
keys = json.load(file) keys = json.load(file)
return key in keys return key in keys
@ -49,8 +49,6 @@ async def cached_key_is_rated(key: str) -> bool:
async def remove_rated_keys() -> None: async def remove_rated_keys() -> None:
"""Removes all keys that have been rate limited for more than a day.""" """Removes all keys that have been rate limited for more than a day."""
a_day = 86400
client = AsyncIOMotorClient(MONGO_URI) client = AsyncIOMotorClient(MONGO_URI)
collection = client['Liabilities']['rate-limited-keys'] collection = client['Liabilities']['rate-limited-keys']
@ -58,7 +56,7 @@ async def remove_rated_keys() -> None:
marked_for_removal = [] marked_for_removal = []
for key in keys: for key in keys:
if int(time.time()) - key['timestamp_added'] > a_day: if int(time.time()) - key['timestamp_added'] > 86400:
marked_for_removal.append(key['_id']) marked_for_removal.append(key['_id'])
query = { query = {

View file

@ -1,6 +1,5 @@
import random import random
import asyncio import asyncio
from db.key_validation import cached_key_is_rated
import providers import providers
@ -32,16 +31,6 @@ async def balance_chat_request(payload: dict) -> dict:
provider = random.choice(providers_available) provider = random.choice(providers_available)
target = await provider.chat_completion(**payload) target = await provider.chat_completion(**payload)
while True:
key = target.get('provider_auth')
if not await cached_key_is_rated(key):
break
else:
target = await provider.chat_completion(**payload)
module_name = await _get_module_name(provider) module_name = await _get_module_name(provider)
target['module'] = module_name target['module'] = module_name

View file

@ -49,7 +49,7 @@ async def invalidate_key(provider_and_key: str) -> None:
with open(f'secret/{provider}.invalid.txt', 'a', encoding='utf8') as f: with open(f'secret/{provider}.invalid.txt', 'a', encoding='utf8') as f:
f.write(key + '\n') f.write(key + '\n')
await invalidation_webhook(provider_and_key) # await invalidation_webhook(provider_and_key)
if __name__ == '__main__': if __name__ == '__main__':
asyncio.run(invalidate_key('closed>demo-...')) asyncio.run(invalidate_key('closed>demo-...'))

View file

@ -14,7 +14,8 @@ import provider_auth
import after_request import after_request
import load_balancing import load_balancing
from helpers import network, chat, errors from helpers import errors
from db import key_validation from db import key_validation
load_dotenv() load_dotenv()
@ -48,15 +49,10 @@ async def respond(
for _ in range(10): for _ in range(10):
# Load balancing: randomly selecting a suitable provider # Load balancing: randomly selecting a suitable provider
# If the request is a chat completion, then we need to load balance between chat providers
# If the request is an organic request, then we need to load balance between organic providers
try: try:
if is_chat: if is_chat:
target_request = await load_balancing.balance_chat_request(payload) target_request = await load_balancing.balance_chat_request(payload)
else: else:
# In this case we are doing a organic request. "organic" means that it's not using a reverse engineered front-end, but rather ClosedAI's API directly
# churchless.tech is an example of an organic provider, because it redirects the request to ClosedAI.
target_request = await load_balancing.balance_organic_request({ target_request = await load_balancing.balance_organic_request({
'method': incoming_request.method, 'method': incoming_request.method,
'path': path, 'path': path,
@ -93,6 +89,7 @@ async def respond(
is_stream = response.content_type == 'text/event-stream' is_stream = response.content_type == 'text/event-stream'
if response.status == 429: if response.status == 429:
await key_validation.log_rated_key(target_request.get('provider_auth'))
continue continue
if response.content_type == 'application/json': if response.content_type == 'application/json':
@ -118,7 +115,6 @@ async def respond(
await errors.error(500, 'Sorry, this endpoint does not support this method.', data['error']['message']) await errors.error(500, 'Sorry, this endpoint does not support this method.', data['error']['message'])
if 'invalid_api_key' in str(data) or 'account_deactivated' in str(data): if 'invalid_api_key' in str(data) or 'account_deactivated' in str(data):
print('[!] invalid api key', target_request.get('provider_auth'))
await provider_auth.invalidate_key(target_request.get('provider_auth')) await provider_auth.invalidate_key(target_request.get('provider_auth'))
continue continue
@ -140,13 +136,14 @@ async def respond(
async for chunk in response.content.iter_any(): async for chunk in response.content.iter_any():
chunk = chunk.decode('utf8').strip() chunk = chunk.decode('utf8').strip()
print(1)
yield chunk + '\n\n' yield chunk + '\n\n'
break break
except Exception as exc: except Exception as exc:
print('[!] exception', exc) if 'too many requests' in str(exc):
await key_validation.log_rated_key(key)
continue continue
else: else:

View file

@ -208,8 +208,8 @@ async def demo():
else: else:
raise ConnectionError('API Server is not running.') raise ConnectionError('API Server is not running.')
print('[lightblue]Checking if function calling works...') # print('[lightblue]Checking if function calling works...')
print(await test_function_calling()) # print(await test_function_calling())
print('Checking non-streamed chat completions...') print('Checking non-streamed chat completions...')
print(await test_chat_non_stream_gpt4()) print(await test_chat_non_stream_gpt4())
@ -220,8 +220,8 @@ async def demo():
# print('[lightblue]Checking if image generation works...') # print('[lightblue]Checking if image generation works...')
# print(await test_image_generation()) # print(await test_image_generation())
print('Checking the models endpoint...') # print('Checking the models endpoint...')
print(await test_models()) # print(await test_models())
except Exception as exc: except Exception as exc:
print('[red]Error: ' + str(exc)) print('[red]Error: ' + str(exc))