mirror of
https://github.com/NovaOSS/nova-api.git
synced 2024-11-25 18:43:57 +01:00
implemented key ratelimit checks
This commit is contained in:
parent
007050e9fe
commit
7a22c1726f
|
@ -82,8 +82,5 @@
|
||||||
|
|
||||||
# # ====================================================================================
|
# # ====================================================================================
|
||||||
|
|
||||||
# def prune():
|
|
||||||
# # gets all users from
|
|
||||||
|
|
||||||
# if __name__ == '__main__':
|
# if __name__ == '__main__':
|
||||||
# launch()
|
# launch()
|
||||||
|
|
|
@ -41,7 +41,7 @@ async def key_is_rated(key: str) -> bool:
|
||||||
async def cached_key_is_rated(key: str) -> bool:
|
async def cached_key_is_rated(key: str) -> bool:
|
||||||
path = os.path.join(os.getcwd(), 'cache', 'rate_limited_keys.json')
|
path = os.path.join(os.getcwd(), 'cache', 'rate_limited_keys.json')
|
||||||
|
|
||||||
with open(path, 'r') as file:
|
with open(path, 'r', encoding='utf8') as file:
|
||||||
keys = json.load(file)
|
keys = json.load(file)
|
||||||
|
|
||||||
return key in keys
|
return key in keys
|
||||||
|
@ -49,8 +49,6 @@ async def cached_key_is_rated(key: str) -> bool:
|
||||||
async def remove_rated_keys() -> None:
|
async def remove_rated_keys() -> None:
|
||||||
"""Removes all keys that have been rate limited for more than a day."""
|
"""Removes all keys that have been rate limited for more than a day."""
|
||||||
|
|
||||||
a_day = 86400
|
|
||||||
|
|
||||||
client = AsyncIOMotorClient(MONGO_URI)
|
client = AsyncIOMotorClient(MONGO_URI)
|
||||||
collection = client['Liabilities']['rate-limited-keys']
|
collection = client['Liabilities']['rate-limited-keys']
|
||||||
|
|
||||||
|
@ -58,7 +56,7 @@ async def remove_rated_keys() -> None:
|
||||||
|
|
||||||
marked_for_removal = []
|
marked_for_removal = []
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if int(time.time()) - key['timestamp_added'] > a_day:
|
if int(time.time()) - key['timestamp_added'] > 86400:
|
||||||
marked_for_removal.append(key['_id'])
|
marked_for_removal.append(key['_id'])
|
||||||
|
|
||||||
query = {
|
query = {
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
import random
|
import random
|
||||||
import asyncio
|
import asyncio
|
||||||
from db.key_validation import cached_key_is_rated
|
|
||||||
|
|
||||||
import providers
|
import providers
|
||||||
|
|
||||||
|
@ -32,16 +31,6 @@ async def balance_chat_request(payload: dict) -> dict:
|
||||||
|
|
||||||
provider = random.choice(providers_available)
|
provider = random.choice(providers_available)
|
||||||
target = await provider.chat_completion(**payload)
|
target = await provider.chat_completion(**payload)
|
||||||
|
|
||||||
while True:
|
|
||||||
key = target.get('provider_auth')
|
|
||||||
|
|
||||||
if not await cached_key_is_rated(key):
|
|
||||||
break
|
|
||||||
|
|
||||||
else:
|
|
||||||
target = await provider.chat_completion(**payload)
|
|
||||||
|
|
||||||
module_name = await _get_module_name(provider)
|
module_name = await _get_module_name(provider)
|
||||||
target['module'] = module_name
|
target['module'] = module_name
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,7 @@ async def invalidate_key(provider_and_key: str) -> None:
|
||||||
with open(f'secret/{provider}.invalid.txt', 'a', encoding='utf8') as f:
|
with open(f'secret/{provider}.invalid.txt', 'a', encoding='utf8') as f:
|
||||||
f.write(key + '\n')
|
f.write(key + '\n')
|
||||||
|
|
||||||
await invalidation_webhook(provider_and_key)
|
# await invalidation_webhook(provider_and_key)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
asyncio.run(invalidate_key('closed>demo-...'))
|
asyncio.run(invalidate_key('closed>demo-...'))
|
||||||
|
|
|
@ -14,7 +14,8 @@ import provider_auth
|
||||||
import after_request
|
import after_request
|
||||||
import load_balancing
|
import load_balancing
|
||||||
|
|
||||||
from helpers import network, chat, errors
|
from helpers import errors
|
||||||
|
|
||||||
from db import key_validation
|
from db import key_validation
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
@ -48,15 +49,10 @@ async def respond(
|
||||||
|
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
# Load balancing: randomly selecting a suitable provider
|
# Load balancing: randomly selecting a suitable provider
|
||||||
# If the request is a chat completion, then we need to load balance between chat providers
|
|
||||||
# If the request is an organic request, then we need to load balance between organic providers
|
|
||||||
try:
|
try:
|
||||||
if is_chat:
|
if is_chat:
|
||||||
target_request = await load_balancing.balance_chat_request(payload)
|
target_request = await load_balancing.balance_chat_request(payload)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
# In this case we are doing a organic request. "organic" means that it's not using a reverse engineered front-end, but rather ClosedAI's API directly
|
|
||||||
# churchless.tech is an example of an organic provider, because it redirects the request to ClosedAI.
|
|
||||||
target_request = await load_balancing.balance_organic_request({
|
target_request = await load_balancing.balance_organic_request({
|
||||||
'method': incoming_request.method,
|
'method': incoming_request.method,
|
||||||
'path': path,
|
'path': path,
|
||||||
|
@ -93,6 +89,7 @@ async def respond(
|
||||||
is_stream = response.content_type == 'text/event-stream'
|
is_stream = response.content_type == 'text/event-stream'
|
||||||
|
|
||||||
if response.status == 429:
|
if response.status == 429:
|
||||||
|
await key_validation.log_rated_key(target_request.get('provider_auth'))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if response.content_type == 'application/json':
|
if response.content_type == 'application/json':
|
||||||
|
@ -118,7 +115,6 @@ async def respond(
|
||||||
await errors.error(500, 'Sorry, this endpoint does not support this method.', data['error']['message'])
|
await errors.error(500, 'Sorry, this endpoint does not support this method.', data['error']['message'])
|
||||||
|
|
||||||
if 'invalid_api_key' in str(data) or 'account_deactivated' in str(data):
|
if 'invalid_api_key' in str(data) or 'account_deactivated' in str(data):
|
||||||
print('[!] invalid api key', target_request.get('provider_auth'))
|
|
||||||
await provider_auth.invalidate_key(target_request.get('provider_auth'))
|
await provider_auth.invalidate_key(target_request.get('provider_auth'))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -140,13 +136,14 @@ async def respond(
|
||||||
|
|
||||||
async for chunk in response.content.iter_any():
|
async for chunk in response.content.iter_any():
|
||||||
chunk = chunk.decode('utf8').strip()
|
chunk = chunk.decode('utf8').strip()
|
||||||
print(1)
|
|
||||||
yield chunk + '\n\n'
|
yield chunk + '\n\n'
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
print('[!] exception', exc)
|
if 'too many requests' in str(exc):
|
||||||
|
await key_validation.log_rated_key(key)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -208,8 +208,8 @@ async def demo():
|
||||||
else:
|
else:
|
||||||
raise ConnectionError('API Server is not running.')
|
raise ConnectionError('API Server is not running.')
|
||||||
|
|
||||||
print('[lightblue]Checking if function calling works...')
|
# print('[lightblue]Checking if function calling works...')
|
||||||
print(await test_function_calling())
|
# print(await test_function_calling())
|
||||||
|
|
||||||
print('Checking non-streamed chat completions...')
|
print('Checking non-streamed chat completions...')
|
||||||
print(await test_chat_non_stream_gpt4())
|
print(await test_chat_non_stream_gpt4())
|
||||||
|
@ -220,8 +220,8 @@ async def demo():
|
||||||
# print('[lightblue]Checking if image generation works...')
|
# print('[lightblue]Checking if image generation works...')
|
||||||
# print(await test_image_generation())
|
# print(await test_image_generation())
|
||||||
|
|
||||||
print('Checking the models endpoint...')
|
# print('Checking the models endpoint...')
|
||||||
print(await test_models())
|
# print(await test_models())
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
print('[red]Error: ' + str(exc))
|
print('[red]Error: ' + str(exc))
|
||||||
|
|
Loading…
Reference in a new issue