diff --git a/.gitignore b/.gitignore index 1084296..d806065 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,7 @@ last_update.txt *.log.json /logs /log +*.log .log *.log.* @@ -188,4 +189,3 @@ cython_debug/ backups/ cache/ -api/cache/rate_limited_keys.json diff --git a/PUSH_TO_PRODUCTION.sh b/PUSH_TO_PRODUCTION.sh index 249faaf..84d148e 100755 --- a/PUSH_TO_PRODUCTION.sh +++ b/PUSH_TO_PRODUCTION.sh @@ -22,4 +22,4 @@ cp env/.prod.env /home/nova-prod/.env cd /home/nova-prod # Start screen -screen -L -S nova-api python run prod && sleep 5 +screen -L -Logfile .z.log -S nova-api python run prod && sleep 5 diff --git a/api/after_request.py b/api/after_request.py index 2e34132..6eec61e 100644 --- a/api/after_request.py +++ b/api/after_request.py @@ -1,4 +1,4 @@ -from db import logs, stats, users, key_validation +from db import logs, stats, users from helpers import network async def after_request( @@ -23,8 +23,6 @@ async def after_request( await stats.manager.add_ip_address(ip_address) await stats.manager.add_path(path) await stats.manager.add_target(target_request['url']) - await key_validation.remove_rated_keys() - await key_validation.cache_all_keys() if is_chat: await stats.manager.add_model(model) diff --git a/api/db/key_validation.py b/api/db/key_validation.py deleted file mode 100644 index f902300..0000000 --- a/api/db/key_validation.py +++ /dev/null @@ -1,84 +0,0 @@ -import os -import time -import asyncio -import json - -from dotenv import load_dotenv -from motor.motor_asyncio import AsyncIOMotorClient - -load_dotenv() - -MONGO_URI = os.getenv('MONGO_URI') - -async def log_rated_key(key: str) -> None: - """Logs a key that has been rate limited to the database.""" - - client = AsyncIOMotorClient(MONGO_URI) - - scheme = { - 'key': key, - 'timestamp_added': int(time.time()) - } - - collection = client['Liabilities']['rate-limited-keys'] - await collection.insert_one(scheme) - - -async def key_is_rated(key: str) -> bool: - """Checks if a key is rate limited.""" - - client = AsyncIOMotorClient(MONGO_URI) - collection = client['Liabilities']['rate-limited-keys'] - - query = { - 'key': key - } - - result = await collection.find_one(query) - return result is not None - - -async def cached_key_is_rated(key: str) -> bool: - path = os.path.join(os.getcwd(), 'cache', 'rate_limited_keys.json') - - with open(path, 'r', encoding='utf8') as file: - keys = json.load(file) - - return key in keys - -async def remove_rated_keys() -> None: - """Removes all keys that have been rate limited for more than a day.""" - - client = AsyncIOMotorClient(MONGO_URI) - collection = client['Liabilities']['rate-limited-keys'] - - keys = await collection.find().to_list(length=None) - - marked_for_removal = [] - for key in keys: - if int(time.time()) - key['timestamp_added'] > 86400: - marked_for_removal.append(key['_id']) - - query = { - '_id': { - '$in': marked_for_removal - } - } - - await collection.delete_many(query) - -async def cache_all_keys() -> None: - """Clones all keys from the database to the cache.""" - - client = AsyncIOMotorClient(MONGO_URI) - collection = client['Liabilities']['rate-limited-keys'] - - keys = await collection.find().to_list(length=None) - keys = [key['key'] for key in keys] - - path = os.path.join(os.getcwd(), 'cache', 'rate_limited_keys.json') - with open(path, 'w') as file: - json.dump(keys, file) - -if __name__ == "__main__": - asyncio.run(remove_rated_keys()) diff --git a/api/db/providerkeys.py b/api/db/providerkeys.py new file mode 100644 index 0000000..84930de --- /dev/null +++ b/api/db/providerkeys.py @@ -0,0 +1,90 @@ +import os +import time +import asyncio + +from dotenv import load_dotenv +from motor.motor_asyncio import AsyncIOMotorClient + +load_dotenv() + +class KeyManager: + def __init__(self): + self.conn = AsyncIOMotorClient(os.environ['MONGO_URI']) + + async def _get_collection(self, collection_name: str): + return self.conn[os.getenv('MONGO_NAME', 'nova-test')][collection_name] + + async def add_key(self, provider: str, key: str, source: str='?'): + db = await self._get_collection('providerkeys') + await db.insert_one({ + 'provider': provider, + 'key': key, + 'rate_limited_since': None, + 'inactive_reason': None, + 'source': source, + }) + + async def get_key(self, provider: str): + db = await self._get_collection('providerkeys') + key = await db.find_one({ + 'provider': provider, + 'inactive_reason': None, + '$or': [ + {'rate_limited_since': None}, + {'rate_limited_since': {'$lte': time.time() - 86400}} + ] + }) + + if key is None: + return ValueError('No keys available for this provider!') + + return key['key'] + + async def rate_limit_key(self, provider: str, key: str): + db = await self._get_collection('providerkeys') + await db.update_one({'provider': provider, 'key': key}, { + '$set': { + 'rate_limited_since': time.time() + } + }) + + async def deactivate_key(self, provider: str, key: str, reason: str): + db = await self._get_collection('providerkeys') + await db.update_one({'provider': provider, 'key': key}, { + '$set': { + 'inactive_reason': reason + } + }) + + async def import_all(self): + db = await self._get_collection('providerkeys') + num = 0 + + for filename in os.listdir('api/secret'): + if filename.endswith('.txt'): + with open(f'api/secret/{filename}') as f: + for line in f.readlines(): + if not line.strip(): + continue + + await db.insert_one({ + 'provider': filename.split('.')[0], + 'key': line.strip(), + 'rate_limited_since': None, + 'inactive_reason': None, + 'source': 'import' + }) + num += 1 + + print(f'[+] Imported {num} keys') + + print('[+] Done importing keys!') + + async def delete_empty_keys(self): + db = await self._get_collection('providerkeys') + await db.delete_many({'key': ''}) + +manager = KeyManager() + +if __name__ == '__main__': + asyncio.run(manager.delete_empty_keys()) diff --git a/api/handler.py b/api/handler.py index dac026a..d13b893 100644 --- a/api/handler.py +++ b/api/handler.py @@ -64,9 +64,6 @@ async def handle(incoming_request: fastapi.Request): if not user or not user['status']['active']: return await errors.error(418, 'Invalid or inactive NovaAI API key!', 'Create a new NovaOSS API key or reactivate your account.') - if user.get('auth', {}).get('discord'): - print(f'[bold green]>{ip_address} ({user["auth"]["discord"]})[/bold green]') - ban_reason = user['status']['ban_reason'] if ban_reason: return await errors.error(403, f'Your NovaAI account has been banned. Reason: \'{ban_reason}\'.', 'Contact the staff for an appeal.') diff --git a/api/helpers/chat.py b/api/helpers/chat.py index 6bca378..cddb06a 100644 --- a/api/helpers/chat.py +++ b/api/helpers/chat.py @@ -1,7 +1,6 @@ import json import string import random -import asyncio from rich import print diff --git a/api/provider_auth.py b/api/provider_auth.py deleted file mode 100644 index 504dc1c..0000000 --- a/api/provider_auth.py +++ /dev/null @@ -1,55 +0,0 @@ -"""This module contains functions for authenticating with providers.""" - -import os -import asyncio - -from dotenv import load_dotenv -from dhooks import Webhook, Embed - -load_dotenv() - -async def invalidation_webhook(provider_and_key: str) -> None: - """Runs when a new user is created.""" - - dhook = Webhook(os.environ['DISCORD_WEBHOOK__API_ISSUE']) - - embed = Embed( - description='Key Invalidated', - color=0xffee90, - ) - - embed.add_field(name='Provider', value=provider_and_key.split('>')[0]) - embed.add_field(name='Key (censored)', value=f'||{provider_and_key.split(">")[1][:10]}...||', inline=False) - - dhook.send(embed=embed) - -async def invalidate_key(provider_and_key: str) -> None: - """ - - Invalidates a key stored in the secret/ folder by storing it in the associated .invalid.txt file. - The schmea in which should be passed is: - , e.g. - closed4>cd-... - - """ - - if not provider_and_key: - return - - provider = provider_and_key.split('>')[0] - provider_file = f'secret/{provider}.txt' - key = provider_and_key.split('>')[1] - - with open(provider_file, encoding='utf8') as f_in: - text = f_in.read() - - with open(provider_file, 'w', encoding='utf8') as f_out: - f_out.write(text.replace(key, '')) - - with open(f'secret/{provider}.invalid.txt', 'a', encoding='utf8') as f: - f.write(key + '\n') - - # await invalidation_webhook(provider_and_key) - -if __name__ == '__main__': - asyncio.run(invalidate_key('closed>demo-...')) diff --git a/api/responder.py b/api/responder.py index fc34adf..657a223 100644 --- a/api/responder.py +++ b/api/responder.py @@ -2,24 +2,27 @@ import os import json -import random +import logging import aiohttp +import asyncio import starlette from rich import print from dotenv import load_dotenv import proxies -import provider_auth import after_request import load_balancing from helpers import errors - -from db import key_validation +from db import providerkeys load_dotenv() +CRITICAL_API_ERRORS = ['invalid_api_key', 'account_deactivated'] + +keymanager = providerkeys.manager + async def respond( path: str='/v1/chat/completions', user: dict=None, @@ -41,13 +44,13 @@ async def respond( is_chat = True model = payload['model'] - json_response = {} + server_json_response = {} headers = { 'Content-Type': 'application/json' } - for _ in range(10): + for _ in range(20): # Load balancing: randomly selecting a suitable provider try: if is_chat: @@ -60,17 +63,21 @@ async def respond( 'headers': headers, 'cookies': incoming_request.cookies }) - except ValueError as exc: + except ValueError: yield await errors.yield_error(500, f'Sorry, the API has no active API keys for {model}.', 'Please use a different model.') return + provider_auth = target_request.get('provider_auth') + + if provider_auth: + provider_name = provider_auth.split('>')[0] + provider_key = provider_auth.split('>')[1] + target_request['headers'].update(target_request.get('headers', {})) if target_request['method'] == 'GET' and not payload: target_request['payload'] = None - # We haven't done any requests as of right now, everything until now was just preparation - # Here, we process the request async with aiohttp.ClientSession(connector=proxies.get_proxy().connector) as session: try: async with session.request( @@ -89,43 +96,30 @@ async def respond( is_stream = response.content_type == 'text/event-stream' if response.status == 429: - await key_validation.log_rated_key(target_request.get('provider_auth')) + await keymanager.rate_limit_key(provider_name, provider_key) continue if response.content_type == 'application/json': - data = await response.json() + client_json_response = await response.json() - error = data.get('error') - match error: - case None: - pass - - case _: - key = target_request.get('provider_auth') - - match error.get('code'): - case 'invalid_api_key': - await key_validation.log_rated_key(key) - print('[!] invalid key', key) - - case _: - print('[!] unknown error with key: ', key, error) - - if 'method_not_supported' in str(data): + if 'method_not_supported' in str(client_json_response): await errors.error(500, 'Sorry, this endpoint does not support this method.', data['error']['message']) - if 'invalid_api_key' in str(data) or 'account_deactivated' in str(data): - await provider_auth.invalidate_key(target_request.get('provider_auth')) + critical_error = False + for error in CRITICAL_API_ERRORS: + if error in str(client_json_response): + await keymanager.deactivate_key(provider_name, provider_key, error) + critical_error = True + + if critical_error: continue if response.ok: - json_response = data + server_json_response = client_json_response else: - print('[!] error', data) continue - if is_stream: try: response.raise_for_status() @@ -141,8 +135,10 @@ async def respond( break except Exception as exc: + print('[!] exception', exc) if 'too many requests' in str(exc): - await key_validation.log_rated_key(key) + #!TODO + pass continue @@ -150,16 +146,18 @@ async def respond( yield await errors.yield_error(500, 'Sorry, our API seems to have issues connecting to our provider(s).', 'This most likely isn\'t your fault. Please try again later.') return - if (not is_stream) and json_response: - yield json.dumps(json_response) + if (not is_stream) and server_json_response: + yield json.dumps(server_json_response) - await after_request.after_request( - incoming_request=incoming_request, - target_request=target_request, - user=user, - credits_cost=credits_cost, - input_tokens=input_tokens, - path=path, - is_chat=is_chat, - model=model, + asyncio.create_task( + after_request.after_request( + incoming_request=incoming_request, + target_request=target_request, + user=user, + credits_cost=credits_cost, + input_tokens=input_tokens, + path=path, + is_chat=is_chat, + model=model, + ) )