Compare commits

...

4 commits

Author SHA1 Message Date
nsde 7a22c1726f implemented key ratelimit checks 2023-10-02 21:09:39 +02:00
nsde 007050e9fe Fixes 2023-10-02 20:06:38 +02:00
nsde 577cdc0d0b Key rate-limit system thanks to Leander, full proxy list support 2023-10-02 20:06:18 +02:00
henceiusegentoo 1e2a596df3 Added key validation by API-key instead of IP
Added rate limited keys getting logged in a database
2023-09-23 21:41:48 +02:00
19 changed files with 238 additions and 1828 deletions

28
.gitignore vendored
View file

@ -1,7 +1,20 @@
*.zip # Environments
# !!! KEEP THESE ENTRIES AT THE TOP OF THIS FILE BECAUSE THEY CONTAIN THE MOST SENSITIVE DATA !!!
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
env.old/
.prod.env
# !!! KEEP THESE ENTRIES AT THE TOP OF THIS FILE BECAUSE THEY CONTAIN THE MOST SENSITIVE DATA !!!
rate_limited_keys.json
last_update.txt last_update.txt
*.zip
*.log.json *.log.json
/logs /logs
/log /log
@ -142,15 +155,6 @@ celerybeat.pid
# SageMath parsed files # SageMath parsed files
*.sage.py *.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings # Spyder project settings
.spyderproject .spyderproject
.spyproject .spyproject
@ -180,6 +184,8 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear # and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ .idea/
backups/ backups/
cache/
api/cache/rate_limited_keys.json

View file

@ -7,6 +7,7 @@
"**/.DS_Store": true, "**/.DS_Store": true,
"**/Thumbs.db": true, "**/Thumbs.db": true,
"**/__pycache__": true, "**/__pycache__": true,
"**/*.css.map": true,
"**/.vscode": true, "**/.vscode": true,
"**/*.map": true, "**/*.map": true,
"tests/__pycache__": true "tests/__pycache__": true

View file

@ -4,7 +4,7 @@
# git commit -am "Auto-trigger - Production server started" && git push origin Production # git commit -am "Auto-trigger - Production server started" && git push origin Production
# backup database # backup database
/usr/local/bin/python /home/nova-api/api/backup_manager/main.py pre_prodpush # /usr/local/bin/python /home/nova-api/api/backup_manager/main.py pre_prodpush
# Kill production server # Kill production server
fuser -k 2333/tcp fuser -k 2333/tcp
@ -22,4 +22,4 @@ cp env/.prod.env /home/nova-prod/.env
cd /home/nova-prod cd /home/nova-prod
# Start screen # Start screen
screen -S nova-api python run prod && sleep 5 screen -L -S nova-api python run prod && sleep 5

View file

