Compare commits

...

2 commits

Author SHA1 Message Date
monosans 03f8c9ad0a
Fix dangling asyncio tasks 2023-10-07 09:14:51 +03:00
nsde 719f29fb29 Added azure endpoints 2023-10-06 23:05:38 +02:00
10 changed files with 154 additions and 58 deletions

View file

@ -6,9 +6,10 @@ costs:
other: 5
chat-models:
gpt-4-32k: 100
gpt-4: 30
gpt-3: 3
gpt-4-32k-azure: 100
gpt-4: 50
gpt-4-azure: 10
gpt-3: 5
## Roles Explanation

View file

@ -1,12 +1,17 @@
import os
import time
import random
import asyncio
from aiocache import cached
from dotenv import load_dotenv
from cachetools import TTLCache
from motor.motor_asyncio import AsyncIOMotorClient
load_dotenv()
cache = TTLCache(maxsize=100, ttl=10)
class KeyManager:
def __init__(self):
self.conn = AsyncIOMotorClient(os.environ['MONGO_URI'])
@ -24,27 +29,34 @@ class KeyManager:
'source': source,
})
async def get_key(self, provider: str):
async def get_possible_keys(self, provider: str):
db = await self._get_collection('providerkeys')
key = await db.find_one({
keys = await db.find({
'provider': provider,
'inactive_reason': None,
'$or': [
{'rate_limited_since': None},
{'rate_limited_since': {'$lte': time.time() - 86400}}
{'rate_limited_until': None},
{'rate_limited_until': {'$lte': time.time()}}
]
})
}).to_list(length=None)
if key is None:
return keys
async def get_key(self, provider: str):
keys = await self.get_possible_keys(provider)
if not keys:
return '--NO_KEY--'
return key['key']
key = random.choice(keys)
api_key = key['key']
return api_key
async def rate_limit_key(self, provider: str, key: str):
async def rate_limit_key(self, provider: str, key: str, duration: int):
db = await self._get_collection('providerkeys')
await db.update_one({'provider': provider, 'key': key}, {
'$set': {
'rate_limited_since': time.time()
'rate_limited_until': time.time() + duration
}
})
@ -70,8 +82,6 @@ class KeyManager:
await db.insert_one({
'provider': filename.split('.')[0],
'key': line.strip(),
'rate_limited_since': None,
'inactive_reason': None,
'source': 'import'
})
num += 1
@ -86,5 +96,9 @@ class KeyManager:
manager = KeyManager()
async def main():
keys = await manager.get_possible_keys('closed')
print(len(keys))
if __name__ == '__main__':
asyncio.run(manager.import_all())
asyncio.run(main())

View file

@ -69,8 +69,8 @@ async def handle(incoming_request: fastapi.Request):
return await errors.error(403, f'Your NovaAI account has been banned. Reason: \'{ban_reason}\'.', 'Contact the staff for an appeal.')
# Checking for enterprise status
enterprise_keys = os.environ.get('NO_RATELIMIT_KEYS')
if '/enterprise' in path and user.get('api_key') not in enterprise_keys:
enterprise_keys = os.environ.get('ENTERPRISE_KEYS')
if path.startswith('/enterprise/v1') and user.get('api_key') not in enterprise_keys.split():
return await errors.error(403, 'Enterprise API is not available.', 'Contact the staff for an upgrade.')
if 'account/credits' in path:

View file

