Added key validation by API-key instead of IP

Added rate limited keys getting logged in a database
This commit is contained in:
henceiusegentoo 2023-09-23 21:41:48 +02:00
parent d3d9ead8f4
commit 1e2a596df3
14 changed files with 165 additions and 46 deletions

3
.gitignore vendored
View file

@ -180,6 +180,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear # and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ .idea/
backups/ backups/
cache/

View file

@ -4,7 +4,7 @@
# git commit -am "Auto-trigger - Production server started" && git push origin Production # git commit -am "Auto-trigger - Production server started" && git push origin Production
# backup database # backup database
/usr/local/bin/python /home/nova-api/api/backup_manager/main.py pre_prodpush # /usr/local/bin/python /home/nova-api/api/backup_manager/main.py pre_prodpush
# Kill production server # Kill production server
fuser -k 2333/tcp fuser -k 2333/tcp

View file

@ -1,4 +1,4 @@
from db import logs, stats, users from db import logs, stats, users, key_validation
from helpers import network from helpers import network
async def after_request( async def after_request(
@ -23,6 +23,8 @@ async def after_request(
await stats.manager.add_ip_address(ip_address) await stats.manager.add_ip_address(ip_address)
await stats.manager.add_path(path) await stats.manager.add_path(path)
await stats.manager.add_target(target_request['url']) await stats.manager.add_target(target_request['url'])
await key_validation.remove_rated_keys()
await key_validation.cache_all_keys()
if is_chat: if is_chat:
await stats.manager.add_model(model) await stats.manager.add_model(model)

View file

@ -33,7 +33,7 @@ async def make_backup(output_dir: str):
os.mkdir(f'{output_dir}/{database}') os.mkdir(f'{output_dir}/{database}')
for collection in databases[database]: for collection in databases[database]:
print(f'Making backup for {database}/{collection}') print(f'Initiated database backup for {database}/{collection}')
await make_backup_for_collection(database, collection, output_dir) await make_backup_for_collection(database, collection, output_dir)
async def make_backup_for_collection(database, collection, output_dir): async def make_backup_for_collection(database, collection, output_dir):

1
api/cache/rate_limited_keys.json vendored Normal file
View file

@ -0,0 +1 @@
[]

View file

@ -194,4 +194,6 @@ async def get_finances(incoming_request: fastapi.Request):
amount_in_usd = await get_crypto_price(currency) * amount amount_in_usd = await get_crypto_price(currency) * amount
transactions[table][transactions[table].index(transaction)]['amount_usd'] = amount_in_usd transactions[table][transactions[table].index(transaction)]['amount_usd'] = amount_in_usd
transactions['timestamp'] = time.time()
return transactions return transactions

87
api/db/key_validation.py Normal file
View file

@ -0,0 +1,87 @@
import os
import time
import asyncio
import json
from dotenv import load_dotenv
from motor.motor_asyncio import AsyncIOMotorClient
load_dotenv()
MONGO_URI = os.getenv("MONGO_URI")
async def log_rated_key(key: str) -> None:
"""Logs a key that has been rate limited to the database."""
client = AsyncIOMotorClient(MONGO_URI)
scheme = {
"key": key,
"timestamp_added": int(time.time())
}
collection = client['Liabilities']['rate-limited-keys']
await collection.insert_one(scheme)
async def key_is_rated(key: str) -> bool:
"""Checks if a key is rate limited."""
client = AsyncIOMotorClient(MONGO_URI)
collection = client['Liabilities']['rate-limited-keys']
query = {
"key": key
}
result = await collection.find_one(query)
return result is not None
async def cached_key_is_rated(key: str) -> bool:
path = os.path.join(os.getcwd(), "cache", "rate_limited_keys.json")
with open(path, "r") as file:
keys = json.load(file)
return key in keys
async def remove_rated_keys() -> None:
"""Removes all keys that have been rate limited for more than a day."""
a_day = 86400
client = AsyncIOMotorClient(MONGO_URI)
collection = client['Liabilities']['rate-limited-keys']
keys = await collection.find().to_list(length=None)
marked_for_removal = []
for key in keys:
if int(time.time()) - key['timestamp_added'] > a_day:
marked_for_removal.append(key['_id'])
query = {
"_id": {
"$in": marked_for_removal
}
}
await collection.delete_many(query)
async def cache_all_keys() -> None:
"""Clones all keys from the database to the cache."""
client = AsyncIOMotorClient(MONGO_URI)
collection = client['Liabilities']['rate-limited-keys']
keys = await collection.find().to_list(length=None)
keys = [key['key'] for key in keys]
path = os.path.join(os.getcwd(), "cache", "rate_limited_keys.json")
with open(path, "w") as file:
json.dump(keys, file)
if __name__ == "__main__":
asyncio.run(remove_rated_keys())

View file

@ -28,15 +28,15 @@ with open('config/config.yml', encoding='utf8') as f:
moderation_debug_key_key = os.getenv('MODERATION_DEBUG_KEY') moderation_debug_key_key = os.getenv('MODERATION_DEBUG_KEY')
async def handle(incoming_request: fastapi.Request): async def handle(incoming_request: fastapi.Request):
""" """Transfer a streaming response
### Transfer a streaming response
Takes the request from the incoming request to the target endpoint. Takes the request from the incoming request to the target endpoint.
Checks method, token amount, auth and cost along with if request is NSFW. Checks method, token amount, auth and cost along with if request is NSFW.
""" """
path = incoming_request.url.path.replace('v1/v1', 'v1').replace('//', '/')
path = incoming_request.url.path
path = path.replace('/v1/v1', '/v1')
ip_address = await network.get_ip(incoming_request) ip_address = await network.get_ip(incoming_request)
print(f'[bold green]>{ip_address}[/bold green]')
if '/models' in path: if '/models' in path:
return fastapi.responses.JSONResponse(content=models_list) return fastapi.responses.JSONResponse(content=models_list)
@ -65,12 +65,17 @@ async def handle(incoming_request: fastapi.Request):
return await errors.error(418, 'Invalid or inactive NovaAI API key!', 'Create a new NovaOSS API key or reactivate your account.') return await errors.error(418, 'Invalid or inactive NovaAI API key!', 'Create a new NovaOSS API key or reactivate your account.')
if user.get('auth', {}).get('discord'): if user.get('auth', {}).get('discord'):
print(f'[bold green]>Discord[/bold green] {user["auth"]["discord"]}') print(f'[bold green]>{ip_address} ({user["auth"]["discord"]})[/bold green]')
ban_reason = user['status']['ban_reason'] ban_reason = user['status']['ban_reason']
if ban_reason: if ban_reason:
return await errors.error(403, f'Your NovaAI account has been banned. Reason: \'{ban_reason}\'.', 'Contact the staff for an appeal.') return await errors.error(403, f'Your NovaAI account has been banned. Reason: \'{ban_reason}\'.', 'Contact the staff for an appeal.')
# Checking for enterprise status
enterprise_keys = os.environ.get('NO_RATELIMIT_KEYS')
if '/enterprise' in path and user.get('api_key') not in enterprise_keys:
return await errors.error(403, 'Enterprise API is not available.', 'Contact the staff for an upgrade.')
if 'account/credits' in path: if 'account/credits' in path:
return fastapi.responses.JSONResponse({'credits': user['credits']}) return fastapi.responses.JSONResponse({'credits': user['credits']})

View file

@ -1,7 +1,7 @@
import os import os
import time import time
from dotenv import load_dotenv from dotenv import load_dotenv
from slowapi.util import get_remote_address
load_dotenv() load_dotenv()
@ -24,22 +24,10 @@ async def get_ip(request) -> str:
def get_ratelimit_key(request) -> str: def get_ratelimit_key(request) -> str:
"""Get the IP address of the incoming request.""" """Get the IP address of the incoming request."""
custom = os.environ('NO_RATELIMIT_IPS')
ip = get_remote_address(request)
xff = None if ip in custom:
if request.headers.get('x-forwarded-for'): return f'enterprise_{ip}'
xff, *_ = request.headers['x-forwarded-for'].split(', ')
possible_ips = [ return ip
xff,
request.headers.get('cf-connecting-ip'),
request.client.host
]
detected_ip = next((i for i in possible_ips if i), None)
for whitelisted_ip in os.getenv('NO_RATELIMIT_IPS', '').split():
if whitelisted_ip in detected_ip:
custom_key = f'whitelisted-{time.time()}'
return custom_key
return detected_ip

View file

@ -1,5 +1,6 @@
import random import random
import asyncio import asyncio
from db.key_validation import cached_key_is_rated
import providers import providers
@ -32,6 +33,15 @@ async def balance_chat_request(payload: dict) -> dict:
provider = random.choice(providers_available) provider = random.choice(providers_available)
target = await provider.chat_completion(**payload) target = await provider.chat_completion(**payload)
while True:
key = target.get('provider_auth')
if not await cached_key_is_rated(key):
break
else:
target = await provider.chat_completion(**payload)
module_name = await _get_module_name(provider) module_name = await _get_module_name(provider)
target['module'] = module_name target['module'] = module_name

View file

@ -11,6 +11,7 @@ from bson.objectid import ObjectId
from slowapi.errors import RateLimitExceeded from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware from slowapi.middleware import SlowAPIMiddleware
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from slowapi.util import get_remote_address
from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi import Limiter, _rate_limit_exceeded_handler
from helpers import network from helpers import network
@ -34,11 +35,13 @@ app.include_router(core.router)
limiter = Limiter( limiter = Limiter(
swallow_errors=True, swallow_errors=True,
key_func=network.get_ratelimit_key, default_limits=[ key_func=get_remote_address,
default_limits=[
'2/second', '2/second',
'20/minute', '20/minute',
'300/hour' '300/hour'
]) ])
app.state.limiter = limiter app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_middleware(SlowAPIMiddleware) app.add_middleware(SlowAPIMiddleware)
@ -67,3 +70,9 @@ async def root():
async def v1_handler(request: fastapi.Request): async def v1_handler(request: fastapi.Request):
res = await handler.handle(incoming_request=request) res = await handler.handle(incoming_request=request)
return res return res
@limiter.limit('100/second')
@app.route('/enterprise/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH'])
async def enterprise_handler(request: fastapi.Request):
res = await handler.handle(incoming_request=request)
return res

View file

@ -125,3 +125,7 @@ def get_proxy() -> Proxy:
username=os.getenv('PROXY_USER'), username=os.getenv('PROXY_USER'),
password=os.getenv('PROXY_PASS') password=os.getenv('PROXY_PASS')
) )
if __name__ == '__main__':
print(get_proxy().url)
print(get_proxy().connector)