@ -125,14 +125,45 @@ Set up a MongoDB database and set `MONGO_URI` to the MongoDB database connection
Want to use a proxy list? See the according section! Want to use a proxy list? See the according section!
Keep in mind to set `USE_PROXY_LIST` to `True`! Otherwise, the proxy list won't be used. Keep in mind to set `USE_PROXY_LIST` to `True`! Otherwise, the proxy list won't be used.
### `ACTUAL_IPS` (optional) ### Proxy Lists
To use proxy lists, navigate to `api/secret/proxies/` and create the following files:
- `http.txt`
- `socks4.txt`
- `socks5.txt`
Then, paste your proxies in the following format:
```
[username:password@]host:port
```
Whereas anything inside of `[]` is optional and the host can be an IP address or a hostname. Always specify the port.
If you're using [iproyal.com](https://iproyal.com?r=307932)<sup>affiliate link</sup>, follow the following steps:
- Order any type of proxy or proxies
- In the *Product Info* tab:
- set *Select port* to `SOCKS5`
- and *Select format* to `USER:PASS@IP:PORT`
#### Proxy List Examples
```
1.2.3.4:8080
user:pass@127.0.0.1:1337
aaaaaaaaaaaaa:bbbbbbbbbb@1.2.3.4:5555
```
In the proxy credential files, can use comments just like in Python.
**Important:** to activate the proxy lists, you need to change the `USE_PROXY_LIST` environment variable to `True`!
### ~~`ACTUAL_IPS` (optional)~~ (deprecated, might come back in the future)
This is a security measure to make sure a proxy, VPN, Tor or any other IP hiding service is used by the host when accessing "Closed"AI's API. This is a security measure to make sure a proxy, VPN, Tor or any other IP hiding service is used by the host when accessing "Closed"AI's API.
It is a space separated list of IP addresses that are allowed to access the API. It is a space separated list of IP addresses that are allowed to access the API.
You can also just add the *beginning* of an API address, like `12.123.` (without an asterisk!) to allow all IPs starting with `12.123.`. You can also just add the *beginning* of an API address, like `12.123.` (without an asterisk!) to allow all IPs starting with `12.123.`.
> To disable the warning if you don't have this feature enabled, set `ACTUAL_IPS` to `None`. > To disable the warning if you don't have this feature enabled, set `ACTUAL_IPS` to `None`.
### Timeout ### Timeout
`TRANSFER_TIMEOUT` seconds to wait until the program throws an exception for if the request takes too long. We recommend rather long times like `120` for two minutes. `TRANSFER_TIMEOUT` seconds to wait until the program throws an exception for if the request takes too long. We recommend rather long times like `500` for 500 seconds.
### Core Keys ### Core Keys
`CORE_API_KEY` specifies the **very secret key** for which need to access the entire user database etc. `CORE_API_KEY` specifies the **very secret key** for which need to access the entire user database etc.
@ -145,29 +176,6 @@ You can also just add the *beginning* of an API address, like `12.123.` (without
### Other ### Other
`KEYGEN_INFIX` can be almost any string (avoid spaces or special characters) - this string will be put in the middle of every NovaAI API key which is generated. This is useful for identifying the source of the key using e.g. RegEx. `KEYGEN_INFIX` can be almost any string (avoid spaces or special characters) - this string will be put in the middle of every NovaAI API key which is generated. This is useful for identifying the source of the key using e.g. RegEx.
## Proxy Lists
To use proxy lists, navigate to `api/secret/proxies/` and create the following files:
- `http.txt`
- `socks4.txt`
- `socks5.txt`
Then, paste your proxies in the following format:
```
[username:password@]host:port
```
e.g.
```
1.2.3.4:8080
user:pass@127.0.0.1:1337
```
You can use comments just like in Python.
**Important:** to use the proxy lists, you need to change the `USE_PROXY_LIST` environment variable to `True`!
## Run ## Run
> **Warning:** read the according section for production usage! > **Warning:** read the according section for production usage!
@ -186,17 +194,20 @@ python run 1337
``` ```
## Adding a provider ## Adding a provider
To be documented!]
## Test if it works ## Run tests
Make sure the API server is running on the port you specified and run:
`python checks` `python checks`
## Ports ## Default Ports
```yml ```yml
2332: Developement (default) 2332: Developement
2333: Production 2333: Production
``` ```
## Production ## Production
Make sure your server is secure and up to date. Make sure your server is secure and up to date.
Check everything. Check everything.

View file

@ -82,8 +82,5 @@
# # ==================================================================================== # # ====================================================================================
# def prune():
# # gets all users from
# if __name__ == '__main__': # if __name__ == '__main__':
# launch() # launch()

View file