@ -1,9 +1,11 @@
from . import \
azure, \
closed, \
closed4
# closed432
MODULES = [
azure,
closed,
closed4,
# closed432,

View file

@ -1,5 +1,8 @@
import os
import sys
import aiohttp
import asyncio
import importlib
from rich import print
@ -12,41 +15,43 @@ def remove_duplicate_keys(file):
with open(file, 'w', encoding='utf8') as f:
f.writelines(unique_lines)
try:
provider_name = sys.argv[1]
async def main():
try:
provider_name = sys.argv[1]
if provider_name == '--clear':
for file in os.listdir('secret/'):
if file.endswith('.txt'):
remove_duplicate_keys(f'secret/{file}')
except IndexError:
print('List of available providers:')
exit()
for file_name in os.listdir(os.path.dirname(__file__)):
if file_name.endswith('.py') and not file_name.startswith('_'):
print(file_name.split('.')[0])
except IndexError:
print('List of available providers:')
sys.exit(0)
for file_name in os.listdir(os.path.dirname(__file__)):
if file_name.endswith('.py') and not file_name.startswith('_'):
print(file_name.split('.')[0])
try:
provider = importlib.import_module(f'.{provider_name}', 'providers')
except ModuleNotFoundError as exc:
print(exc)
sys.exit(1)
sys.exit(0)
if len(sys.argv) > 2:
model = sys.argv[2] # choose a specific model
else:
model = provider.MODELS[-1] # choose best model
try:
provider = __import__(provider_name)
except ModuleNotFoundError as exc:
print(f'Provider "{provider_name}" not found.')
print('Available providers:')
for file_name in os.listdir(os.path.dirname(__file__)):
if file_name.endswith('.py') and not file_name.startswith('_'):
print(file_name.split('.')[0])
sys.exit(1)
print(f'{provider_name} @ {model}')
req = await provider.chat_completion(model=model, messages=[{'role': 'user', 'content': '1+1='}])
print(req)
if len(sys.argv) > 2:
model = sys.argv[2]
else:
model = provider.MODELS[-1]
# launch aiohttp
async with aiohttp.ClientSession() as session:
async with session.request(
method=req['method'],
url=req['url'],
headers=req['headers'],
json=req['payload'],
) as response:
res_json = await response.json()
print(response.status, res_json)
print(f'{provider_name} @ {model}')
comp = provider.chat_completion(model=model)
print(comp)
asyncio.run(main())

32
api/providers/azure.py Normal file
View file

@ -0,0 +1,32 @@
from .helpers import utils
AUTH = True
ORGANIC = False
CONTEXT = True
STREAMING = True
MODERATIONS = False
ENDPOINT = 'https://nova-00001.openai.azure.com'
MODELS = [
'gpt-3.5-turbo',
'gpt-3.5-turbo-16k',
'gpt-4',
'gpt-4-32k'
]
MODELS = [f'{model}-azure' for model in MODELS]
AZURE_API = '2023-07-01-preview'
async def chat_completion(**payload):
key = await utils.random_secret_for('azure-nva1')
deployment = payload['model'].replace('.', '').replace('-azure', '')
return {
'method': 'POST',
'url': f'{ENDPOINT}/openai/deployments/{deployment}/chat/completions?api-version={AZURE_API}',
'payload': payload,
'headers': {
'api-key': key
},
'provider_auth': f'azure-nva1>{key}'
}

View file

@ -1,4 +1,7 @@
from db import providerkeys
try:
from db import providerkeys
except ModuleNotFoundError:
from ...db import providerkeys
GPT_3 = [
'gpt-3.5-turbo',

21
api/providers/mandrill.py Normal file
View file

@ -0,0 +1,21 @@
from .helpers import utils
AUTH = True
ORGANIC = False
CONTEXT = True
STREAMING = True
MODELS = ['llama-2-7b-chat']
async def chat_completion(**kwargs):
payload = kwargs
key = await utils.random_secret_for('mandrill')
return {
'method': 'POST',
'url': f'https://api.mandrillai.tech/v1/chat/completions',
'payload': payload,
'headers': {
'Authorization': f'Bearer {key}'
},
'provider_auth': f'mandrill>{key}'
}

View file

@ -7,6 +7,7 @@ import aiohttp
import asyncio
import starlette
from typing import Set
from rich import print
from dotenv import load_dotenv
@ -23,6 +24,8 @@ CRITICAL_API_ERRORS = ['invalid_api_key', 'account_deactivated']
keymanager = providerkeys.manager
background_tasks: Set[asyncio.Task] = set()
async def respond(
path: str='/v1/chat/completions',
user: dict=None,
@ -49,7 +52,8 @@ async def respond(
'Content-Type': 'application/json'
}
for _ in range(20):
for i in range(20):
print(i)
# Load balancing: randomly selecting a suitable provider
try:
if is_chat:
@ -62,6 +66,7 @@ async def respond(
'headers': headers,
'cookies': incoming_request.cookies
})
except ValueError:
yield await errors.yield_error(500, f'Sorry, the API has no active API keys for {model}.', 'Please use a different model.')
return
@ -73,6 +78,7 @@ async def respond(
provider_key = provider_auth.split('>')[1]
if provider_key == '--NO_KEY--':
print(f'No key for {provider_name}')
yield await errors.yield_error(500,
'Sorry, our API seems to have issues connecting to our provider(s).',
'This most likely isn\'t your fault. Please try again later.'
@ -101,16 +107,26 @@ async def respond(
) as response:
is_stream = response.content_type == 'text/event-stream'
if response.status == 429:
print('[!] rate limit')
# await keymanager.rate_limit_key(provider_name, provider_key)
continue
if response.content_type == 'application/json':
client_json_response = await response.json()
if 'method_not_supported' in str(client_json_response):
await errors.error(500, 'Sorry, this endpoint does not support this method.', data['error']['message'])
try:
error_code = client_json_response['error']['code']
except KeyError:
error_code = ''
if error_code == 'method_not_supported':
yield await errors.yield_error(400, 'Sorry, this endpoint does not support this method.', 'Please use a different method.')
if error_code == 'insufficient_quota':
print('[!] insufficient quota')
await keymanager.rate_limit_key(provider_name, provider_key, 86400)
continue
if error_code == 'billing_not_active':
print('[!] billing not active')
await keymanager.deactivate_key(provider_name, provider_key, 'billing_not_active')
continue
critical_error = False
for error in CRITICAL_API_ERRORS:
@ -126,7 +142,6 @@ async def respond(
server_json_response = client_json_response
else:
print('[!] non-ok response', client_json_response)
continue
if is_stream:
@ -154,7 +169,7 @@ async def respond(
if (not is_stream) and server_json_response:
yield json.dumps(server_json_response)
asyncio.create_task(
task = asyncio.create_task(
after_request.after_request(
incoming_request=incoming_request,
target_request=target_request,
@ -166,3 +181,5 @@ async def respond(
model=model,
)
)
background_tasks.add(task)
task.add_done_callback(background_tasks.discard)

View file

@ -25,6 +25,7 @@ MESSAGES = [
]
api_endpoint = os.getenv('CHECKS_ENDPOINT', 'http://localhost:2332/v1')
# api_endpoint = 'http://localhost:2333/v1'
async def _response_base_check(response: httpx.Response) -> None:
try: