From 1e2a596df3d18b6dc8615bed140089f5a19d7002 Mon Sep 17 00:00:00 2001 From: henceiusegentoo Date: Sat, 23 Sep 2023 21:41:48 +0200 Subject: [PATCH] Added key validation by API-key instead of IP Added rate limited keys getting logged in a database --- .gitignore | 5 +- PUSH_TO_PRODUCTION.sh | 2 +- api/after_request.py | 4 +- api/backup_manager/main.py | 4 +- api/cache/rate_limited_keys.json | 1 + api/core.py | 2 + api/db/key_validation.py | 87 ++++++++++++++++++++++++++++++++ api/handler.py | 15 ++++-- api/helpers/network.py | 24 +++------ api/load_balancing.py | 10 ++++ api/main.py | 13 ++++- api/proxies.py | 4 ++ api/responder.py | 38 +++++++++----- checks/client.py | 2 +- 14 files changed, 165 insertions(+), 46 deletions(-) create mode 100644 api/cache/rate_limited_keys.json create mode 100644 api/db/key_validation.py diff --git a/.gitignore b/.gitignore index e16827a..5d3ade6 100644 --- a/.gitignore +++ b/.gitignore @@ -180,6 +180,7 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ -backups/ \ No newline at end of file +backups/ +cache/ \ No newline at end of file diff --git a/PUSH_TO_PRODUCTION.sh b/PUSH_TO_PRODUCTION.sh index d15b66c..e0f699c 100755 --- a/PUSH_TO_PRODUCTION.sh +++ b/PUSH_TO_PRODUCTION.sh @@ -4,7 +4,7 @@ # git commit -am "Auto-trigger - Production server started" && git push origin Production # backup database -/usr/local/bin/python /home/nova-api/api/backup_manager/main.py pre_prodpush +# /usr/local/bin/python /home/nova-api/api/backup_manager/main.py pre_prodpush # Kill production server fuser -k 2333/tcp diff --git a/api/after_request.py b/api/after_request.py index 6eec61e..2e34132 100644 --- a/api/after_request.py +++ b/api/after_request.py @@ -1,4 +1,4 @@ -from db import logs, stats, users +from db import logs, stats, users, key_validation from helpers import network async def after_request( @@ -23,6 +23,8 @@ async def after_request( await stats.manager.add_ip_address(ip_address) await stats.manager.add_path(path) await stats.manager.add_target(target_request['url']) + await key_validation.remove_rated_keys() + await key_validation.cache_all_keys() if is_chat: await stats.manager.add_model(model) diff --git a/api/backup_manager/main.py b/api/backup_manager/main.py index 4ecb1cf..d7db852 100644 --- a/api/backup_manager/main.py +++ b/api/backup_manager/main.py @@ -33,7 +33,7 @@ async def make_backup(output_dir: str): os.mkdir(f'{output_dir}/{database}') for collection in databases[database]: - print(f'Making backup for {database}/{collection}') + print(f'Initiated database backup for {database}/{collection}') await make_backup_for_collection(database, collection, output_dir) async def make_backup_for_collection(database, collection, output_dir): @@ -52,4 +52,4 @@ if __name__ == '__main__': exit(1) output_dir = argv[1] - asyncio.run(main(output_dir)) \ No newline at end of file + asyncio.run(main(output_dir)) diff --git a/api/cache/rate_limited_keys.json b/api/cache/rate_limited_keys.json new file mode 100644 index 0000000..0637a08 --- /dev/null +++ b/api/cache/rate_limited_keys.json @@ -0,0 +1 @@ +[] \ No newline at end of file diff --git a/api/core.py b/api/core.py index c88b732..f802e20 100644 --- a/api/core.py +++ b/api/core.py @@ -194,4 +194,6 @@ async def get_finances(incoming_request: fastapi.Request): amount_in_usd = await get_crypto_price(currency) * amount transactions[table][transactions[table].index(transaction)]['amount_usd'] = amount_in_usd + transactions['timestamp'] = time.time() + return transactions diff --git a/api/db/key_validation.py b/api/db/key_validation.py new file mode 100644 index 0000000..c1fcebf --- /dev/null +++ b/api/db/key_validation.py @@ -0,0 +1,87 @@ +import os +import time +import asyncio +import json + +from dotenv import load_dotenv +from motor.motor_asyncio import AsyncIOMotorClient + +load_dotenv() +MONGO_URI = os.getenv("MONGO_URI") + + +async def log_rated_key(key: str) -> None: + """Logs a key that has been rate limited to the database.""" + + client = AsyncIOMotorClient(MONGO_URI) + + scheme = { + "key": key, + "timestamp_added": int(time.time()) + } + + collection = client['Liabilities']['rate-limited-keys'] + await collection.insert_one(scheme) + + +async def key_is_rated(key: str) -> bool: + """Checks if a key is rate limited.""" + + client = AsyncIOMotorClient(MONGO_URI) + collection = client['Liabilities']['rate-limited-keys'] + + query = { + "key": key + } + + result = await collection.find_one(query) + return result is not None + + +async def cached_key_is_rated(key: str) -> bool: + path = os.path.join(os.getcwd(), "cache", "rate_limited_keys.json") + + with open(path, "r") as file: + keys = json.load(file) + + return key in keys + +async def remove_rated_keys() -> None: + """Removes all keys that have been rate limited for more than a day.""" + + a_day = 86400 + + client = AsyncIOMotorClient(MONGO_URI) + collection = client['Liabilities']['rate-limited-keys'] + + keys = await collection.find().to_list(length=None) + + marked_for_removal = [] + for key in keys: + if int(time.time()) - key['timestamp_added'] > a_day: + marked_for_removal.append(key['_id']) + + query = { + "_id": { + "$in": marked_for_removal + } + } + + await collection.delete_many(query) + + +async def cache_all_keys() -> None: + """Clones all keys from the database to the cache.""" + + client = AsyncIOMotorClient(MONGO_URI) + collection = client['Liabilities']['rate-limited-keys'] + + keys = await collection.find().to_list(length=None) + keys = [key['key'] for key in keys] + + path = os.path.join(os.getcwd(), "cache", "rate_limited_keys.json") + with open(path, "w") as file: + json.dump(keys, file) + +if __name__ == "__main__": + asyncio.run(remove_rated_keys()) diff --git a/api/handler.py b/api/handler.py index e6d5993..dac026a 100644 --- a/api/handler.py +++ b/api/handler.py @@ -28,15 +28,15 @@ with open('config/config.yml', encoding='utf8') as f: moderation_debug_key_key = os.getenv('MODERATION_DEBUG_KEY') async def handle(incoming_request: fastapi.Request): - """ - ### Transfer a streaming response + """Transfer a streaming response Takes the request from the incoming request to the target endpoint. Checks method, token amount, auth and cost along with if request is NSFW. """ - path = incoming_request.url.path.replace('v1/v1', 'v1').replace('//', '/') + + path = incoming_request.url.path + path = path.replace('/v1/v1', '/v1') ip_address = await network.get_ip(incoming_request) - print(f'[bold green]>{ip_address}[/bold green]') if '/models' in path: return fastapi.responses.JSONResponse(content=models_list) @@ -65,12 +65,17 @@ async def handle(incoming_request: fastapi.Request): return await errors.error(418, 'Invalid or inactive NovaAI API key!', 'Create a new NovaOSS API key or reactivate your account.') if user.get('auth', {}).get('discord'): - print(f'[bold green]>Discord[/bold green] {user["auth"]["discord"]}') + print(f'[bold green]>{ip_address} ({user["auth"]["discord"]})[/bold green]') ban_reason = user['status']['ban_reason'] if ban_reason: return await errors.error(403, f'Your NovaAI account has been banned. Reason: \'{ban_reason}\'.', 'Contact the staff for an appeal.') + # Checking for enterprise status + enterprise_keys = os.environ.get('NO_RATELIMIT_KEYS') + if '/enterprise' in path and user.get('api_key') not in enterprise_keys: + return await errors.error(403, 'Enterprise API is not available.', 'Contact the staff for an upgrade.') + if 'account/credits' in path: return fastapi.responses.JSONResponse({'credits': user['credits']}) diff --git a/api/helpers/network.py b/api/helpers/network.py index e4b4849..71fab28 100644 --- a/api/helpers/network.py +++ b/api/helpers/network.py @@ -1,7 +1,7 @@ import os import time - from dotenv import load_dotenv +from slowapi.util import get_remote_address load_dotenv() @@ -24,22 +24,10 @@ async def get_ip(request) -> str: def get_ratelimit_key(request) -> str: """Get the IP address of the incoming request.""" + custom = os.environ('NO_RATELIMIT_IPS') + ip = get_remote_address(request) - xff = None - if request.headers.get('x-forwarded-for'): - xff, *_ = request.headers['x-forwarded-for'].split(', ') + if ip in custom: + return f'enterprise_{ip}' - possible_ips = [ - xff, - request.headers.get('cf-connecting-ip'), - request.client.host - ] - - detected_ip = next((i for i in possible_ips if i), None) - - for whitelisted_ip in os.getenv('NO_RATELIMIT_IPS', '').split(): - if whitelisted_ip in detected_ip: - custom_key = f'whitelisted-{time.time()}' - return custom_key - - return detected_ip + return ip \ No newline at end of file diff --git a/api/load_balancing.py b/api/load_balancing.py index 0e18e02..2bb37fe 100644 --- a/api/load_balancing.py +++ b/api/load_balancing.py @@ -1,5 +1,6 @@ import random import asyncio +from db.key_validation import cached_key_is_rated import providers @@ -32,6 +33,15 @@ async def balance_chat_request(payload: dict) -> dict: provider = random.choice(providers_available) target = await provider.chat_completion(**payload) + while True: + key = target.get('provider_auth') + + if not await cached_key_is_rated(key): + break + + else: + target = await provider.chat_completion(**payload) + module_name = await _get_module_name(provider) target['module'] = module_name diff --git a/api/main.py b/api/main.py index 8290360..a3ed7e3 100644 --- a/api/main.py +++ b/api/main.py @@ -11,6 +11,7 @@ from bson.objectid import ObjectId from slowapi.errors import RateLimitExceeded from slowapi.middleware import SlowAPIMiddleware from fastapi.middleware.cors import CORSMiddleware +from slowapi.util import get_remote_address from slowapi import Limiter, _rate_limit_exceeded_handler from helpers import network @@ -34,11 +35,13 @@ app.include_router(core.router) limiter = Limiter( swallow_errors=True, - key_func=network.get_ratelimit_key, default_limits=[ + key_func=get_remote_address, + default_limits=[ '2/second', '20/minute', '300/hour' ]) + app.state.limiter = limiter app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) app.add_middleware(SlowAPIMiddleware) @@ -66,4 +69,10 @@ async def root(): @app.route('/v1/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH']) async def v1_handler(request: fastapi.Request): res = await handler.handle(incoming_request=request) - return res \ No newline at end of file + return res + +@limiter.limit('100/second') +@app.route('/enterprise/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH']) +async def enterprise_handler(request: fastapi.Request): + res = await handler.handle(incoming_request=request) + return res diff --git a/api/proxies.py b/api/proxies.py index 0dd84b2..bca633f 100644 --- a/api/proxies.py +++ b/api/proxies.py @@ -125,3 +125,7 @@ def get_proxy() -> Proxy: username=os.getenv('PROXY_USER'), password=os.getenv('PROXY_PASS') ) + +if __name__ == '__main__': + print(get_proxy().url) + print(get_proxy().connector) diff --git a/api/responder.py b/api/responder.py index 047083b..f0eb6dc 100644 --- a/api/responder.py +++ b/api/responder.py @@ -2,9 +2,7 @@ import os import json -import yaml -import dhooks -import asyncio +import random import aiohttp import starlette @@ -17,6 +15,7 @@ import after_request import load_balancing from helpers import network, chat, errors +from db import key_validation load_dotenv() @@ -44,11 +43,10 @@ async def respond( json_response = {} headers = { - 'Content-Type': 'application/json', - 'User-Agent': 'axios/0.21.1', + 'Content-Type': 'application/json' } - for _ in range(10): + for i in range(20): # Load balancing: randomly selecting a suitable provider # If the request is a chat completion, then we need to load balance between chat providers # If the request is an organic request, then we need to load balance between organic providers @@ -67,10 +65,7 @@ async def respond( 'cookies': incoming_request.cookies }) except ValueError as exc: - if model in ['gpt-3.5-turbo', 'gpt-4', 'gpt-4-32k']: - webhook = dhooks.Webhook(os.environ['DISCORD_WEBHOOK__API_ISSUE']) - webhook.send(content=f'API Issue: **`{exc}`**\nhttps://i.imgflip.com/7uv122.jpg') - yield await errors.yield_error(500, 'Sorry, the API has no working keys anymore.', 'The admins have been messaged automatically.') + yield await errors.yield_error(500, f'Sorry, the API has no active API keys for {model}.', 'Please use a different model.') return target_request['headers'].update(target_request.get('headers', {})) @@ -91,7 +86,7 @@ async def respond( cookies=target_request.get('cookies'), ssl=False, timeout=aiohttp.ClientTimeout( - connect=0.3, + connect=0.5, total=float(os.getenv('TRANSFER_TIMEOUT', '500')) ), ) as response: @@ -103,6 +98,21 @@ async def respond( if response.content_type == 'application/json': data = await response.json() + error = data.get('error') + match error: + case None: + pass + + case _: + match error.get('code'): + case "insufficient_quota": + key = target_request.get('provider_auth') + await key_validation.log_rated_key(key) + continue + + case _: + pass + if 'method_not_supported' in str(data): await errors.error(500, 'Sorry, this endpoint does not support this method.', data['error']['message']) @@ -119,6 +129,7 @@ async def respond( response.raise_for_status() except Exception as exc: if 'Too Many Requests' in str(exc): + print('[!] too many requests') continue async for chunk in response.content.iter_any(): @@ -134,14 +145,13 @@ async def respond( print('[!] chat response is empty') continue else: - yield await errors.yield_error(500, 'Sorry, the provider is not responding. We\'re possibly getting rate-limited.', 'Please try again later.') + print('[!] no response') + yield await errors.yield_error(500, 'Sorry, the provider is not responding.', 'Please try again later.') return if (not is_stream) and json_response: yield json.dumps(json_response) - print(f'[+] {path} -> {model or ""}') - await after_request.after_request( incoming_request=incoming_request, target_request=target_request, diff --git a/checks/client.py b/checks/client.py index de321f6..06a24df 100644 --- a/checks/client.py +++ b/checks/client.py @@ -164,7 +164,7 @@ async def test_function_calling(): url=f'{api_endpoint}/chat/completions', headers=HEADERS, json=json_data, - timeout=10, + timeout=15, ) response.raise_for_status()