@ -1,4 +1,4 @@
from db import logs, stats, users from db import logs, stats, users, key_validation
from helpers import network from helpers import network
async def after_request( async def after_request(
@ -23,6 +23,8 @@ async def after_request(
await stats.manager.add_ip_address(ip_address) await stats.manager.add_ip_address(ip_address)
await stats.manager.add_path(path) await stats.manager.add_path(path)
await stats.manager.add_target(target_request['url']) await stats.manager.add_target(target_request['url'])
await key_validation.remove_rated_keys()
await key_validation.cache_all_keys()
if is_chat: if is_chat:
await stats.manager.add_model(model) await stats.manager.add_model(model)

View file

@ -33,7 +33,7 @@ async def make_backup(output_dir: str):
os.mkdir(f'{output_dir}/{database}') os.mkdir(f'{output_dir}/{database}')
for collection in databases[database]: for collection in databases[database]:
print(f'Making backup for {database}/{collection}') print(f'Initiated database backup for {database}/{collection}')
await make_backup_for_collection(database, collection, output_dir) await make_backup_for_collection(database, collection, output_dir)
async def make_backup_for_collection(database, collection, output_dir): async def make_backup_for_collection(database, collection, output_dir):

View file

@ -1 +0,0 @@
{"LTC": 64.665, "_last_updated": 1695334741.4905503, "BTC": 26583.485, "MATIC": 0.52075, "XMR": 146.46058828041404, "ADA": 0.2455, "USDT": 1.000005, "ETH": 1586.115, "USD": 1.0, "EUR": 1.0662838016640013}

1709
api/cache/models.json vendored

File diff suppressed because it is too large Load diff

View file

@ -194,4 +194,6 @@ async def get_finances(incoming_request: fastapi.Request):
amount_in_usd = await get_crypto_price(currency) * amount amount_in_usd = await get_crypto_price(currency) * amount
transactions[table][transactions[table].index(transaction)]['amount_usd'] = amount_in_usd transactions[table][transactions[table].index(transaction)]['amount_usd'] = amount_in_usd
transactions['timestamp'] = time.time()
return transactions return transactions

84
api/db/key_validation.py Normal file
View file

@ -0,0 +1,84 @@
import os
import time
import asyncio
import json
from dotenv import load_dotenv
from motor.motor_asyncio import AsyncIOMotorClient
load_dotenv()
MONGO_URI = os.getenv('MONGO_URI')
async def log_rated_key(key: str) -> None:
"""Logs a key that has been rate limited to the database."""
client = AsyncIOMotorClient(MONGO_URI)
scheme = {
'key': key,
'timestamp_added': int(time.time())
}
collection = client['Liabilities']['rate-limited-keys']
await collection.insert_one(scheme)
async def key_is_rated(key: str) -> bool:
"""Checks if a key is rate limited."""
client = AsyncIOMotorClient(MONGO_URI)
collection = client['Liabilities']['rate-limited-keys']
query = {
'key': key
}
result = await collection.find_one(query)
return result is not None
async def cached_key_is_rated(key: str) -> bool:
path = os.path.join(os.getcwd(), 'cache', 'rate_limited_keys.json')
with open(path, 'r', encoding='utf8') as file:
keys = json.load(file)
return key in keys
async def remove_rated_keys() -> None:
"""Removes all keys that have been rate limited for more than a day."""
client = AsyncIOMotorClient(MONGO_URI)
collection = client['Liabilities']['rate-limited-keys']
keys = await collection.find().to_list(length=None)
marked_for_removal = []
for key in keys:
if int(time.time()) - key['timestamp_added'] > 86400:
marked_for_removal.append(key['_id'])
query = {
'_id': {
'$in': marked_for_removal
}
}
await collection.delete_many(query)
async def cache_all_keys() -> None:
"""Clones all keys from the database to the cache."""
client = AsyncIOMotorClient(MONGO_URI)
collection = client['Liabilities']['rate-limited-keys']
keys = await collection.find().to_list(length=None)
keys = [key['key'] for key in keys]
path = os.path.join(os.getcwd(), 'cache', 'rate_limited_keys.json')
with open(path, 'w') as file:
json.dump(keys, file)
if __name__ == "__main__":
asyncio.run(remove_rated_keys())

View file

@ -28,15 +28,15 @@ with open('config/config.yml', encoding='utf8') as f:
moderation_debug_key_key = os.getenv('MODERATION_DEBUG_KEY') moderation_debug_key_key = os.getenv('MODERATION_DEBUG_KEY')
async def handle(incoming_request: fastapi.Request): async def handle(incoming_request: fastapi.Request):
""" """Transfer a streaming response
### Transfer a streaming response
Takes the request from the incoming request to the target endpoint. Takes the request from the incoming request to the target endpoint.
Checks method, token amount, auth and cost along with if request is NSFW. Checks method, token amount, auth and cost along with if request is NSFW.
""" """
path = incoming_request.url.path.replace('v1/v1', 'v1').replace('//', '/')
path = incoming_request.url.path
path = path.replace('/v1/v1', '/v1')
ip_address = await network.get_ip(incoming_request) ip_address = await network.get_ip(incoming_request)
print(f'[bold green]>{ip_address}[/bold green]')
if '/models' in path: if '/models' in path:
return fastapi.responses.JSONResponse(content=models_list) return fastapi.responses.JSONResponse(content=models_list)
@ -65,12 +65,17 @@ async def handle(incoming_request: fastapi.Request):
return await errors.error(418, 'Invalid or inactive NovaAI API key!', 'Create a new NovaOSS API key or reactivate your account.') return await errors.error(418, 'Invalid or inactive NovaAI API key!', 'Create a new NovaOSS API key or reactivate your account.')
if user.get('auth', {}).get('discord'): if user.get('auth', {}).get('discord'):
print(f'[bold green]>Discord[/bold green] {user["auth"]["discord"]}') print(f'[bold green]>{ip_address} ({user["auth"]["discord"]})[/bold green]')
ban_reason = user['status']['ban_reason'] ban_reason = user['status']['ban_reason']
if ban_reason: if ban_reason:
return await errors.error(403, f'Your NovaAI account has been banned. Reason: \'{ban_reason}\'.', 'Contact the staff for an appeal.') return await errors.error(403, f'Your NovaAI account has been banned. Reason: \'{ban_reason}\'.', 'Contact the staff for an appeal.')
# Checking for enterprise status
enterprise_keys = os.environ.get('NO_RATELIMIT_KEYS')
if '/enterprise' in path and user.get('api_key') not in enterprise_keys:
return await errors.error(403, 'Enterprise API is not available.', 'Contact the staff for an upgrade.')
if 'account/credits' in path: if 'account/credits' in path:
return fastapi.responses.JSONResponse({'credits': user['credits']}) return fastapi.responses.JSONResponse({'credits': user['credits']})

View file

@ -1,7 +1,7 @@
import os import os
import time import time
from dotenv import load_dotenv from dotenv import load_dotenv
from slowapi.util import get_remote_address
load_dotenv() load_dotenv()
@ -24,22 +24,10 @@ async def get_ip(request) -> str:
def get_ratelimit_key(request) -> str: def get_ratelimit_key(request) -> str:
"""Get the IP address of the incoming request.""" """Get the IP address of the incoming request."""
custom = os.environ('NO_RATELIMIT_IPS')
ip = get_remote_address(request)
xff = None if ip in custom:
if request.headers.get('x-forwarded-for'): return f'enterprise_{ip}'
xff, *_ = request.headers['x-forwarded-for'].split(', ')
possible_ips = [ return ip
xff,
request.headers.get('cf-connecting-ip'),
request.client.host
]
detected_ip = next((i for i in possible_ips if i), None)
for whitelisted_ip in os.getenv('NO_RATELIMIT_IPS', '').split():
if whitelisted_ip in detected_ip:
custom_key = f'whitelisted-{time.time()}'
return custom_key
return detected_ip

View file

@ -31,7 +31,6 @@ async def balance_chat_request(payload: dict) -> dict:
provider = random.choice(providers_available) provider = random.choice(providers_available)
target = await provider.chat_completion(**payload) target = await provider.chat_completion(**payload)
module_name = await _get_module_name(provider) module_name = await _get_module_name(provider)
target['module'] = module_name target['module'] = module_name

View file

@ -11,10 +11,9 @@ from bson.objectid import ObjectId
from slowapi.errors import RateLimitExceeded from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware from slowapi.middleware import SlowAPIMiddleware
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from slowapi.util import get_remote_address
from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi import Limiter, _rate_limit_exceeded_handler
from helpers import network
import core import core
import handler import handler
@ -34,11 +33,13 @@ app.include_router(core.router)
limiter = Limiter( limiter = Limiter(
swallow_errors=True, swallow_errors=True,
key_func=network.get_ratelimit_key, default_limits=[ key_func=get_remote_address,
'2/second', default_limits=[
'1/second',
'20/minute', '20/minute',
'300/hour' '300/hour'
]) ])
app.state.limiter = limiter app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_middleware(SlowAPIMiddleware) app.add_middleware(SlowAPIMiddleware)
@ -67,3 +68,9 @@ async def root():
async def v1_handler(request: fastapi.Request): async def v1_handler(request: fastapi.Request):
res = await handler.handle(incoming_request=request) res = await handler.handle(incoming_request=request)
return res return res
@limiter.limit('100/minute', '1000/hour')
@app.route('/enterprise/v1/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH'])
async def enterprise_handler(request: fastapi.Request):
res = await handler.handle(incoming_request=request)
return res

View file

@ -49,7 +49,7 @@ async def invalidate_key(provider_and_key: str) -> None:
with open(f'secret/{provider}.invalid.txt', 'a', encoding='utf8') as f: with open(f'secret/{provider}.invalid.txt', 'a', encoding='utf8') as f:
f.write(key + '\n') f.write(key + '\n')
await invalidation_webhook(provider_and_key) # await invalidation_webhook(provider_and_key)
if __name__ == '__main__': if __name__ == '__main__':
asyncio.run(invalidate_key('closed>demo-...')) asyncio.run(invalidate_key('closed>demo-...'))

View file

@ -37,15 +37,20 @@ class Proxy:
url = url.split('://')[1] url = url.split('://')[1]
if '@' in url: if '@' in url:
username = url.split('@')[1].split(':')[0] username = url.split('@')[0].split(':')[0]
password = url.split('@')[1].split(':')[1] password = url.split('@')[0].split(':')[1]
host_or_ip = url.split(':')[0] host_or_ip = url.split('@')[-1].split(':')[0]
port = url.split(':')[1] port = int(url.split('@')[-1].split(':')[1])
self.proxy_type = proxy_type self.proxy_type = proxy_type
self.host_or_ip = host_or_ip self.host_or_ip = host_or_ip
self.ip_address = socket.gethostbyname(self.host_or_ip) # get ip address from host
try:
self.ip_address = socket.gethostbyname(self.host_or_ip) # get ip address from host
except socket.gaierror:
self.ip_address = self.host_or_ip
self.host = self.host_or_ip self.host = self.host_or_ip
self.port = port self.port = port
self.username = username self.username = username
@ -78,7 +83,7 @@ class Proxy:
return aiohttp_socks.ProxyConnector( return aiohttp_socks.ProxyConnector(
proxy_type=proxy_types[self.proxy_type], proxy_type=proxy_types[self.proxy_type],
host=self.ip_address, host=self.host,
port=self.port, port=self.port,
rdns=False, rdns=False,
username=self.username, username=self.username,
@ -89,16 +94,15 @@ class Proxy:
proxies_in_files = [] proxies_in_files = []
try: for proxy_type in ['http', 'socks4', 'socks5']:
for proxy_type in ['http', 'socks4', 'socks5']: try:
with open(f'secret/proxies/{proxy_type}.txt') as f: with open(f'secret/proxies/{proxy_type}.txt') as f:
for line in f: for line in f:
clean_line = line.split('#', 1)[0].strip() clean_line = line.split('#', 1)[0].strip()
if clean_line: if clean_line:
proxies_in_files.append(f'{proxy_type}://{clean_line}') proxies_in_files.append(f'{proxy_type}://{clean_line}')
except FileNotFoundError:
except FileNotFoundError: pass
pass
## Manages the proxy list ## Manages the proxy list
@ -125,3 +129,6 @@ def get_proxy() -> Proxy:
username=os.getenv('PROXY_USER'), username=os.getenv('PROXY_USER'),
password=os.getenv('PROXY_PASS') password=os.getenv('PROXY_PASS')
) )
if __name__ == '__main__':
print(get_proxy().url)

View file

@ -2,9 +2,7 @@
import os import os
import json import json
import yaml import random
import dhooks
import asyncio
import aiohttp import aiohttp
import starlette import starlette
@ -16,7 +14,9 @@ import provider_auth
import after_request import after_request
import load_balancing import load_balancing
from helpers import network, chat, errors from helpers import errors
from db import key_validation
load_dotenv() load_dotenv()
@ -44,21 +44,15 @@ async def respond(
json_response = {} json_response = {}
headers = { headers = {
'Content-Type': 'application/json', 'Content-Type': 'application/json'
'User-Agent': 'axios/0.21.1',
} }
for _ in range(10): for _ in range(10):
# Load balancing: randomly selecting a suitable provider # Load balancing: randomly selecting a suitable provider
# If the request is a chat completion, then we need to load balance between chat providers
# If the request is an organic request, then we need to load balance between organic providers
try: try:
if is_chat: if is_chat:
target_request = await load_balancing.balance_chat_request(payload) target_request = await load_balancing.balance_chat_request(payload)
else: else:
# In this case we are doing a organic request. "organic" means that it's not using a reverse engineered front-end, but rather ClosedAI's API directly
# churchless.tech is an example of an organic provider, because it redirects the request to ClosedAI.
target_request = await load_balancing.balance_organic_request({ target_request = await load_balancing.balance_organic_request({
'method': incoming_request.method, 'method': incoming_request.method,
'path': path, 'path': path,
@ -67,10 +61,7 @@ async def respond(
'cookies': incoming_request.cookies 'cookies': incoming_request.cookies
}) })
except ValueError as exc: except ValueError as exc:
if model in ['gpt-3.5-turbo', 'gpt-4', 'gpt-4-32k']: yield await errors.yield_error(500, f'Sorry, the API has no active API keys for {model}.', 'Please use a different model.')
webhook = dhooks.Webhook(os.environ['DISCORD_WEBHOOK__API_ISSUE'])
webhook.send(content=f'API Issue: **`{exc}`**\nhttps://i.imgflip.com/7uv122.jpg')
yield await errors.yield_error(500, 'Sorry, the API has no working keys anymore.', 'The admins have been messaged automatically.')
return return
target_request['headers'].update(target_request.get('headers', {})) target_request['headers'].update(target_request.get('headers', {}))
@ -91,34 +82,56 @@ async def respond(
cookies=target_request.get('cookies'), cookies=target_request.get('cookies'),
ssl=False, ssl=False,
timeout=aiohttp.ClientTimeout( timeout=aiohttp.ClientTimeout(
connect=0.3, connect=1.0,
total=float(os.getenv('TRANSFER_TIMEOUT', '500')) total=float(os.getenv('TRANSFER_TIMEOUT', '500'))
), ),
) as response: ) as response:
is_stream = response.content_type == 'text/event-stream' is_stream = response.content_type == 'text/event-stream'
if response.status == 429: if response.status == 429:
await key_validation.log_rated_key(target_request.get('provider_auth'))
continue continue
if response.content_type == 'application/json': if response.content_type == 'application/json':
data = await response.json() data = await response.json()
error = data.get('error')
match error:
case None:
pass
case _:
key = target_request.get('provider_auth')
match error.get('code'):
case 'invalid_api_key':
await key_validation.log_rated_key(key)
print('[!] invalid key', key)
case _:
print('[!] unknown error with key: ', key, error)
if 'method_not_supported' in str(data): if 'method_not_supported' in str(data):
await errors.error(500, 'Sorry, this endpoint does not support this method.', data['error']['message']) await errors.error(500, 'Sorry, this endpoint does not support this method.', data['error']['message'])
if 'invalid_api_key' in str(data) or 'account_deactivated' in str(data): if 'invalid_api_key' in str(data) or 'account_deactivated' in str(data):
print('[!] invalid api key', target_request.get('provider_auth'))
await provider_auth.invalidate_key(target_request.get('provider_auth')) await provider_auth.invalidate_key(target_request.get('provider_auth'))
continue continue
if response.ok: if response.ok:
json_response = data json_response = data
else:
print('[!] error', data)
continue
if is_stream: if is_stream:
try: try:
response.raise_for_status() response.raise_for_status()
except Exception as exc: except Exception as exc:
if 'Too Many Requests' in str(exc): if 'Too Many Requests' in str(exc):
print('[!] too many requests')
continue continue
async for chunk in response.content.iter_any(): async for chunk in response.content.iter_any():
@ -128,20 +141,18 @@ async def respond(
break break
except Exception as exc: except Exception as exc:
if 'too many requests' in str(exc):
await key_validation.log_rated_key(key)
continue continue
if (not json_response) and is_chat:
print('[!] chat response is empty')
continue
else: else:
yield await errors.yield_error(500, 'Sorry, the provider is not responding. We\'re possibly getting rate-limited.', 'Please try again later.') yield await errors.yield_error(500, 'Sorry, our API seems to have issues connecting to our provider(s).', 'This most likely isn\'t your fault. Please try again later.')
return return
if (not is_stream) and json_response: if (not is_stream) and json_response:
yield json.dumps(json_response) yield json.dumps(json_response)
print(f'[+] {path} -> {model or ""}')
await after_request.after_request( await after_request.after_request(
incoming_request=incoming_request, incoming_request=incoming_request,
target_request=target_request, target_request=target_request,

View file

@ -164,7 +164,7 @@ async def test_function_calling():
url=f'{api_endpoint}/chat/completions', url=f'{api_endpoint}/chat/completions',
headers=HEADERS, headers=HEADERS,
json=json_data, json=json_data,
timeout=10, timeout=15,
) )
response.raise_for_status() response.raise_for_status()
@ -208,8 +208,8 @@ async def demo():
else: else:
raise ConnectionError('API Server is not running.') raise ConnectionError('API Server is not running.')
print('[lightblue]Checking if function calling works...') # print('[lightblue]Checking if function calling works...')
print(await test_function_calling()) # print(await test_function_calling())
print('Checking non-streamed chat completions...') print('Checking non-streamed chat completions...')
print(await test_chat_non_stream_gpt4()) print(await test_chat_non_stream_gpt4())
@ -220,8 +220,8 @@ async def demo():
# print('[lightblue]Checking if image generation works...') # print('[lightblue]Checking if image generation works...')
# print(await test_image_generation()) # print(await test_image_generation())
print('Checking the models endpoint...') # print('Checking the models endpoint...')
print(await test_models()) # print(await test_models())
except Exception as exc: except Exception as exc:
print('[red]Error: ' + str(exc)) print('[red]Error: ' + str(exc))