View file

@ -2,9 +2,7 @@
import os import os
import json import json
import yaml import random
import dhooks
import asyncio
import aiohttp import aiohttp
import starlette import starlette
@ -17,6 +15,7 @@ import after_request
import load_balancing import load_balancing
from helpers import network, chat, errors from helpers import network, chat, errors
from db import key_validation
load_dotenv() load_dotenv()
@ -44,11 +43,10 @@ async def respond(
json_response = {} json_response = {}
headers = { headers = {
'Content-Type': 'application/json', 'Content-Type': 'application/json'
'User-Agent': 'axios/0.21.1',
} }
for _ in range(10): for i in range(20):
# Load balancing: randomly selecting a suitable provider # Load balancing: randomly selecting a suitable provider
# If the request is a chat completion, then we need to load balance between chat providers # If the request is a chat completion, then we need to load balance between chat providers
# If the request is an organic request, then we need to load balance between organic providers # If the request is an organic request, then we need to load balance between organic providers
@ -67,10 +65,7 @@ async def respond(
'cookies': incoming_request.cookies 'cookies': incoming_request.cookies
}) })
except ValueError as exc: except ValueError as exc:
if model in ['gpt-3.5-turbo', 'gpt-4', 'gpt-4-32k']: yield await errors.yield_error(500, f'Sorry, the API has no active API keys for {model}.', 'Please use a different model.')
webhook = dhooks.Webhook(os.environ['DISCORD_WEBHOOK__API_ISSUE'])
webhook.send(content=f'API Issue: **`{exc}`**\nhttps://i.imgflip.com/7uv122.jpg')
yield await errors.yield_error(500, 'Sorry, the API has no working keys anymore.', 'The admins have been messaged automatically.')
return return
target_request['headers'].update(target_request.get('headers', {})) target_request['headers'].update(target_request.get('headers', {}))
@ -91,7 +86,7 @@ async def respond(
cookies=target_request.get('cookies'), cookies=target_request.get('cookies'),
ssl=False, ssl=False,
timeout=aiohttp.ClientTimeout( timeout=aiohttp.ClientTimeout(
connect=0.3, connect=0.5,
total=float(os.getenv('TRANSFER_TIMEOUT', '500')) total=float(os.getenv('TRANSFER_TIMEOUT', '500'))
), ),
) as response: ) as response:
@ -103,6 +98,21 @@ async def respond(
if response.content_type == 'application/json': if response.content_type == 'application/json':
data = await response.json() data = await response.json()
error = data.get('error')
match error:
case None:
pass
case _:
match error.get('code'):
case "insufficient_quota":
key = target_request.get('provider_auth')
await key_validation.log_rated_key(key)
continue
case _:
pass
if 'method_not_supported' in str(data): if 'method_not_supported' in str(data):
await errors.error(500, 'Sorry, this endpoint does not support this method.', data['error']['message']) await errors.error(500, 'Sorry, this endpoint does not support this method.', data['error']['message'])
@ -119,6 +129,7 @@ async def respond(
response.raise_for_status() response.raise_for_status()
except Exception as exc: except Exception as exc:
if 'Too Many Requests' in str(exc): if 'Too Many Requests' in str(exc):
print('[!] too many requests')
continue continue
async for chunk in response.content.iter_any(): async for chunk in response.content.iter_any():
@ -134,14 +145,13 @@ async def respond(
print('[!] chat response is empty') print('[!] chat response is empty')
continue continue
else: else:
yield await errors.yield_error(500, 'Sorry, the provider is not responding. We\'re possibly getting rate-limited.', 'Please try again later.') print('[!] no response')
yield await errors.yield_error(500, 'Sorry, the provider is not responding.', 'Please try again later.')
return return
if (not is_stream) and json_response: if (not is_stream) and json_response:
yield json.dumps(json_response) yield json.dumps(json_response)
print(f'[+] {path} -> {model or ""}')
await after_request.after_request( await after_request.after_request(
incoming_request=incoming_request, incoming_request=incoming_request,
target_request=target_request, target_request=target_request,

View file

@ -164,7 +164,7 @@ async def test_function_calling():
url=f'{api_endpoint}/chat/completions', url=f'{api_endpoint}/chat/completions',
headers=HEADERS, headers=HEADERS,
json=json_data, json=json_data,
timeout=10, timeout=15,
) )
response.raise_for_status() response.raise_for_status()