mirror of
https://github.com/NovaOSS/nova-api.git
synced 2024-11-25 16:43:58 +01:00
Added load balancer, MongoDB, improved proxy streaming support, more error messages, entire new world order
This commit is contained in:
parent
eeea634da0
commit
08d31d7ad1
9
.gitignore
vendored
9
.gitignore
vendored
|
@ -1,3 +1,12 @@
|
||||||
|
providers/*
|
||||||
|
providers/
|
||||||
|
chat_providers/*
|
||||||
|
chat_providers/
|
||||||
|
|
||||||
|
secret/*
|
||||||
|
secret/
|
||||||
|
/secret
|
||||||
|
|
||||||
*.db
|
*.db
|
||||||
*.sqlite3
|
*.sqlite3
|
||||||
*.sql
|
*.sql
|
||||||
|
|
17
Dockerfile
17
Dockerfile
|
@ -1,17 +0,0 @@
|
||||||
#
|
|
||||||
FROM python:3.10
|
|
||||||
|
|
||||||
#
|
|
||||||
WORKDIR /code
|
|
||||||
|
|
||||||
#
|
|
||||||
COPY ./requirements.txt /code/requirements.txt
|
|
||||||
|
|
||||||
#
|
|
||||||
RUN pip install . pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
|
||||||
|
|
||||||
#
|
|
||||||
COPY ./app /code/app
|
|
||||||
|
|
||||||
#
|
|
||||||
CMD ["uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "2333"]
|
|
1
api/__main__.py
Normal file
1
api/__main__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
import providers.__main__
|
30
api/chat_balancing.py
Normal file
30
api/chat_balancing.py
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
import random
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import chat_providers
|
||||||
|
|
||||||
|
provider_modules = [
|
||||||
|
# chat_providers.twa,
|
||||||
|
chat_providers.quantum,
|
||||||
|
# chat_providers.churchless,
|
||||||
|
chat_providers.closed
|
||||||
|
]
|
||||||
|
|
||||||
|
async def balance(payload: dict) -> dict:
|
||||||
|
providers_available = []
|
||||||
|
|
||||||
|
for provider_module in provider_modules:
|
||||||
|
if payload['stream'] and not provider_module.STREAMING:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if payload['model'] not in provider_module.MODELS:
|
||||||
|
continue
|
||||||
|
|
||||||
|
providers_available.append(provider_module)
|
||||||
|
|
||||||
|
provider = random.choice(providers_available)
|
||||||
|
return provider.chat_completion(**payload)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
req = asyncio.run(balance(payload={'model': 'gpt-3.5-turbo', 'stream': True}))
|
||||||
|
print(req['url'])
|
10
api/core.py
10
api/core.py
|
@ -4,7 +4,7 @@ import os
|
||||||
import json
|
import json
|
||||||
import fastapi
|
import fastapi
|
||||||
|
|
||||||
import users
|
from db import users
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
@ -24,10 +24,11 @@ async def get_users(discord_id: int, incoming_request: fastapi.Request):
|
||||||
if auth_error:
|
if auth_error:
|
||||||
return auth_error
|
return auth_error
|
||||||
|
|
||||||
user = await users.get_user(by_discord_id=discord_id)
|
user = await users.by_discord_id(discord_id)
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
return fastapi.Response(status_code=404, content='User not found.')
|
return fastapi.Response(status_code=404, content='User not found.')
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
@router.post('/users')
|
@router.post('/users')
|
||||||
|
@ -42,7 +43,6 @@ async def create_user(incoming_request: fastapi.Request):
|
||||||
discord_id = payload.get('discord_id')
|
discord_id = payload.get('discord_id')
|
||||||
except (json.decoder.JSONDecodeError, AttributeError):
|
except (json.decoder.JSONDecodeError, AttributeError):
|
||||||
return fastapi.Response(status_code=400, content='Invalid or no payload received.')
|
return fastapi.Response(status_code=400, content='Invalid or no payload received.')
|
||||||
|
|
||||||
user = await users.add_user(discord_id=discord_id)
|
user = await users.create(discord_id)
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
57
api/db/logs.py
Normal file
57
api/db/logs.py
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
import os
|
||||||
|
import bson
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
def _get_mongo(collection_name: str):
|
||||||
|
return AsyncIOMotorClient(os.getenv('MONGO_URI'))['nova-core'][collection_name]
|
||||||
|
|
||||||
|
async def log_api_request(user, request, target_url):
|
||||||
|
payload = await request.json()
|
||||||
|
|
||||||
|
last_prompt = None
|
||||||
|
if 'messages' in payload:
|
||||||
|
last_prompt = payload['messages'][-1]['content']
|
||||||
|
|
||||||
|
model = None
|
||||||
|
if 'model' in payload:
|
||||||
|
model = payload['model']
|
||||||
|
|
||||||
|
new_log_item = {
|
||||||
|
'timestamp': bson.timestamp.Timestamp(datetime.datetime.now(), 0),
|
||||||
|
'method': request.method,
|
||||||
|
'path': request.url.path,
|
||||||
|
'user_id': user['_id'],
|
||||||
|
'security': {
|
||||||
|
'ip': request.client.host,
|
||||||
|
'useragent': request.headers.get('User-Agent')
|
||||||
|
},
|
||||||
|
'details': {
|
||||||
|
'model': model,
|
||||||
|
'last_prompt': last_prompt,
|
||||||
|
'target_url': target_url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inserted = await _get_mongo('logs').insert_one(new_log_item)
|
||||||
|
log_item = await _get_mongo('logs').find_one({'_id': inserted.inserted_id})
|
||||||
|
return log_item
|
||||||
|
|
||||||
|
async def by_id(log_id: str):
|
||||||
|
return await _get_mongo('logs').find_one({'_id': log_id})
|
||||||
|
|
||||||
|
async def by_user_id(user_id: str):
|
||||||
|
return await _get_mongo('logs').find({'user_id': user_id})
|
||||||
|
|
||||||
|
async def delete_by_id(log_id: str):
|
||||||
|
return await _get_mongo('logs').delete_one({'_id': log_id})
|
||||||
|
|
||||||
|
async def delete_by_user_id(user_id: str):
|
||||||
|
return await _get_mongo('logs').delete_many({'user_id': user_id})
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pass
|
67
api/db/users.py
Normal file
67
api/db/users.py
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
def _get_mongo(collection_name: str):
|
||||||
|
return AsyncIOMotorClient(os.getenv('MONGO_URI'))['nova-core'][collection_name]
|
||||||
|
|
||||||
|
async def create(discord_id: int=0) -> dict:
|
||||||
|
"""Adds a new user to the MongoDB collection."""
|
||||||
|
chars = string.ascii_letters + string.digits
|
||||||
|
|
||||||
|
infix = os.getenv('KEYGEN_INFIX')
|
||||||
|
suffix = ''.join(random.choices(chars, k=20))
|
||||||
|
prefix = ''.join(random.choices(chars, k=20))
|
||||||
|
|
||||||
|
new_api_key = f'nv-{prefix}{infix}{suffix}'
|
||||||
|
|
||||||
|
new_user = {
|
||||||
|
'api_key': new_api_key,
|
||||||
|
'credits': 1000,
|
||||||
|
'role': '',
|
||||||
|
'status': {
|
||||||
|
'active': True,
|
||||||
|
'ban_reason': '',
|
||||||
|
},
|
||||||
|
'auth': {
|
||||||
|
'discord': discord_id,
|
||||||
|
'github': None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await _get_mongo('users').insert_one(new_user)
|
||||||
|
|
||||||
|
user = await _get_mongo('users').find_one({'api_key': new_api_key})
|
||||||
|
return user
|
||||||
|
|
||||||
|
async def by_id(user_id: str):
|
||||||
|
return await _get_mongo('users').find_one({'_id': user_id})
|
||||||
|
|
||||||
|
async def by_discord_id(discord_id: str):
|
||||||
|
return await _get_mongo('users').find_one({'auth.discord': discord_id})
|
||||||
|
|
||||||
|
async def by_api_key(key: str):
|
||||||
|
return await _get_mongo('users').find_one({'api_key': key})
|
||||||
|
|
||||||
|
async def update_by_id(user_id: str, update):
|
||||||
|
return await _get_mongo('users').update_one({'_id': user_id}, update)
|
||||||
|
|
||||||
|
async def update_by_filter(obj_filter, update):
|
||||||
|
return await _get_mongo('users').update_one(obj_filter, update)
|
||||||
|
|
||||||
|
async def delete(user_id: str):
|
||||||
|
await _get_mongo('users').delete_one({'_id': user_id})
|
||||||
|
|
||||||
|
|
||||||
|
async def demo():
|
||||||
|
user = await create(69420)
|
||||||
|
print(user)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
asyncio.run(demo())
|
2
api/helpers/exceptions.py
Normal file
2
api/helpers/exceptions.py
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
class Retry(Exception):
|
||||||
|
"""The server should retry the request."""
|
|
@ -1,35 +0,0 @@
|
||||||
"""Manages web requests."""
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from typing import Union, Optional
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
EXCLUDED_HEADERS = [
|
|
||||||
'content-encoding',
|
|
||||||
'content-length',
|
|
||||||
'transfer-encoding',
|
|
||||||
'connection'
|
|
||||||
]
|
|
||||||
|
|
||||||
class Request:
|
|
||||||
def __init__(self,
|
|
||||||
url: str,
|
|
||||||
method: str='GET',
|
|
||||||
payload: Optional[Union[dict, list]]=None,
|
|
||||||
headers: dict={
|
|
||||||
'Content-Type': 'application/json'
|
|
||||||
}
|
|
||||||
):
|
|
||||||
self.method = method.upper()
|
|
||||||
self.url = url.replace('/v1/v1', '/v1')
|
|
||||||
self.payload = payload
|
|
||||||
self.headers = headers
|
|
||||||
self.timeout = int(os.getenv('TRANSFER_TIMEOUT', '120'))
|
|
||||||
|
|
||||||
class HTTPXRequest(Request):
|
|
||||||
def __init__(self, url: str, *args, **kwargs):
|
|
||||||
super().__init__(url, *args, **kwargs)
|
|
||||||
self.url += '?httpx=1'
|
|
|
@ -7,9 +7,10 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
import core
|
import core
|
||||||
import users
|
|
||||||
import transfer
|
import transfer
|
||||||
|
|
||||||
|
from db import users
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
app = fastapi.FastAPI()
|
app = fastapi.FastAPI()
|
||||||
|
@ -26,7 +27,9 @@ app.include_router(core.router)
|
||||||
|
|
||||||
@app.on_event('startup')
|
@app.on_event('startup')
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
await users.prepare()
|
# DATABASE FIX https://stackoverflow.com/questions/65970988/python-mongodb-motor-objectid-object-is-not-iterable-error-while-trying-to-f
|
||||||
|
import pydantic, bson
|
||||||
|
pydantic.json.ENCODERS_BY_TYPE[bson.objectid.ObjectId]=str
|
||||||
|
|
||||||
@app.get('/')
|
@app.get('/')
|
||||||
async def root():
|
async def root():
|
||||||
|
@ -39,4 +42,4 @@ async def root():
|
||||||
'github': 'https://github.com/novaoss/nova-api'
|
'github': 'https://github.com/novaoss/nova-api'
|
||||||
}
|
}
|
||||||
|
|
||||||
app.add_route('/{path:path}', transfer.handle_api_request, ['GET', 'POST', 'PUT', 'DELETE', 'PATCH'])
|
app.add_route('/{path:path}', transfer.handle, ['GET', 'POST', 'PUT', 'DELETE', 'PATCH'])
|
||||||
|
|
|
@ -1,34 +1,44 @@
|
||||||
import os
|
import os
|
||||||
import httpx
|
import requests
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
import proxies
|
import proxies
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from helpers import exceptions
|
||||||
from helpers.requesting import Request
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
async def stream_closedai_request(request: Request):
|
async def stream(request: dict):
|
||||||
async with httpx.AsyncClient(
|
headers = {
|
||||||
# proxies=proxies.default_proxy.urls_httpx,
|
'Content-Type': 'application/json'
|
||||||
timeout=httpx.Timeout(request.timeout)
|
}
|
||||||
) as client:
|
|
||||||
headers = {
|
for k, v in request.get('headers', {}).items():
|
||||||
'Content-Type': 'application/json',
|
headers[k] = v
|
||||||
'Authorization': f'Bearer {os.getenv("CLOSEDAI_KEY")}'
|
|
||||||
}
|
for _ in range(3):
|
||||||
response = await client.request(
|
response = requests.request(
|
||||||
method=request.method,
|
method=request.get('method', 'POST'),
|
||||||
url=request.url,
|
url=request['url'],
|
||||||
json=request.payload,
|
json=request.get('payload', {}),
|
||||||
headers=headers
|
headers=headers,
|
||||||
|
timeout=int(os.getenv('TRANSFER_TIMEOUT', '120')),
|
||||||
|
proxies=proxies.default_proxy.urls,
|
||||||
|
stream=True
|
||||||
)
|
)
|
||||||
|
|
||||||
response.raise_for_status()
|
try:
|
||||||
|
response.raise_for_status()
|
||||||
|
except Exception as exc:
|
||||||
|
if str(exc) == '429 Client Error: Too Many Requests for url: https://api.openai.com/v1/chat/completions':
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
async for chunk in response.aiter_bytes():
|
for chunk in response.iter_lines():
|
||||||
chunk = f'{chunk.decode("utf8")}\n\n'
|
chunk = f'{chunk.decode("utf8")}\n\n'
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -86,10 +86,33 @@ def test_httpx():
|
||||||
|
|
||||||
print(default_proxy.proxies)
|
print(default_proxy.proxies)
|
||||||
|
|
||||||
|
with httpx.Client(
|
||||||
|
# proxies=default_proxy.proxies
|
||||||
|
) as client:
|
||||||
|
return client.get('https://checkip.amazonaws.com').text.strip()
|
||||||
|
|
||||||
|
def test_httpx_workaround():
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
print(default_proxy.proxies)
|
||||||
|
|
||||||
|
# this workaround solves the RNDS issue, but fails for Cloudflare protected websites
|
||||||
with httpx.Client(proxies=default_proxy.proxies, headers={'Host': 'checkip.amazonaws.com'}) as client:
|
with httpx.Client(proxies=default_proxy.proxies, headers={'Host': 'checkip.amazonaws.com'}) as client:
|
||||||
return client.get(
|
return client.get(
|
||||||
f'http://{socket.gethostbyname("checkip.amazonaws.com")}/',
|
f'http://{socket.gethostbyname("checkip.amazonaws.com")}/',
|
||||||
).text.strip()
|
).text.strip()
|
||||||
|
|
||||||
|
def test_requests():
|
||||||
|
import requests
|
||||||
|
|
||||||
|
print(default_proxy.proxies)
|
||||||
|
|
||||||
|
return requests.get(
|
||||||
|
timeout=5,
|
||||||
|
url='https://checkip.amazonaws.com/',
|
||||||
|
proxies=default_proxy.urls
|
||||||
|
).text.strip()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
print(test_httpx())
|
print(test_httpx())
|
||||||
|
# print(test_requests())
|
||||||
|
|
|
@ -2,12 +2,16 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import tokens
|
|
||||||
import logging
|
import logging
|
||||||
import starlette
|
import starlette
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from helpers import requesting, tokens, errors
|
|
||||||
|
import netclient
|
||||||
|
import chat_balancing
|
||||||
|
|
||||||
|
from db import logs, users
|
||||||
|
from helpers import tokens, errors, exceptions
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
@ -20,12 +24,10 @@ logging.basicConfig(
|
||||||
|
|
||||||
logging.info('API started')
|
logging.info('API started')
|
||||||
|
|
||||||
DEFAULT_ENDPOINT = os.getenv('CLOSEDAI_ENDPOINT')
|
async def handle(incoming_request):
|
||||||
|
|
||||||
async def handle_api_request(incoming_request, target_endpoint: str=DEFAULT_ENDPOINT):
|
|
||||||
"""Transfer a streaming response from the incoming request to the target endpoint"""
|
"""Transfer a streaming response from the incoming request to the target endpoint"""
|
||||||
|
|
||||||
target_url = f'{target_endpoint}{incoming_request.url.path}'
|
path = incoming_request.url.path
|
||||||
|
|
||||||
if incoming_request.method not in ['GET', 'POST', 'PUT', 'DELETE', 'PATCH']:
|
if incoming_request.method not in ['GET', 'POST', 'PUT', 'DELETE', 'PATCH']:
|
||||||
return errors.error(405, f'Method "{incoming_request.method}" is not allowed.', 'Change the request method to the correct one.')
|
return errors.error(405, f'Method "{incoming_request.method}" is not allowed.', 'Change the request method to the correct one.')
|
||||||
|
@ -37,35 +39,54 @@ async def handle_api_request(incoming_request, target_endpoint: str=DEFAULT_ENDP
|
||||||
|
|
||||||
try:
|
try:
|
||||||
input_tokens = tokens.count_for_messages(payload['messages'])
|
input_tokens = tokens.count_for_messages(payload['messages'])
|
||||||
except:
|
except (KeyError, TypeError):
|
||||||
input_tokens = 0
|
input_tokens = 0
|
||||||
|
|
||||||
auth_header = incoming_request.headers.get('Authorization')
|
auth_header = incoming_request.headers.get('Authorization')
|
||||||
|
|
||||||
if not auth_header:
|
if not auth_header:
|
||||||
return errors.error(401, 'No NovaOSS API key given!', 'Add "Authorization: Bearer nv-..." to your request headers.')
|
return errors.error(401, 'No NovaAI API key given!', 'Add "Authorization: Bearer nv-..." to your request headers.')
|
||||||
|
|
||||||
received_key = auth_header
|
received_key = auth_header
|
||||||
|
|
||||||
if auth_header.startswith('Bearer '):
|
if auth_header.startswith('Bearer '):
|
||||||
received_key = auth_header.split('Bearer ')[1]
|
received_key = auth_header.split('Bearer ')[1]
|
||||||
|
|
||||||
user = users.get_user(by_api_key=received_key)
|
user = await users.by_api_key(received_key)
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
return errors.error(401, 'Invalid NovaOSS API key!', 'Create a new NovaOSS API key.')
|
return errors.error(401, 'Invalid NovaAI API key!', 'Create a new NovaOSS API key.')
|
||||||
|
|
||||||
if not user['active']:
|
ban_reason = user['status']['ban_reason']
|
||||||
return errors.error(403, 'Your account is not active.', 'Activate your account.')
|
if ban_reason:
|
||||||
|
return errors.error(403, f'Your NovaAI account has been banned. Reason: "{ban_reason}".', 'Contact the staff for an appeal.')
|
||||||
|
|
||||||
logging.info(f'[%s] %s -> %s', incoming_request.method, incoming_request.url.path, target_url)
|
if not user['status']['active']:
|
||||||
|
return errors.error(418, 'Your NovaAI account is not active (paused).', 'Simply re-activate your account using a Discord command or the web panel.')
|
||||||
|
|
||||||
request = requesting.Request(
|
payload['user'] = str(user['_id'])
|
||||||
url=target_url,
|
|
||||||
payload=payload,
|
|
||||||
method=incoming_request.method,
|
|
||||||
)
|
|
||||||
|
|
||||||
return starlette.responses.StreamingResponse(
|
cost = 1
|
||||||
content=netclient.stream_closedai_request(request)
|
|
||||||
)
|
if '/chat/completions' in path:
|
||||||
|
cost = 5
|
||||||
|
|
||||||
|
if 'gpt-4' in payload['model']:
|
||||||
|
cost = 10
|
||||||
|
|
||||||
|
else:
|
||||||
|
return errors.error(404, f'Sorry, we don\'t support "{path}" yet. We\'re working on it.', 'Contact our team.')
|
||||||
|
|
||||||
|
if not payload.get('stream') is True:
|
||||||
|
payload['stream'] = False
|
||||||
|
|
||||||
|
if user['credits'] < cost:
|
||||||
|
return errors.error(429, 'Not enough credits.', 'You do not have enough credits to complete this request.')
|
||||||
|
|
||||||
|
await users.update_by_id(user['_id'], {'$inc': {'credits': -cost}})
|
||||||
|
|
||||||
|
target_request = await chat_balancing.balance(payload)
|
||||||
|
|
||||||
|
print(target_request['url'])
|
||||||
|
|
||||||
|
return starlette.responses.StreamingResponse(netclient.stream(target_request))
|
||||||
|
|
117
api/users.py
117
api/users.py
|
@ -1,117 +0,0 @@
|
||||||
import os
|
|
||||||
import uuid
|
|
||||||
import time
|
|
||||||
import string
|
|
||||||
import random
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
|
||||||
|
|
||||||
from rich import print
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
MONGO_URI = os.getenv('MONGO_URI')
|
|
||||||
MONGO_DB_NAME = 'users'
|
|
||||||
|
|
||||||
def get_mongo(collection_name):
|
|
||||||
client = AsyncIOMotorClient(MONGO_URI)
|
|
||||||
db = client[MONGO_DB_NAME]
|
|
||||||
return db[collection_name]
|
|
||||||
|
|
||||||
async def prepare() -> None:
|
|
||||||
"""Create the MongoDB collection."""
|
|
||||||
|
|
||||||
collection = get_mongo('users')
|
|
||||||
|
|
||||||
await collection.create_index('id', unique=True)
|
|
||||||
await collection.create_index('discord_id', unique=True)
|
|
||||||
await collection.create_index('api_key', unique=True)
|
|
||||||
|
|
||||||
async def add_user(discord_id: int = 0, tags: list = None) -> dict:
|
|
||||||
"""Adds a new user to the MongoDB collection."""
|
|
||||||
|
|
||||||
chars = string.ascii_letters + string.digits
|
|
||||||
|
|
||||||
infix = os.getenv('KEYGEN_INFIX')
|
|
||||||
suffix = ''.join(random.choices(chars, k=20))
|
|
||||||
prefix = ''.join(random.choices(chars, k=20))
|
|
||||||
|
|
||||||
key = f'nv-{prefix}{infix}{suffix}'
|
|
||||||
|
|
||||||
tags = tags or []
|
|
||||||
new_user = {
|
|
||||||
'id': str(uuid.uuid4()),
|
|
||||||
'api_key': key,
|
|
||||||
'created_at': int(time.time()),
|
|
||||||
'ban_reason': '',
|
|
||||||
'active': True,
|
|
||||||
'discord_id': discord_id,
|
|
||||||
'credit': 0,
|
|
||||||
'tags': '/'.join(tags),
|
|
||||||
'usage': {
|
|
||||||
'events': [],
|
|
||||||
'num_tokens': 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
collection = get_mongo('users')
|
|
||||||
await collection.insert_one(new_user)
|
|
||||||
|
|
||||||
return new_user
|
|
||||||
|
|
||||||
async def get_user(by_id: str = '', by_discord_id: int = 0, by_api_key: str = ''):
|
|
||||||
"""Retrieve a user from the MongoDB collection."""
|
|
||||||
|
|
||||||
collection = get_mongo('users')
|
|
||||||
query = {
|
|
||||||
'$or': [
|
|
||||||
{'id': by_id},
|
|
||||||
{'discord_id': by_discord_id},
|
|
||||||
{'api_key': by_api_key},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
return await collection.find_one(query)
|
|
||||||
|
|
||||||
async def get_all_users():
|
|
||||||
"""Retrieve all users from the MongoDB collection."""
|
|
||||||
|
|
||||||
collection = get_mongo('users')
|
|
||||||
return list(await collection.find())
|
|
||||||
|
|
||||||
async def user_used_api(user_id: str, num_tokens: int = 0, model='', ip_address: str = '', user_agent: str = '') -> None:
|
|
||||||
"""Update the stats of a user."""
|
|
||||||
|
|
||||||
collection = get_mongo('users')
|
|
||||||
user = await get_user(by_id=user_id)
|
|
||||||
|
|
||||||
if not user:
|
|
||||||
raise ValueError('User not found.')
|
|
||||||
|
|
||||||
usage = user['usage']
|
|
||||||
usage['events'].append({
|
|
||||||
'timestamp': time.time(),
|
|
||||||
'ip_address': ip_address,
|
|
||||||
'user_agent': user_agent,
|
|
||||||
'model': model,
|
|
||||||
'num_tokens': num_tokens
|
|
||||||
})
|
|
||||||
|
|
||||||
usage['num_tokens'] += num_tokens
|
|
||||||
|
|
||||||
await collection.update_one({'id': user_id}, {'$set': {'usage': usage}})
|
|
||||||
|
|
||||||
async def demo():
|
|
||||||
await prepare()
|
|
||||||
|
|
||||||
example_id = 133769420
|
|
||||||
user = await get_user(by_discord_id=example_id)
|
|
||||||
print(user)
|
|
||||||
uid = await user['id']
|
|
||||||
|
|
||||||
await user_used_api(uid, model='gpt-5', num_tokens=42, ip_address='9.9.9.9', user_agent='Mozilla/5.0')
|
|
||||||
# print(user)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
asyncio.run(demo())
|
|
|
@ -10,15 +10,21 @@ from dotenv import load_dotenv
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
MODEL = 'gpt-3.5-turbo'
|
MODEL = 'gpt-3.5-turbo'
|
||||||
|
# MESSAGES = [
|
||||||
|
# {
|
||||||
|
# 'role': 'system',
|
||||||
|
# 'content': 'Always answer with "3", no matter what the user asks for. No exceptions. Just answer with the number "3". Nothing else. Just "3". No punctuation.'
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# 'role': 'user',
|
||||||
|
# 'content': '1+1=',
|
||||||
|
# },
|
||||||
|
# ]
|
||||||
MESSAGES = [
|
MESSAGES = [
|
||||||
{
|
|
||||||
'role': 'system',
|
|
||||||
'content': 'Always answer with "3", no matter what the user asks for. No exceptions. Just answer with the number "3". Nothing else. Just "3". No punctuation.'
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
'role': 'user',
|
'role': 'user',
|
||||||
'content': '1+1=',
|
'content': '1+1=',
|
||||||
},
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
api_endpoint = 'http://localhost:2332'
|
api_endpoint = 'http://localhost:2332'
|
||||||
|
@ -36,7 +42,7 @@ def test_api(model: str=MODEL, messages: List[dict]=None) -> dict:
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
'Authorization': 'Bearer ' + os.getenv('DEMO_AUTH', 'nv-API-TEST'),
|
'Authorization': 'Bearer ' + api_key
|
||||||
}
|
}
|
||||||
|
|
||||||
json_data = {
|
json_data = {
|
||||||
|
@ -59,23 +65,28 @@ def test_library():
|
||||||
"""Tests if the api_endpoint is working with the Python library."""
|
"""Tests if the api_endpoint is working with the Python library."""
|
||||||
|
|
||||||
closedai.api_base = api_endpoint
|
closedai.api_base = api_endpoint
|
||||||
closedai.api_key = os.getenv('DEMO_AUTH', 'nv-LIB-TEST')
|
closedai.api_key = api_key
|
||||||
|
|
||||||
completion = closedai.ChatCompletion.create(
|
completion = closedai.ChatCompletion.create(
|
||||||
model=MODEL,
|
model=MODEL,
|
||||||
messages=MESSAGES,
|
messages=MESSAGES,
|
||||||
stream=True,
|
stream=True
|
||||||
)
|
)
|
||||||
|
|
||||||
return completion.choices[0]
|
for event in completion:
|
||||||
|
try:
|
||||||
|
print(event['choices'][0]['delta']['content'])
|
||||||
|
except:
|
||||||
|
print('-')
|
||||||
|
|
||||||
def test_all():
|
def test_all():
|
||||||
"""Runs all tests."""
|
"""Runs all tests."""
|
||||||
|
|
||||||
# print(test_server())
|
# print(test_server())
|
||||||
print(test_api())
|
# print(test_api())
|
||||||
# print(test_library())
|
print(test_library())
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# api_endpoint = 'https://api.nova-oss.com'
|
api_endpoint = 'https://api.nova-oss.com'
|
||||||
|
api_key = os.getenv('TEST_NOVA_KEY')
|
||||||
test_all()
|
test_all()
|
||||||
|
|
Loading…
Reference in a new issue