mirror of
https://github.com/NovaOSS/nova-api.git
synced 2024-11-25 14:33:57 +01:00
Added key validation by API-key instead of IP
Added rate limited keys getting logged in a database
This commit is contained in:
parent
d3d9ead8f4
commit
1e2a596df3
5
.gitignore
vendored
5
.gitignore
vendored
|
@ -180,6 +180,7 @@ cython_debug/
|
|||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
.idea/
|
||||
|
||||
backups/
|
||||
backups/
|
||||
cache/
|
|
@ -4,7 +4,7 @@
|
|||
# git commit -am "Auto-trigger - Production server started" && git push origin Production
|
||||
|
||||
# backup database
|
||||
/usr/local/bin/python /home/nova-api/api/backup_manager/main.py pre_prodpush
|
||||
# /usr/local/bin/python /home/nova-api/api/backup_manager/main.py pre_prodpush
|
||||
|
||||
# Kill production server
|
||||
fuser -k 2333/tcp
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from db import logs, stats, users
|
||||
from db import logs, stats, users, key_validation
|
||||
from helpers import network
|
||||
|
||||
async def after_request(
|
||||
|
@ -23,6 +23,8 @@ async def after_request(
|
|||
await stats.manager.add_ip_address(ip_address)
|
||||
await stats.manager.add_path(path)
|
||||
await stats.manager.add_target(target_request['url'])
|
||||
await key_validation.remove_rated_keys()
|
||||
await key_validation.cache_all_keys()
|
||||
|
||||
if is_chat:
|
||||
await stats.manager.add_model(model)
|
||||
|
|
|
@ -33,7 +33,7 @@ async def make_backup(output_dir: str):
|
|||
os.mkdir(f'{output_dir}/{database}')
|
||||
|
||||
for collection in databases[database]:
|
||||
print(f'Making backup for {database}/{collection}')
|
||||
print(f'Initiated database backup for {database}/{collection}')
|
||||
await make_backup_for_collection(database, collection, output_dir)
|
||||
|
||||
async def make_backup_for_collection(database, collection, output_dir):
|
||||
|
@ -52,4 +52,4 @@ if __name__ == '__main__':
|
|||
exit(1)
|
||||
|
||||
output_dir = argv[1]
|
||||
asyncio.run(main(output_dir))
|
||||
asyncio.run(main(output_dir))
|
||||
|
|
1
api/cache/rate_limited_keys.json
vendored
Normal file
1
api/cache/rate_limited_keys.json
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
[]
|
|
@ -194,4 +194,6 @@ async def get_finances(incoming_request: fastapi.Request):
|
|||
amount_in_usd = await get_crypto_price(currency) * amount
|
||||
transactions[table][transactions[table].index(transaction)]['amount_usd'] = amount_in_usd
|
||||
|
||||
transactions['timestamp'] = time.time()
|
||||
|
||||
return transactions
|
||||
|
|
87
api/db/key_validation.py
Normal file
87
api/db/key_validation.py
Normal file
|
@ -0,0 +1,87 @@
|
|||
import os
|
||||
import time
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
|
||||
load_dotenv()
|
||||
MONGO_URI = os.getenv("MONGO_URI")
|
||||
|
||||
|
||||
async def log_rated_key(key: str) -> None:
|
||||
"""Logs a key that has been rate limited to the database."""
|
||||
|
||||
client = AsyncIOMotorClient(MONGO_URI)
|
||||
|
||||
scheme = {
|
||||
"key": key,
|
||||
"timestamp_added": int(time.time())
|
||||
}
|
||||
|
||||
collection = client['Liabilities']['rate-limited-keys']
|
||||
await collection.insert_one(scheme)
|
||||
|
||||
|
||||
async def key_is_rated(key: str) -> bool:
|
||||
"""Checks if a key is rate limited."""
|
||||
|
||||
client = AsyncIOMotorClient(MONGO_URI)
|
||||
collection = client['Liabilities']['rate-limited-keys']
|
||||
|
||||
query = {
|
||||
"key": key
|
||||
}
|
||||
|
||||
result = await collection.find_one(query)
|
||||
return result is not None
|
||||
|
||||
|
||||
async def cached_key_is_rated(key: str) -> bool:
|
||||
path = os.path.join(os.getcwd(), "cache", "rate_limited_keys.json")
|
||||
|
||||
with open(path, "r") as file:
|
||||
keys = json.load(file)
|
||||
|
||||
return key in keys
|
||||
|
||||
async def remove_rated_keys() -> None:
|
||||
"""Removes all keys that have been rate limited for more than a day."""
|
||||
|
||||
a_day = 86400
|
||||
|
||||
client = AsyncIOMotorClient(MONGO_URI)
|
||||
collection = client['Liabilities']['rate-limited-keys']
|
||||
|
||||
keys = await collection.find().to_list(length=None)
|
||||
|
||||
marked_for_removal = []
|
||||
for key in keys:
|
||||
if int(time.time()) - key['timestamp_added'] > a_day:
|
||||
marked_for_removal.append(key['_id'])
|
||||
|
||||
query = {
|
||||
"_id": {
|
||||
"$in": marked_for_removal
|
||||
}
|
||||
}
|
||||
|
||||
await collection.delete_many(query)
|
||||
|
||||
|
||||
async def cache_all_keys() -> None:
|
||||
"""Clones all keys from the database to the cache."""
|
||||
|
||||
client = AsyncIOMotorClient(MONGO_URI)
|
||||
collection = client['Liabilities']['rate-limited-keys']
|
||||
|
||||
keys = await collection.find().to_list(length=None)
|
||||
keys = [key['key'] for key in keys]
|
||||
|
||||
path = os.path.join(os.getcwd(), "cache", "rate_limited_keys.json")
|
||||
with open(path, "w") as file:
|
||||
json.dump(keys, file)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(remove_rated_keys())
|
|
@ -28,15 +28,15 @@ with open('config/config.yml', encoding='utf8') as f:
|
|||
moderation_debug_key_key = os.getenv('MODERATION_DEBUG_KEY')
|
||||
|
||||
async def handle(incoming_request: fastapi.Request):
|
||||
"""
|
||||
### Transfer a streaming response
|
||||
"""Transfer a streaming response
|
||||
Takes the request from the incoming request to the target endpoint.
|
||||
Checks method, token amount, auth and cost along with if request is NSFW.
|
||||
"""
|
||||
path = incoming_request.url.path.replace('v1/v1', 'v1').replace('//', '/')
|
||||
|
||||
path = incoming_request.url.path
|
||||
path = path.replace('/v1/v1', '/v1')
|
||||
|
||||
ip_address = await network.get_ip(incoming_request)
|
||||
print(f'[bold green]>{ip_address}[/bold green]')
|
||||
|
||||
if '/models' in path:
|
||||
return fastapi.responses.JSONResponse(content=models_list)
|
||||
|
@ -65,12 +65,17 @@ async def handle(incoming_request: fastapi.Request):
|
|||
return await errors.error(418, 'Invalid or inactive NovaAI API key!', 'Create a new NovaOSS API key or reactivate your account.')
|
||||
|
||||
if user.get('auth', {}).get('discord'):
|
||||
print(f'[bold green]>Discord[/bold green] {user["auth"]["discord"]}')
|
||||
print(f'[bold green]>{ip_address} ({user["auth"]["discord"]})[/bold green]')
|
||||
|
||||
ban_reason = user['status']['ban_reason']
|
||||
if ban_reason:
|
||||
return await errors.error(403, f'Your NovaAI account has been banned. Reason: \'{ban_reason}\'.', 'Contact the staff for an appeal.')
|
||||
|
||||
# Checking for enterprise status
|
||||
enterprise_keys = os.environ.get('NO_RATELIMIT_KEYS')
|
||||
if '/enterprise' in path and user.get('api_key') not in enterprise_keys:
|
||||
return await errors.error(403, 'Enterprise API is not available.', 'Contact the staff for an upgrade.')
|
||||
|
||||
if 'account/credits' in path:
|
||||
return fastapi.responses.JSONResponse({'credits': user['credits']})
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
import time
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
@ -24,22 +24,10 @@ async def get_ip(request) -> str:
|
|||
|
||||
def get_ratelimit_key(request) -> str:
|
||||
"""Get the IP address of the incoming request."""
|
||||
custom = os.environ('NO_RATELIMIT_IPS')
|
||||
ip = get_remote_address(request)
|
||||
|
||||
xff = None
|
||||
if request.headers.get('x-forwarded-for'):
|
||||
xff, *_ = request.headers['x-forwarded-for'].split(', ')
|
||||
if ip in custom:
|
||||
return f'enterprise_{ip}'
|
||||
|
||||
possible_ips = [
|
||||
xff,
|
||||
request.headers.get('cf-connecting-ip'),
|
||||
request.client.host
|
||||
]
|
||||
|
||||
detected_ip = next((i for i in possible_ips if i), None)
|
||||
|
||||
for whitelisted_ip in os.getenv('NO_RATELIMIT_IPS', '').split():
|
||||
if whitelisted_ip in detected_ip:
|
||||
custom_key = f'whitelisted-{time.time()}'
|
||||
return custom_key
|
||||
|
||||
return detected_ip
|
||||
return ip
|
|
@ -1,5 +1,6 @@
|
|||
import random
|
||||
import asyncio
|
||||
from db.key_validation import cached_key_is_rated
|
||||
|
||||
import providers
|
||||
|
||||
|
@ -32,6 +33,15 @@ async def balance_chat_request(payload: dict) -> dict:
|
|||
provider = random.choice(providers_available)
|
||||
target = await provider.chat_completion(**payload)
|
||||
|
||||
while True:
|
||||
key = target.get('provider_auth')
|
||||
|
||||
if not await cached_key_is_rated(key):
|
||||
break
|
||||
|
||||
else:
|
||||
target = await provider.chat_completion(**payload)
|
||||
|
||||
module_name = await _get_module_name(provider)
|
||||
target['module'] = module_name
|
||||
|
||||
|
|
13
api/main.py
13
api/main.py
|
@ -11,6 +11,7 @@ from bson.objectid import ObjectId
|
|||
from slowapi.errors import RateLimitExceeded
|
||||
from slowapi.middleware import SlowAPIMiddleware
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from slowapi.util import get_remote_address
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
|
||||
from helpers import network
|
||||
|
@ -34,11 +35,13 @@ app.include_router(core.router)
|
|||
|
||||
limiter = Limiter(
|
||||
swallow_errors=True,
|
||||
key_func=network.get_ratelimit_key, default_limits=[
|
||||
key_func=get_remote_address,
|
||||
default_limits=[
|
||||
'2/second',
|
||||
'20/minute',
|
||||
'300/hour'
|
||||
])
|
||||
|
||||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
app.add_middleware(SlowAPIMiddleware)
|
||||
|
@ -66,4 +69,10 @@ async def root():
|
|||
@app.route('/v1/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH'])
|
||||
async def v1_handler(request: fastapi.Request):
|
||||
res = await handler.handle(incoming_request=request)
|
||||
return res
|
||||
return res
|
||||
|
||||
@limiter.limit('100/second')
|
||||
@app.route('/enterprise/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH'])
|
||||
async def enterprise_handler(request: fastapi.Request):
|
||||
res = await handler.handle(incoming_request=request)
|
||||
return res
|
||||
|
|
|
@ -125,3 +125,7 @@ def get_proxy() -> Proxy:
|
|||
username=os.getenv('PROXY_USER'),
|
||||
password=os.getenv('PROXY_PASS')
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(get_proxy().url)
|
||||
print(get_proxy().connector)
|
||||
|
|
|
@ -2,9 +2,7 @@
|
|||
|
||||
import os
|
||||
import json
|
||||
import yaml
|
||||
import dhooks
|
||||
import asyncio
|
||||
import random
|
||||
import aiohttp
|
||||
import starlette
|
||||
|
||||
|
@ -17,6 +15,7 @@ import after_request
|
|||
import load_balancing
|
||||
|
||||
from helpers import network, chat, errors
|
||||
from db import key_validation
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
@ -44,11 +43,10 @@ async def respond(
|
|||
json_response = {}
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'User-Agent': 'axios/0.21.1',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
for _ in range(10):
|
||||
for i in range(20):
|
||||
# Load balancing: randomly selecting a suitable provider
|
||||
# If the request is a chat completion, then we need to load balance between chat providers
|
||||
# If the request is an organic request, then we need to load balance between organic providers
|
||||
|
@ -67,10 +65,7 @@ async def respond(
|
|||
'cookies': incoming_request.cookies
|
||||
})
|
||||
except ValueError as exc:
|
||||
if model in ['gpt-3.5-turbo', 'gpt-4', 'gpt-4-32k']:
|
||||
webhook = dhooks.Webhook(os.environ['DISCORD_WEBHOOK__API_ISSUE'])
|
||||
webhook.send(content=f'API Issue: **`{exc}`**\nhttps://i.imgflip.com/7uv122.jpg')
|
||||
yield await errors.yield_error(500, 'Sorry, the API has no working keys anymore.', 'The admins have been messaged automatically.')
|
||||
yield await errors.yield_error(500, f'Sorry, the API has no active API keys for {model}.', 'Please use a different model.')
|
||||
return
|
||||
|
||||
target_request['headers'].update(target_request.get('headers', {}))
|
||||
|
@ -91,7 +86,7 @@ async def respond(
|
|||
cookies=target_request.get('cookies'),
|
||||
ssl=False,
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
connect=0.3,
|
||||
connect=0.5,
|
||||
total=float(os.getenv('TRANSFER_TIMEOUT', '500'))
|
||||
),
|
||||
) as response:
|
||||
|
@ -103,6 +98,21 @@ async def respond(
|
|||
if response.content_type == 'application/json':
|
||||
data = await response.json()
|
||||
|
||||
error = data.get('error')
|
||||
match error:
|
||||
case None:
|
||||
pass
|
||||
|
||||
case _:
|
||||
match error.get('code'):
|
||||
case "insufficient_quota":
|
||||
key = target_request.get('provider_auth')
|
||||
await key_validation.log_rated_key(key)
|
||||
continue
|
||||
|
||||
case _:
|
||||
pass
|
||||
|
||||
if 'method_not_supported' in str(data):
|
||||
await errors.error(500, 'Sorry, this endpoint does not support this method.', data['error']['message'])
|
||||
|
||||
|
@ -119,6 +129,7 @@ async def respond(
|
|||
response.raise_for_status()
|
||||
except Exception as exc:
|
||||
if 'Too Many Requests' in str(exc):
|
||||
print('[!] too many requests')
|
||||
continue
|
||||
|
||||
async for chunk in response.content.iter_any():
|
||||
|
@ -134,14 +145,13 @@ async def respond(
|
|||
print('[!] chat response is empty')
|
||||
continue
|
||||
else:
|
||||
yield await errors.yield_error(500, 'Sorry, the provider is not responding. We\'re possibly getting rate-limited.', 'Please try again later.')
|
||||
print('[!] no response')
|
||||
yield await errors.yield_error(500, 'Sorry, the provider is not responding.', 'Please try again later.')
|
||||
return
|
||||
|
||||
if (not is_stream) and json_response:
|
||||
yield json.dumps(json_response)
|
||||
|
||||
print(f'[+] {path} -> {model or ""}')
|
||||
|
||||
await after_request.after_request(
|
||||
incoming_request=incoming_request,
|
||||
target_request=target_request,
|
||||
|
|
|
@ -164,7 +164,7 @@ async def test_function_calling():
|
|||
url=f'{api_endpoint}/chat/completions',
|
||||
headers=HEADERS,
|
||||
json=json_data,
|
||||
timeout=10,
|
||||
timeout=15,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
|
|
Loading…
Reference in a new issue