Compare commits

...

2 commits

Author SHA1 Message Date
monosans 03f8c9ad0a
Fix dangling asyncio tasks 2023-10-07 09:14:51 +03:00
nsde 719f29fb29 Added azure endpoints 2023-10-06 23:05:38 +02:00
10 changed files with 154 additions and 58 deletions

View file

@ -6,9 +6,10 @@ costs:
other: 5 other: 5
chat-models: chat-models:
gpt-4-32k: 100 gpt-4-32k-azure: 100
gpt-4: 30 gpt-4: 50
gpt-3: 3 gpt-4-azure: 10
gpt-3: 5
## Roles Explanation ## Roles Explanation

View file

@ -1,12 +1,17 @@
import os import os
import time import time
import random
import asyncio import asyncio
from aiocache import cached
from dotenv import load_dotenv from dotenv import load_dotenv
from cachetools import TTLCache
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient
load_dotenv() load_dotenv()
cache = TTLCache(maxsize=100, ttl=10)
class KeyManager: class KeyManager:
def __init__(self): def __init__(self):
self.conn = AsyncIOMotorClient(os.environ['MONGO_URI']) self.conn = AsyncIOMotorClient(os.environ['MONGO_URI'])
@ -24,27 +29,34 @@ class KeyManager:
'source': source, 'source': source,
}) })
async def get_key(self, provider: str): async def get_possible_keys(self, provider: str):
db = await self._get_collection('providerkeys') db = await self._get_collection('providerkeys')
key = await db.find_one({ keys = await db.find({
'provider': provider, 'provider': provider,
'inactive_reason': None, 'inactive_reason': None,
'$or': [ '$or': [
{'rate_limited_since': None}, {'rate_limited_until': None},
{'rate_limited_since': {'$lte': time.time() - 86400}} {'rate_limited_until': {'$lte': time.time()}}
] ]
}) }).to_list(length=None)
if key is None: return keys
async def get_key(self, provider: str):
keys = await self.get_possible_keys(provider)
if not keys:
return '--NO_KEY--' return '--NO_KEY--'
return key['key'] key = random.choice(keys)
api_key = key['key']
return api_key
async def rate_limit_key(self, provider: str, key: str): async def rate_limit_key(self, provider: str, key: str, duration: int):
db = await self._get_collection('providerkeys') db = await self._get_collection('providerkeys')
await db.update_one({'provider': provider, 'key': key}, { await db.update_one({'provider': provider, 'key': key}, {
'$set': { '$set': {
'rate_limited_since': time.time() 'rate_limited_until': time.time() + duration
} }
}) })
@ -70,8 +82,6 @@ class KeyManager:
await db.insert_one({ await db.insert_one({
'provider': filename.split('.')[0], 'provider': filename.split('.')[0],
'key': line.strip(), 'key': line.strip(),
'rate_limited_since': None,
'inactive_reason': None,
'source': 'import' 'source': 'import'
}) })
num += 1 num += 1
@ -86,5 +96,9 @@ class KeyManager:
manager = KeyManager() manager = KeyManager()
async def main():
keys = await manager.get_possible_keys('closed')
print(len(keys))
if __name__ == '__main__': if __name__ == '__main__':
asyncio.run(manager.import_all()) asyncio.run(main())

View file

@ -69,8 +69,8 @@ async def handle(incoming_request: fastapi.Request):
return await errors.error(403, f'Your NovaAI account has been banned. Reason: \'{ban_reason}\'.', 'Contact the staff for an appeal.') return await errors.error(403, f'Your NovaAI account has been banned. Reason: \'{ban_reason}\'.', 'Contact the staff for an appeal.')
# Checking for enterprise status # Checking for enterprise status
enterprise_keys = os.environ.get('NO_RATELIMIT_KEYS') enterprise_keys = os.environ.get('ENTERPRISE_KEYS')
if '/enterprise' in path and user.get('api_key') not in enterprise_keys: if path.startswith('/enterprise/v1') and user.get('api_key') not in enterprise_keys.split():
return await errors.error(403, 'Enterprise API is not available.', 'Contact the staff for an upgrade.') return await errors.error(403, 'Enterprise API is not available.', 'Contact the staff for an upgrade.')
if 'account/credits' in path: if 'account/credits' in path:

View file

