yoooo everything works now 100% 🔥

This commit is contained in:
nsde 2023-10-04 23:24:55 +02:00
parent 7a22c1726f
commit 01aa41b6b1
9 changed files with 136 additions and 193 deletions

2
.gitignore vendored
View file

@ -18,6 +18,7 @@ last_update.txt
*.log.json *.log.json
/logs /logs
/log /log
*.log
.log .log
*.log.* *.log.*
@ -188,4 +189,3 @@ cython_debug/
backups/ backups/
cache/ cache/
api/cache/rate_limited_keys.json

View file

@ -22,4 +22,4 @@ cp env/.prod.env /home/nova-prod/.env
cd /home/nova-prod cd /home/nova-prod
# Start screen # Start screen
screen -L -S nova-api python run prod && sleep 5 screen -L -Logfile .z.log -S nova-api python run prod && sleep 5

View file

@ -1,4 +1,4 @@
from db import logs, stats, users, key_validation from db import logs, stats, users
from helpers import network from helpers import network
async def after_request( async def after_request(
@ -23,8 +23,6 @@ async def after_request(
await stats.manager.add_ip_address(ip_address) await stats.manager.add_ip_address(ip_address)
await stats.manager.add_path(path) await stats.manager.add_path(path)
await stats.manager.add_target(target_request['url']) await stats.manager.add_target(target_request['url'])
await key_validation.remove_rated_keys()
await key_validation.cache_all_keys()
if is_chat: if is_chat:
await stats.manager.add_model(model) await stats.manager.add_model(model)

View file

@ -1,84 +0,0 @@
import os
import time
import asyncio
import json
from dotenv import load_dotenv
from motor.motor_asyncio import AsyncIOMotorClient
load_dotenv()
MONGO_URI = os.getenv('MONGO_URI')
async def log_rated_key(key: str) -> None:
"""Logs a key that has been rate limited to the database."""
client = AsyncIOMotorClient(MONGO_URI)
scheme = {
'key': key,
'timestamp_added': int(time.time())
}
collection = client['Liabilities']['rate-limited-keys']
await collection.insert_one(scheme)
async def key_is_rated(key: str) -> bool:
"""Checks if a key is rate limited."""
client = AsyncIOMotorClient(MONGO_URI)
collection = client['Liabilities']['rate-limited-keys']
query = {
'key': key
}
result = await collection.find_one(query)
return result is not None
async def cached_key_is_rated(key: str) -> bool:
path = os.path.join(os.getcwd(), 'cache', 'rate_limited_keys.json')
with open(path, 'r', encoding='utf8') as file:
keys = json.load(file)
return key in keys
async def remove_rated_keys() -> None:
"""Removes all keys that have been rate limited for more than a day."""
client = AsyncIOMotorClient(MONGO_URI)
collection = client['Liabilities']['rate-limited-keys']
keys = await collection.find().to_list(length=None)
marked_for_removal = []
for key in keys:
if int(time.time()) - key['timestamp_added'] > 86400:
marked_for_removal.append(key['_id'])
query = {
'_id': {
'$in': marked_for_removal
}
}
await collection.delete_many(query)
async def cache_all_keys() -> None:
"""Clones all keys from the database to the cache."""
client = AsyncIOMotorClient(MONGO_URI)
collection = client['Liabilities']['rate-limited-keys']
keys = await collection.find().to_list(length=None)
keys = [key['key'] for key in keys]
path = os.path.join(os.getcwd(), 'cache', 'rate_limited_keys.json')
with open(path, 'w') as file:
json.dump(keys, file)
if __name__ == "__main__":
asyncio.run(remove_rated_keys())

90
api/db/providerkeys.py Normal file
View file

@ -0,0 +1,90 @@
import os
import time
import asyncio
from dotenv import load_dotenv
from motor.motor_asyncio import AsyncIOMotorClient
load_dotenv()
class KeyManager:
def __init__(self):
self.conn = AsyncIOMotorClient(os.environ['MONGO_URI'])
async def _get_collection(self, collection_name: str):
return self.conn[os.getenv('MONGO_NAME', 'nova-test')][collection_name]
async def add_key(self, provider: str, key: str, source: str='?'):
db = await self._get_collection('providerkeys')
await db.insert_one({
'provider': provider,
'key': key,
'rate_limited_since': None,
'inactive_reason': None,
'source': source,
})
async def get_key(self, provider: str):
db = await self._get_collection('providerkeys')
key = await db.find_one({
'provider': provider,
'inactive_reason': None,
'$or': [
{'rate_limited_since': None},
{'rate_limited_since': {'$lte': time.time() - 86400}}
]
})
if key is None:
return ValueError('No keys available for this provider!')
return key['key']
async def rate_limit_key(self, provider: str, key: str):
db = await self._get_collection('providerkeys')
await db.update_one({'provider': provider, 'key': key}, {
'$set': {
'rate_limited_since': time.time()
}
})
async def deactivate_key(self, provider: str, key: str, reason: str):
db = await self._get_collection('providerkeys')
await db.update_one({'provider': provider, 'key': key}, {
'$set': {
'inactive_reason': reason
}
})
async def import_all(self):
db = await self._get_collection('providerkeys')
num = 0
for filename in os.listdir('api/secret'):
if filename.endswith('.txt'):
with open(f'api/secret/{filename}') as f:
for line in f.readlines():
if not line.strip():
continue
await db.insert_one({
'provider': filename.split('.')[0],
'key': line.strip(),
'rate_limited_since': None,
'inactive_reason': None,
'source': 'import'
})
num += 1
print(f'[+] Imported {num} keys')
print('[+] Done importing keys!')
async def delete_empty_keys(self):
db = await self._get_collection('providerkeys')
await db.delete_many({'key': ''})
manager = KeyManager()
if __name__ == '__main__':
asyncio.run(manager.delete_empty_keys())

View file

@ -64,9 +64,6 @@ async def handle(incoming_request: fastapi.Request):
if not user or not user['status']['active']: if not user or not user['status']['active']:
return await errors.error(418, 'Invalid or inactive NovaAI API key!', 'Create a new NovaOSS API key or reactivate your account.') return await errors.error(418, 'Invalid or inactive NovaAI API key!', 'Create a new NovaOSS API key or reactivate your account.')
if user.get('auth', {}).get('discord'):
print(f'[bold green]>{ip_address} ({user["auth"]["discord"]})[/bold green]')
ban_reason = user['status']['ban_reason'] ban_reason = user['status']['ban_reason']
if ban_reason: if ban_reason:
return await errors.error(403, f'Your NovaAI account has been banned. Reason: \'{ban_reason}\'.', 'Contact the staff for an appeal.') return await errors.error(403, f'Your NovaAI account has been banned. Reason: \'{ban_reason}\'.', 'Contact the staff for an appeal.')

View file

@ -1,7 +1,6 @@
import json import json
import string import string
import random import random
import asyncio
from rich import print from rich import print

View file

@ -1,55 +0,0 @@
"""This module contains functions for authenticating with providers."""
import os
import asyncio
from dotenv import load_dotenv
from dhooks import Webhook, Embed
load_dotenv()
async def invalidation_webhook(provider_and_key: str) -> None:
"""Runs when a new user is created."""
dhook = Webhook(os.environ['DISCORD_WEBHOOK__API_ISSUE'])
embed = Embed(
description='Key Invalidated',
color=0xffee90,
)
embed.add_field(name='Provider', value=provider_and_key.split('>')[0])
embed.add_field(name='Key (censored)', value=f'||{provider_and_key.split(">")[1][:10]}...||', inline=False)
dhook.send(embed=embed)
async def invalidate_key(provider_and_key: str) -> None:
"""
Invalidates a key stored in the secret/ folder by storing it in the associated .invalid.txt file.
The schmea in which <provider_and_key> should be passed is:
<provider_name><key>, e.g.
closed4>cd-...
"""
if not provider_and_key:
return
provider = provider_and_key.split('>')[0]
provider_file = f'secret/{provider}.txt'
key = provider_and_key.split('>')[1]
with open(provider_file, encoding='utf8') as f_in:
text = f_in.read()
with open(provider_file, 'w', encoding='utf8') as f_out:
f_out.write(text.replace(key, ''))
with open(f'secret/{provider}.invalid.txt', 'a', encoding='utf8') as f:
f.write(key + '\n')
# await invalidation_webhook(provider_and_key)
if __name__ == '__main__':
asyncio.run(invalidate_key('closed>demo-...'))

View file

@ -2,24 +2,27 @@
import os import os
import json import json
import random import logging
import aiohttp import aiohttp
import asyncio
import starlette import starlette
from rich import print from rich import print
from dotenv import load_dotenv from dotenv import load_dotenv
import proxies import proxies
import provider_auth
import after_request import after_request
import load_balancing import load_balancing
from helpers import errors from helpers import errors
from db import providerkeys
from db import key_validation
load_dotenv() load_dotenv()
CRITICAL_API_ERRORS = ['invalid_api_key', 'account_deactivated']
keymanager = providerkeys.manager
async def respond( async def respond(
path: str='/v1/chat/completions', path: str='/v1/chat/completions',
user: dict=None, user: dict=None,
@ -41,13 +44,13 @@ async def respond(
is_chat = True is_chat = True
model = payload['model'] model = payload['model']
json_response = {} server_json_response = {}
headers = { headers = {
'Content-Type': 'application/json' 'Content-Type': 'application/json'
} }
for _ in range(10): for _ in range(20):
# Load balancing: randomly selecting a suitable provider # Load balancing: randomly selecting a suitable provider
try: try:
if is_chat: if is_chat:
@ -60,17 +63,21 @@ async def respond(
'headers': headers, 'headers': headers,
'cookies': incoming_request.cookies 'cookies': incoming_request.cookies
}) })
except ValueError as exc: except ValueError:
yield await errors.yield_error(500, f'Sorry, the API has no active API keys for {model}.', 'Please use a different model.') yield await errors.yield_error(500, f'Sorry, the API has no active API keys for {model}.', 'Please use a different model.')
return return
provider_auth = target_request.get('provider_auth')
if provider_auth:
provider_name = provider_auth.split('>')[0]
provider_key = provider_auth.split('>')[1]
target_request['headers'].update(target_request.get('headers', {})) target_request['headers'].update(target_request.get('headers', {}))
if target_request['method'] == 'GET' and not payload: if target_request['method'] == 'GET' and not payload:
target_request['payload'] = None target_request['payload'] = None
# We haven't done any requests as of right now, everything until now was just preparation
# Here, we process the request
async with aiohttp.ClientSession(connector=proxies.get_proxy().connector) as session: async with aiohttp.ClientSession(connector=proxies.get_proxy().connector) as session:
try: try:
async with session.request( async with session.request(
@ -89,43 +96,30 @@ async def respond(
is_stream = response.content_type == 'text/event-stream' is_stream = response.content_type == 'text/event-stream'
if response.status == 429: if response.status == 429:
await key_validation.log_rated_key(target_request.get('provider_auth')) await keymanager.rate_limit_key(provider_name, provider_key)
continue continue
if response.content_type == 'application/json': if response.content_type == 'application/json':
data = await response.json() client_json_response = await response.json()
error = data.get('error') if 'method_not_supported' in str(client_json_response):
match error:
case None:
pass
case _:
key = target_request.get('provider_auth')
match error.get('code'):
case 'invalid_api_key':
await key_validation.log_rated_key(key)
print('[!] invalid key', key)
case _:
print('[!] unknown error with key: ', key, error)
if 'method_not_supported' in str(data):
await errors.error(500, 'Sorry, this endpoint does not support this method.', data['error']['message']) await errors.error(500, 'Sorry, this endpoint does not support this method.', data['error']['message'])
if 'invalid_api_key' in str(data) or 'account_deactivated' in str(data): critical_error = False
await provider_auth.invalidate_key(target_request.get('provider_auth')) for error in CRITICAL_API_ERRORS:
if error in str(client_json_response):
await keymanager.deactivate_key(provider_name, provider_key, error)
critical_error = True
if critical_error:
continue continue
if response.ok: if response.ok:
json_response = data server_json_response = client_json_response
else: else:
print('[!] error', data)
continue continue
if is_stream: if is_stream:
try: try:
response.raise_for_status() response.raise_for_status()
@ -141,8 +135,10 @@ async def respond(
break break
except Exception as exc: except Exception as exc:
print('[!] exception', exc)
if 'too many requests' in str(exc): if 'too many requests' in str(exc):
await key_validation.log_rated_key(key) #!TODO
pass
continue continue
@ -150,10 +146,11 @@ async def respond(
yield await errors.yield_error(500, 'Sorry, our API seems to have issues connecting to our provider(s).', 'This most likely isn\'t your fault. Please try again later.') yield await errors.yield_error(500, 'Sorry, our API seems to have issues connecting to our provider(s).', 'This most likely isn\'t your fault. Please try again later.')
return return
if (not is_stream) and json_response: if (not is_stream) and server_json_response:
yield json.dumps(json_response) yield json.dumps(server_json_response)
await after_request.after_request( asyncio.create_task(
after_request.after_request(
incoming_request=incoming_request, incoming_request=incoming_request,
target_request=target_request, target_request=target_request,
user=user, user=user,
@ -163,3 +160,4 @@ async def respond(
is_chat=is_chat, is_chat=is_chat,
model=model, model=model,
) )
)