Added load balancer, MongoDB, improved proxy streaming support, more error messages, entire new world order

This commit is contained in:
nsde 2023-08-03 01:46:49 +02:00
parent eeea634da0
commit 08d31d7ad1
15 changed files with 296 additions and 231 deletions

9
.gitignore vendored
View file

@ -1,3 +1,12 @@
providers/*
providers/
chat_providers/*
chat_providers/
secret/*
secret/
/secret
*.db
*.sqlite3
*.sql

View file

@ -1,17 +0,0 @@
#
FROM python:3.10
#
WORKDIR /code
#
COPY ./requirements.txt /code/requirements.txt
#
RUN pip install . pip install --no-cache-dir --upgrade -r /code/requirements.txt
#
COPY ./app /code/app
#
CMD ["uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "2333"]

1
api/__main__.py Normal file
View file

@ -0,0 +1 @@
import providers.__main__

30
api/chat_balancing.py Normal file
View file

@ -0,0 +1,30 @@
import random
import asyncio
import chat_providers
provider_modules = [
# chat_providers.twa,
chat_providers.quantum,
# chat_providers.churchless,
chat_providers.closed
]
async def balance(payload: dict) -> dict:
providers_available = []
for provider_module in provider_modules:
if payload['stream'] and not provider_module.STREAMING:
continue
if payload['model'] not in provider_module.MODELS:
continue
providers_available.append(provider_module)
provider = random.choice(providers_available)
return provider.chat_completion(**payload)
if __name__ == '__main__':
req = asyncio.run(balance(payload={'model': 'gpt-3.5-turbo', 'stream': True}))
print(req['url'])

View file

@ -4,7 +4,7 @@ import os
import json
import fastapi
import users
from db import users
from dotenv import load_dotenv
@ -24,10 +24,11 @@ async def get_users(discord_id: int, incoming_request: fastapi.Request):
if auth_error:
return auth_error
user = await users.get_user(by_discord_id=discord_id)
user = await users.by_discord_id(discord_id)
if not user:
return fastapi.Response(status_code=404, content='User not found.')
return user
@router.post('/users')
@ -43,6 +44,5 @@ async def create_user(incoming_request: fastapi.Request):
except (json.decoder.JSONDecodeError, AttributeError):
return fastapi.Response(status_code=400, content='Invalid or no payload received.')
user = await users.add_user(discord_id=discord_id)
user = await users.create(discord_id)
return user

57
api/db/logs.py Normal file
View file

@ -0,0 +1,57 @@
import os
import bson
import datetime
from dotenv import load_dotenv
from motor.motor_asyncio import AsyncIOMotorClient
load_dotenv()
def _get_mongo(collection_name: str):
return AsyncIOMotorClient(os.getenv('MONGO_URI'))['nova-core'][collection_name]
async def log_api_request(user, request, target_url):
payload = await request.json()
last_prompt = None
if 'messages' in payload:
last_prompt = payload['messages'][-1]['content']
model = None
if 'model' in payload:
model = payload['model']
new_log_item = {
'timestamp': bson.timestamp.Timestamp(datetime.datetime.now(), 0),
'method': request.method,
'path': request.url.path,
'user_id': user['_id'],
'security': {
'ip': request.client.host,
'useragent': request.headers.get('User-Agent')
},
'details': {
'model': model,
'last_prompt': last_prompt,
'target_url': target_url
}
}
inserted = await _get_mongo('logs').insert_one(new_log_item)
log_item = await _get_mongo('logs').find_one({'_id': inserted.inserted_id})
return log_item
async def by_id(log_id: str):
return await _get_mongo('logs').find_one({'_id': log_id})
async def by_user_id(user_id: str):
return await _get_mongo('logs').find({'user_id': user_id})
async def delete_by_id(log_id: str):
return await _get_mongo('logs').delete_one({'_id': log_id})
async def delete_by_user_id(user_id: str):
return await _get_mongo('logs').delete_many({'user_id': user_id})
if __name__ == '__main__':
pass

67
api/db/users.py Normal file
View file

@ -0,0 +1,67 @@
import os
import random
import string
import asyncio
from dotenv import load_dotenv
from motor.motor_asyncio import AsyncIOMotorClient
load_dotenv()
def _get_mongo(collection_name: str):
return AsyncIOMotorClient(os.getenv('MONGO_URI'))['nova-core'][collection_name]
async def create(discord_id: int=0) -> dict:
"""Adds a new user to the MongoDB collection."""
chars = string.ascii_letters + string.digits
infix = os.getenv('KEYGEN_INFIX')
suffix = ''.join(random.choices(chars, k=20))
prefix = ''.join(random.choices(chars, k=20))
new_api_key = f'nv-{prefix}{infix}{suffix}'
new_user = {
'api_key': new_api_key,
'credits': 1000,
'role': '',
'status': {
'active': True,
'ban_reason': '',
},
'auth': {
'discord': discord_id,
'github': None
}
}
await _get_mongo('users').insert_one(new_user)
user = await _get_mongo('users').find_one({'api_key': new_api_key})
return user
async def by_id(user_id: str):
return await _get_mongo('users').find_one({'_id': user_id})
async def by_discord_id(discord_id: str):
return await _get_mongo('users').find_one({'auth.discord': discord_id})
async def by_api_key(key: str):
return await _get_mongo('users').find_one({'api_key': key})
async def update_by_id(user_id: str, update):
return await _get_mongo('users').update_one({'_id': user_id}, update)
async def update_by_filter(obj_filter, update):
return await _get_mongo('users').update_one(obj_filter, update)
async def delete(user_id: str):
await _get_mongo('users').delete_one({'_id': user_id})
async def demo():
user = await create(69420)
print(user)
if __name__ == '__main__':
asyncio.run(demo())

View file

@ -0,0 +1,2 @@
class Retry(Exception):
"""The server should retry the request."""

View file

@ -1,35 +0,0 @@
"""Manages web requests."""
import os
from dotenv import load_dotenv
from typing import Union, Optional
load_dotenv()
EXCLUDED_HEADERS = [
'content-encoding',
'content-length',
'transfer-encoding',
'connection'
]
class Request:
def __init__(self,
url: str,
method: str='GET',
payload: Optional[Union[dict, list]]=None,
headers: dict={
'Content-Type': 'application/json'
}
):
self.method = method.upper()
self.url = url.replace('/v1/v1', '/v1')
self.payload = payload
self.headers = headers
self.timeout = int(os.getenv('TRANSFER_TIMEOUT', '120'))
class HTTPXRequest(Request):
def __init__(self, url: str, *args, **kwargs):
super().__init__(url, *args, **kwargs)
self.url += '?httpx=1'

View file

@ -7,9 +7,10 @@ from fastapi.middleware.cors import CORSMiddleware
from dotenv import load_dotenv
import core
import users
import transfer
from db import users
load_dotenv()
app = fastapi.FastAPI()
@ -26,7 +27,9 @@ app.include_router(core.router)
@app.on_event('startup')
async def startup_event():
await users.prepare()
# DATABASE FIX https://stackoverflow.com/questions/65970988/python-mongodb-motor-objectid-object-is-not-iterable-error-while-trying-to-f
import pydantic, bson
pydantic.json.ENCODERS_BY_TYPE[bson.objectid.ObjectId]=str
@app.get('/')
async def root():
@ -39,4 +42,4 @@ async def root():
'github': 'https://github.com/novaoss/nova-api'
}
app.add_route('/{path:path}', transfer.handle_api_request, ['GET', 'POST', 'PUT', 'DELETE', 'PATCH'])
app.add_route('/{path:path}', transfer.handle, ['GET', 'POST', 'PUT', 'DELETE', 'PATCH'])

View file

@ -1,32 +1,42 @@
import os
import httpx
import requests
from dotenv import load_dotenv
import proxies
from dotenv import load_dotenv
from helpers.requesting import Request
from helpers import exceptions
load_dotenv()
async def stream_closedai_request(request: Request):
async with httpx.AsyncClient(
# proxies=proxies.default_proxy.urls_httpx,
timeout=httpx.Timeout(request.timeout)
) as client:
async def stream(request: dict):
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {os.getenv("CLOSEDAI_KEY")}'
'Content-Type': 'application/json'
}
response = await client.request(
method=request.method,
url=request.url,
json=request.payload,
headers=headers
for k, v in request.get('headers', {}).items():
headers[k] = v
for _ in range(3):
response = requests.request(
method=request.get('method', 'POST'),
url=request['url'],
json=request.get('payload', {}),
headers=headers,
timeout=int(os.getenv('TRANSFER_TIMEOUT', '120')),
proxies=proxies.default_proxy.urls,
stream=True
)
try:
response.raise_for_status()
except Exception as exc:
if str(exc) == '429 Client Error: Too Many Requests for url: https://api.openai.com/v1/chat/completions':
continue
else:
break
async for chunk in response.aiter_bytes():
for chunk in response.iter_lines():
chunk = f'{chunk.decode("utf8")}\n\n'
yield chunk

View file

@ -86,10 +86,33 @@ def test_httpx():
print(default_proxy.proxies)
with httpx.Client(
# proxies=default_proxy.proxies
) as client:
return client.get('https://checkip.amazonaws.com').text.strip()
def test_httpx_workaround():
import httpx
print(default_proxy.proxies)
# this workaround solves the RNDS issue, but fails for Cloudflare protected websites
with httpx.Client(proxies=default_proxy.proxies, headers={'Host': 'checkip.amazonaws.com'}) as client:
return client.get(
f'http://{socket.gethostbyname("checkip.amazonaws.com")}/',
).text.strip()
def test_requests():
import requests
print(default_proxy.proxies)
return requests.get(
timeout=5,
url='https://checkip.amazonaws.com/',
proxies=default_proxy.urls
).text.strip()
if __name__ == '__main__':
print(test_httpx())
# print(test_requests())

View file

@ -2,12 +2,16 @@
import os
import json
import tokens
import logging
import starlette
from dotenv import load_dotenv
from helpers import requesting, tokens, errors
import netclient
import chat_balancing
from db import logs, users
from helpers import tokens, errors, exceptions
load_dotenv()
@ -20,12 +24,10 @@ logging.basicConfig(
logging.info('API started')
DEFAULT_ENDPOINT = os.getenv('CLOSEDAI_ENDPOINT')
async def handle_api_request(incoming_request, target_endpoint: str=DEFAULT_ENDPOINT):
async def handle(incoming_request):
"""Transfer a streaming response from the incoming request to the target endpoint"""
target_url = f'{target_endpoint}{incoming_request.url.path}'
path = incoming_request.url.path
if incoming_request.method not in ['GET', 'POST', 'PUT', 'DELETE', 'PATCH']:
return errors.error(405, f'Method "{incoming_request.method}" is not allowed.', 'Change the request method to the correct one.')
@ -37,35 +39,54 @@ async def handle_api_request(incoming_request, target_endpoint: str=DEFAULT_ENDP
try:
input_tokens = tokens.count_for_messages(payload['messages'])
except:
except (KeyError, TypeError):
input_tokens = 0
auth_header = incoming_request.headers.get('Authorization')
if not auth_header:
return errors.error(401, 'No NovaOSS API key given!', 'Add "Authorization: Bearer nv-..." to your request headers.')
return errors.error(401, 'No NovaAI API key given!', 'Add "Authorization: Bearer nv-..." to your request headers.')
received_key = auth_header
if auth_header.startswith('Bearer '):
received_key = auth_header.split('Bearer ')[1]
user = users.get_user(by_api_key=received_key)
user = await users.by_api_key(received_key)
if not user:
return errors.error(401, 'Invalid NovaOSS API key!', 'Create a new NovaOSS API key.')
return errors.error(401, 'Invalid NovaAI API key!', 'Create a new NovaOSS API key.')
if not user['active']:
return errors.error(403, 'Your account is not active.', 'Activate your account.')
ban_reason = user['status']['ban_reason']
if ban_reason:
return errors.error(403, f'Your NovaAI account has been banned. Reason: "{ban_reason}".', 'Contact the staff for an appeal.')
logging.info(f'[%s] %s -> %s', incoming_request.method, incoming_request.url.path, target_url)
if not user['status']['active']:
return errors.error(418, 'Your NovaAI account is not active (paused).', 'Simply re-activate your account using a Discord command or the web panel.')
request = requesting.Request(
url=target_url,
payload=payload,
method=incoming_request.method,
)
payload['user'] = str(user['_id'])
return starlette.responses.StreamingResponse(
content=netclient.stream_closedai_request(request)
)
cost = 1
if '/chat/completions' in path:
cost = 5
if 'gpt-4' in payload['model']:
cost = 10
else:
return errors.error(404, f'Sorry, we don\'t support "{path}" yet. We\'re working on it.', 'Contact our team.')
if not payload.get('stream') is True:
payload['stream'] = False
if user['credits'] < cost:
return errors.error(429, 'Not enough credits.', 'You do not have enough credits to complete this request.')
await users.update_by_id(user['_id'], {'$inc': {'credits': -cost}})
target_request = await chat_balancing.balance(payload)
print(target_request['url'])
return starlette.responses.StreamingResponse(netclient.stream(target_request))

View file

@ -1,117 +0,0 @@
import os
import uuid
import time
import string
import random
import asyncio
from dotenv import load_dotenv
from motor.motor_asyncio import AsyncIOMotorClient
from rich import print
load_dotenv()
MONGO_URI = os.getenv('MONGO_URI')
MONGO_DB_NAME = 'users'
def get_mongo(collection_name):
client = AsyncIOMotorClient(MONGO_URI)
db = client[MONGO_DB_NAME]
return db[collection_name]
async def prepare() -> None:
"""Create the MongoDB collection."""
collection = get_mongo('users')
await collection.create_index('id', unique=True)
await collection.create_index('discord_id', unique=True)
await collection.create_index('api_key', unique=True)
async def add_user(discord_id: int = 0, tags: list = None) -> dict:
"""Adds a new user to the MongoDB collection."""
chars = string.ascii_letters + string.digits
infix = os.getenv('KEYGEN_INFIX')
suffix = ''.join(random.choices(chars, k=20))
prefix = ''.join(random.choices(chars, k=20))
key = f'nv-{prefix}{infix}{suffix}'
tags = tags or []
new_user = {
'id': str(uuid.uuid4()),
'api_key': key,
'created_at': int(time.time()),
'ban_reason': '',
'active': True,
'discord_id': discord_id,
'credit': 0,
'tags': '/'.join(tags),
'usage': {
'events': [],
'num_tokens': 0
}
}
collection = get_mongo('users')
await collection.insert_one(new_user)
return new_user
async def get_user(by_id: str = '', by_discord_id: int = 0, by_api_key: str = ''):
"""Retrieve a user from the MongoDB collection."""
collection = get_mongo('users')
query = {
'$or': [
{'id': by_id},
{'discord_id': by_discord_id},
{'api_key': by_api_key},
]
}
return await collection.find_one(query)
async def get_all_users():
"""Retrieve all users from the MongoDB collection."""
collection = get_mongo('users')
return list(await collection.find())
async def user_used_api(user_id: str, num_tokens: int = 0, model='', ip_address: str = '', user_agent: str = '') -> None:
"""Update the stats of a user."""
collection = get_mongo('users')
user = await get_user(by_id=user_id)
if not user:
raise ValueError('User not found.')
usage = user['usage']
usage['events'].append({
'timestamp': time.time(),
'ip_address': ip_address,
'user_agent': user_agent,
'model': model,
'num_tokens': num_tokens
})
usage['num_tokens'] += num_tokens
await collection.update_one({'id': user_id}, {'$set': {'usage': usage}})
async def demo():
await prepare()
example_id = 133769420
user = await get_user(by_discord_id=example_id)
print(user)
uid = await user['id']
await user_used_api(uid, model='gpt-5', num_tokens=42, ip_address='9.9.9.9', user_agent='Mozilla/5.0')
# print(user)
if __name__ == '__main__':
asyncio.run(demo())

View file

@ -10,15 +10,21 @@ from dotenv import load_dotenv
load_dotenv()
MODEL = 'gpt-3.5-turbo'
# MESSAGES = [
# {
# 'role': 'system',
# 'content': 'Always answer with "3", no matter what the user asks for. No exceptions. Just answer with the number "3". Nothing else. Just "3". No punctuation.'
# },
# {
# 'role': 'user',
# 'content': '1+1=',
# },
# ]
MESSAGES = [
{
'role': 'system',
'content': 'Always answer with "3", no matter what the user asks for. No exceptions. Just answer with the number "3". Nothing else. Just "3". No punctuation.'
},
{
'role': 'user',
'content': '1+1=',
},
}
]
api_endpoint = 'http://localhost:2332'
@ -36,7 +42,7 @@ def test_api(model: str=MODEL, messages: List[dict]=None) -> dict:
headers = {
'Content-Type': 'application/json',
'Authorization': 'Bearer ' + os.getenv('DEMO_AUTH', 'nv-API-TEST'),
'Authorization': 'Bearer ' + api_key
}
json_data = {
@ -59,23 +65,28 @@ def test_library():
"""Tests if the api_endpoint is working with the Python library."""
closedai.api_base = api_endpoint
closedai.api_key = os.getenv('DEMO_AUTH', 'nv-LIB-TEST')
closedai.api_key = api_key
completion = closedai.ChatCompletion.create(
model=MODEL,
messages=MESSAGES,
stream=True,
stream=True
)
return completion.choices[0]
for event in completion:
try:
print(event['choices'][0]['delta']['content'])
except:
print('-')
def test_all():
"""Runs all tests."""
# print(test_server())
print(test_api())
# print(test_library())
# print(test_api())
print(test_library())
if __name__ == '__main__':
# api_endpoint = 'https://api.nova-oss.com'
api_endpoint = 'https://api.nova-oss.com'
api_key = os.getenv('TEST_NOVA_KEY')
test_all()