@ -1,9 +1,11 @@
from . import \ from . import \
azure, \
closed, \ closed, \
closed4 closed4
# closed432 # closed432
MODULES = [ MODULES = [
azure,
closed, closed,
closed4, closed4,
# closed432, # closed432,

View file

@ -1,5 +1,8 @@
import os import os
import sys import sys
import aiohttp
import asyncio
import importlib
from rich import print from rich import print
@ -12,17 +15,11 @@ def remove_duplicate_keys(file):
with open(file, 'w', encoding='utf8') as f: with open(file, 'w', encoding='utf8') as f:
f.writelines(unique_lines) f.writelines(unique_lines)
try: async def main():
try:
provider_name = sys.argv[1] provider_name = sys.argv[1]
if provider_name == '--clear': except IndexError:
for file in os.listdir('secret/'):
if file.endswith('.txt'):
remove_duplicate_keys(f'secret/{file}')
exit()
except IndexError:
print('List of available providers:') print('List of available providers:')
for file_name in os.listdir(os.path.dirname(__file__)): for file_name in os.listdir(os.path.dirname(__file__)):
@ -31,22 +28,30 @@ except IndexError:
sys.exit(0) sys.exit(0)
try: try:
provider = __import__(provider_name) provider = importlib.import_module(f'.{provider_name}', 'providers')
except ModuleNotFoundError as exc: except ModuleNotFoundError as exc:
print(f'Provider "{provider_name}" not found.') print(exc)
print('Available providers:')
for file_name in os.listdir(os.path.dirname(__file__)):
if file_name.endswith('.py') and not file_name.startswith('_'):
print(file_name.split('.')[0])
sys.exit(1) sys.exit(1)
if len(sys.argv) > 2: if len(sys.argv) > 2:
model = sys.argv[2] model = sys.argv[2] # choose a specific model
else: else:
model = provider.MODELS[-1] model = provider.MODELS[-1] # choose best model
print(f'{provider_name} @ {model}')
req = await provider.chat_completion(model=model, messages=[{'role': 'user', 'content': '1+1='}])
print(req)
print(f'{provider_name} @ {model}') # launch aiohttp
comp = provider.chat_completion(model=model) async with aiohttp.ClientSession() as session:
print(comp) async with session.request(
method=req['method'],
url=req['url'],
headers=req['headers'],
json=req['payload'],
) as response:
res_json = await response.json()
print(response.status, res_json)
asyncio.run(main())

32
api/providers/azure.py Normal file
View file

@ -0,0 +1,32 @@
from .helpers import utils
AUTH = True
ORGANIC = False
CONTEXT = True
STREAMING = True
MODERATIONS = False
ENDPOINT = 'https://nova-00001.openai.azure.com'
MODELS = [
'gpt-3.5-turbo',
'gpt-3.5-turbo-16k',
'gpt-4',
'gpt-4-32k'
]
MODELS = [f'{model}-azure' for model in MODELS]
AZURE_API = '2023-07-01-preview'
async def chat_completion(**payload):
key = await utils.random_secret_for('azure-nva1')
deployment = payload['model'].replace('.', '').replace('-azure', '')
return {
'method': 'POST',
'url': f'{ENDPOINT}/openai/deployments/{deployment}/chat/completions?api-version={AZURE_API}',
'payload': payload,
'headers': {
'api-key': key
},
'provider_auth': f'azure-nva1>{key}'
}

View file

@ -1,4 +1,7 @@
from db import providerkeys try:
from db import providerkeys
except ModuleNotFoundError:
from ...db import providerkeys
GPT_3 = [ GPT_3 = [
'gpt-3.5-turbo', 'gpt-3.5-turbo',

21
api/providers/mandrill.py Normal file
View file

@ -0,0 +1,21 @@
from .helpers import utils
AUTH = True
ORGANIC = False
CONTEXT = True
STREAMING = True
MODELS = ['llama-2-7b-chat']
async def chat_completion(**kwargs):
payload = kwargs
key = await utils.random_secret_for('mandrill')
return {
'method': 'POST',
'url': f'https://api.mandrillai.tech/v1/chat/completions',
'payload': payload,
'headers': {
'Authorization': f'Bearer {key}'
},
'provider_auth': f'mandrill>{key}'
}

View file

@ -7,6 +7,7 @@ import aiohttp
import asyncio import asyncio
import starlette import starlette
from typing import Set
from rich import print from rich import print
from dotenv import load_dotenv from dotenv import load_dotenv
@ -23,6 +24,8 @@ CRITICAL_API_ERRORS = ['invalid_api_key', 'account_deactivated']
keymanager = providerkeys.manager keymanager = providerkeys.manager
background_tasks: Set[asyncio.Task] = set()
async def respond( async def respond(
path: str='/v1/chat/completions', path: str='/v1/chat/completions',
user: dict=None, user: dict=None,
@ -49,7 +52,8 @@ async def respond(
'Content-Type': 'application/json' 'Content-Type': 'application/json'
} }
for _ in range(20): for i in range(20):
print(i)
# Load balancing: randomly selecting a suitable provider # Load balancing: randomly selecting a suitable provider
try: try:
if is_chat: if is_chat:
@ -62,6 +66,7 @@ async def respond(
'headers': headers, 'headers': headers,
'cookies': incoming_request.cookies 'cookies': incoming_request.cookies
}) })
except ValueError: except ValueError:
yield await errors.yield_error(500, f'Sorry, the API has no active API keys for {model}.', 'Please use a different model.') yield await errors.yield_error(500, f'Sorry, the API has no active API keys for {model}.', 'Please use a different model.')
return return
@ -73,6 +78,7 @@ async def respond(
provider_key = provider_auth.split('>')[1] provider_key = provider_auth.split('>')[1]
if provider_key == '--NO_KEY--': if provider_key == '--NO_KEY--':
print(f'No key for {provider_name}')
yield await errors.yield_error(500, yield await errors.yield_error(500,
'Sorry, our API seems to have issues connecting to our provider(s).', 'Sorry, our API seems to have issues connecting to our provider(s).',
'This most likely isn\'t your fault. Please try again later.' 'This most likely isn\'t your fault. Please try again later.'
@ -101,16 +107,26 @@ async def respond(
) as response: ) as response:
is_stream = response.content_type == 'text/event-stream' is_stream = response.content_type == 'text/event-stream'
if response.status == 429:
print('[!] rate limit')
# await keymanager.rate_limit_key(provider_name, provider_key)
continue
if response.content_type == 'application/json': if response.content_type == 'application/json':
client_json_response = await response.json() client_json_response = await response.json()
if 'method_not_supported' in str(client_json_response): try:
await errors.error(500, 'Sorry, this endpoint does not support this method.', data['error']['message']) error_code = client_json_response['error']['code']
except KeyError:
error_code = ''
if error_code == 'method_not_supported':
yield await errors.yield_error(400, 'Sorry, this endpoint does not support this method.', 'Please use a different method.')
if error_code == 'insufficient_quota':
print('[!] insufficient quota')
await keymanager.rate_limit_key(provider_name, provider_key, 86400)
continue
if error_code == 'billing_not_active':
print('[!] billing not active')
await keymanager.deactivate_key(provider_name, provider_key, 'billing_not_active')
continue
critical_error = False critical_error = False
for error in CRITICAL_API_ERRORS: for error in CRITICAL_API_ERRORS:
@ -126,7 +142,6 @@ async def respond(
server_json_response = client_json_response server_json_response = client_json_response
else: else:
print('[!] non-ok response', client_json_response)
continue continue
if is_stream: if is_stream:
@ -154,7 +169,7 @@ async def respond(
if (not is_stream) and server_json_response: if (not is_stream) and server_json_response:
yield json.dumps(server_json_response) yield json.dumps(server_json_response)
asyncio.create_task( task = asyncio.create_task(
after_request.after_request( after_request.after_request(
incoming_request=incoming_request, incoming_request=incoming_request,
target_request=target_request, target_request=target_request,
@ -166,3 +181,5 @@ async def respond(
model=model, model=model,
) )
) )
background_tasks.add(task)
task.add_done_callback(background_tasks.discard)

View file

@ -25,6 +25,7 @@ MESSAGES = [
] ]
api_endpoint = os.getenv('CHECKS_ENDPOINT', 'http://localhost:2332/v1') api_endpoint = os.getenv('CHECKS_ENDPOINT', 'http://localhost:2332/v1')
# api_endpoint = 'http://localhost:2333/v1'
async def _response_base_check(response: httpx.Response) -> None: async def _response_base_check(response: httpx.Response) -> None:
try: try: