mirror of
https://github.com/NovaOSS/nova-api.git
synced 2024-11-25 16:13:58 +01:00
implemented key ratelimit checks
This commit is contained in:
parent
007050e9fe
commit
7a22c1726f
|
@ -82,8 +82,5 @@
|
|||
|
||||
# # ====================================================================================
|
||||
|
||||
# def prune():
|
||||
# # gets all users from
|
||||
|
||||
# if __name__ == '__main__':
|
||||
# launch()
|
||||
|
|
|
@ -41,7 +41,7 @@ async def key_is_rated(key: str) -> bool:
|
|||
async def cached_key_is_rated(key: str) -> bool:
|
||||
path = os.path.join(os.getcwd(), 'cache', 'rate_limited_keys.json')
|
||||
|
||||
with open(path, 'r') as file:
|
||||
with open(path, 'r', encoding='utf8') as file:
|
||||
keys = json.load(file)
|
||||
|
||||
return key in keys
|
||||
|
@ -49,8 +49,6 @@ async def cached_key_is_rated(key: str) -> bool:
|
|||
async def remove_rated_keys() -> None:
|
||||
"""Removes all keys that have been rate limited for more than a day."""
|
||||
|
||||
a_day = 86400
|
||||
|
||||
client = AsyncIOMotorClient(MONGO_URI)
|
||||
collection = client['Liabilities']['rate-limited-keys']
|
||||
|
||||
|
@ -58,7 +56,7 @@ async def remove_rated_keys() -> None:
|
|||
|
||||
marked_for_removal = []
|
||||
for key in keys:
|
||||
if int(time.time()) - key['timestamp_added'] > a_day:
|
||||
if int(time.time()) - key['timestamp_added'] > 86400:
|
||||
marked_for_removal.append(key['_id'])
|
||||
|
||||
query = {
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import random
|
||||
import asyncio
|
||||
from db.key_validation import cached_key_is_rated
|
||||
|
||||
import providers
|
||||
|
||||
|
@ -32,16 +31,6 @@ async def balance_chat_request(payload: dict) -> dict:
|
|||
|
||||
provider = random.choice(providers_available)
|
||||
target = await provider.chat_completion(**payload)
|
||||
|
||||
while True:
|
||||
key = target.get('provider_auth')
|
||||
|
||||
if not await cached_key_is_rated(key):
|
||||
break
|
||||
|
||||
else:
|
||||
target = await provider.chat_completion(**payload)
|
||||
|
||||
module_name = await _get_module_name(provider)
|
||||
target['module'] = module_name
|
||||
|
||||
|
|
|
@ -49,7 +49,7 @@ async def invalidate_key(provider_and_key: str) -> None:
|
|||
with open(f'secret/{provider}.invalid.txt', 'a', encoding='utf8') as f:
|
||||
f.write(key + '\n')
|
||||
|
||||
await invalidation_webhook(provider_and_key)
|
||||
# await invalidation_webhook(provider_and_key)
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(invalidate_key('closed>demo-...'))
|
||||
|
|
|
@ -14,7 +14,8 @@ import provider_auth
|
|||
import after_request
|
||||
import load_balancing
|
||||
|
||||
from helpers import network, chat, errors
|
||||
from helpers import errors
|
||||
|
||||
from db import key_validation
|
||||
|
||||
load_dotenv()
|
||||
|
@ -48,15 +49,10 @@ async def respond(
|
|||
|
||||
for _ in range(10):
|
||||
# Load balancing: randomly selecting a suitable provider
|
||||
# If the request is a chat completion, then we need to load balance between chat providers
|
||||
# If the request is an organic request, then we need to load balance between organic providers
|
||||
try:
|
||||
if is_chat:
|
||||
target_request = await load_balancing.balance_chat_request(payload)
|
||||
else:
|
||||
|
||||
# In this case we are doing a organic request. "organic" means that it's not using a reverse engineered front-end, but rather ClosedAI's API directly
|
||||
# churchless.tech is an example of an organic provider, because it redirects the request to ClosedAI.
|
||||
target_request = await load_balancing.balance_organic_request({
|
||||
'method': incoming_request.method,
|
||||
'path': path,
|
||||
|
@ -93,6 +89,7 @@ async def respond(
|
|||
is_stream = response.content_type == 'text/event-stream'
|
||||
|
||||
if response.status == 429:
|
||||
await key_validation.log_rated_key(target_request.get('provider_auth'))
|
||||
continue
|
||||
|
||||
if response.content_type == 'application/json':
|
||||
|
@ -118,7 +115,6 @@ async def respond(
|
|||
await errors.error(500, 'Sorry, this endpoint does not support this method.', data['error']['message'])
|
||||
|
||||
if 'invalid_api_key' in str(data) or 'account_deactivated' in str(data):
|
||||
print('[!] invalid api key', target_request.get('provider_auth'))
|
||||
await provider_auth.invalidate_key(target_request.get('provider_auth'))
|
||||
continue
|
||||
|
||||
|
@ -140,13 +136,14 @@ async def respond(
|
|||
|
||||
async for chunk in response.content.iter_any():
|
||||
chunk = chunk.decode('utf8').strip()
|
||||
print(1)
|
||||
yield chunk + '\n\n'
|
||||
|
||||
break
|
||||
|
||||
except Exception as exc:
|
||||
print('[!] exception', exc)
|
||||
if 'too many requests' in str(exc):
|
||||
await key_validation.log_rated_key(key)
|
||||
|
||||
continue
|
||||
|
||||
else:
|
||||
|
|
|
@ -208,8 +208,8 @@ async def demo():
|
|||
else:
|
||||
raise ConnectionError('API Server is not running.')
|
||||
|
||||
print('[lightblue]Checking if function calling works...')
|
||||
print(await test_function_calling())
|
||||
# print('[lightblue]Checking if function calling works...')
|
||||
# print(await test_function_calling())
|
||||
|
||||
print('Checking non-streamed chat completions...')
|
||||
print(await test_chat_non_stream_gpt4())
|
||||
|
@ -220,8 +220,8 @@ async def demo():
|
|||
# print('[lightblue]Checking if image generation works...')
|
||||
# print(await test_image_generation())
|
||||
|
||||
print('Checking the models endpoint...')
|
||||
print(await test_models())
|
||||
# print('Checking the models endpoint...')
|
||||
# print(await test_models())
|
||||
|
||||
except Exception as exc:
|
||||
print('[red]Error: ' + str(exc))
|
||||
|
|
Loading…
Reference in a new issue