mirror of
https://github.com/NovaOSS/nova-api.git
synced 2024-11-25 16:23:57 +01:00
Added load balancer, MongoDB, improved proxy streaming support, more error messages, entire new world order
This commit is contained in:
parent
eeea634da0
commit
08d31d7ad1
9
.gitignore
vendored
9
.gitignore
vendored
|
@ -1,3 +1,12 @@
|
|||
providers/*
|
||||
providers/
|
||||
chat_providers/*
|
||||
chat_providers/
|
||||
|
||||
secret/*
|
||||
secret/
|
||||
/secret
|
||||
|
||||
*.db
|
||||
*.sqlite3
|
||||
*.sql
|
||||
|
|
17
Dockerfile
17
Dockerfile
|
@ -1,17 +0,0 @@
|
|||
#
|
||||
FROM python:3.10
|
||||
|
||||
#
|
||||
WORKDIR /code
|
||||
|
||||
#
|
||||
COPY ./requirements.txt /code/requirements.txt
|
||||
|
||||
#
|
||||
RUN pip install . pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
||||
|
||||
#
|
||||
COPY ./app /code/app
|
||||
|
||||
#
|
||||
CMD ["uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "2333"]
|
1
api/__main__.py
Normal file
1
api/__main__.py
Normal file
|
@ -0,0 +1 @@
|
|||
import providers.__main__
|
30
api/chat_balancing.py
Normal file
30
api/chat_balancing.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
import random
|
||||
import asyncio
|
||||
|
||||
import chat_providers
|
||||
|
||||
provider_modules = [
|
||||
# chat_providers.twa,
|
||||
chat_providers.quantum,
|
||||
# chat_providers.churchless,
|
||||
chat_providers.closed
|
||||
]
|
||||
|
||||
async def balance(payload: dict) -> dict:
|
||||
providers_available = []
|
||||
|
||||
for provider_module in provider_modules:
|
||||
if payload['stream'] and not provider_module.STREAMING:
|
||||
continue
|
||||
|
||||
if payload['model'] not in provider_module.MODELS:
|
||||
continue
|
||||
|
||||
providers_available.append(provider_module)
|
||||
|
||||
provider = random.choice(providers_available)
|
||||
return provider.chat_completion(**payload)
|
||||
|
||||
if __name__ == '__main__':
|
||||
req = asyncio.run(balance(payload={'model': 'gpt-3.5-turbo', 'stream': True}))
|
||||
print(req['url'])
|
10
api/core.py
10
api/core.py
|
@ -4,7 +4,7 @@ import os
|
|||
import json
|
||||
import fastapi
|
||||
|
||||
import users
|
||||
from db import users
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
@ -24,10 +24,11 @@ async def get_users(discord_id: int, incoming_request: fastapi.Request):
|
|||
if auth_error:
|
||||
return auth_error
|
||||
|
||||
user = await users.get_user(by_discord_id=discord_id)
|
||||
user = await users.by_discord_id(discord_id)
|
||||
|
||||
if not user:
|
||||
return fastapi.Response(status_code=404, content='User not found.')
|
||||
|
||||
return user
|
||||
|
||||
@router.post('/users')
|
||||
|
@ -42,7 +43,6 @@ async def create_user(incoming_request: fastapi.Request):
|
|||
discord_id = payload.get('discord_id')
|
||||
except (json.decoder.JSONDecodeError, AttributeError):
|
||||
return fastapi.Response(status_code=400, content='Invalid or no payload received.')
|
||||
|
||||
user = await users.add_user(discord_id=discord_id)
|
||||
|
||||
|
||||
user = await users.create(discord_id)
|
||||
return user
|
||||
|
|
57
api/db/logs.py
Normal file
57
api/db/logs.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
import os
|
||||
import bson
|
||||
import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
|
||||
load_dotenv()
|
||||
|
||||
def _get_mongo(collection_name: str):
|
||||
return AsyncIOMotorClient(os.getenv('MONGO_URI'))['nova-core'][collection_name]
|
||||
|
||||
async def log_api_request(user, request, target_url):
|
||||
payload = await request.json()
|
||||
|
||||
last_prompt = None
|
||||
if 'messages' in payload:
|
||||
last_prompt = payload['messages'][-1]['content']
|
||||
|
||||
model = None
|
||||
if 'model' in payload:
|
||||
model = payload['model']
|
||||
|
||||
new_log_item = {
|
||||
'timestamp': bson.timestamp.Timestamp(datetime.datetime.now(), 0),
|
||||
'method': request.method,
|
||||
'path': request.url.path,
|
||||
'user_id': user['_id'],
|
||||
'security': {
|
||||
'ip': request.client.host,
|
||||
'useragent': request.headers.get('User-Agent')
|
||||
},
|
||||
'details': {
|
||||
'model': model,
|
||||
'last_prompt': last_prompt,
|
||||
'target_url': target_url
|
||||
}
|
||||
}
|
||||
|
||||
inserted = await _get_mongo('logs').insert_one(new_log_item)
|
||||
log_item = await _get_mongo('logs').find_one({'_id': inserted.inserted_id})
|
||||
return log_item
|
||||
|
||||
async def by_id(log_id: str):
|
||||
return await _get_mongo('logs').find_one({'_id': log_id})
|
||||
|
||||
async def by_user_id(user_id: str):
|
||||
return await _get_mongo('logs').find({'user_id': user_id})
|
||||
|
||||
async def delete_by_id(log_id: str):
|
||||
return await _get_mongo('logs').delete_one({'_id': log_id})
|
||||
|
||||
async def delete_by_user_id(user_id: str):
|
||||
return await _get_mongo('logs').delete_many({'user_id': user_id})
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
67
api/db/users.py
Normal file
67
api/db/users.py
Normal file
|
@ -0,0 +1,67 @@
|
|||
import os
|
||||
import random
|
||||
import string
|
||||
import asyncio
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
|
||||
load_dotenv()
|
||||
|
||||
def _get_mongo(collection_name: str):
|
||||
return AsyncIOMotorClient(os.getenv('MONGO_URI'))['nova-core'][collection_name]
|
||||
|
||||
async def create(discord_id: int=0) -> dict:
|
||||
"""Adds a new user to the MongoDB collection."""
|
||||
chars = string.ascii_letters + string.digits
|
||||
|
||||
infix = os.getenv('KEYGEN_INFIX')
|
||||
suffix = ''.join(random.choices(chars, k=20))
|
||||
prefix = ''.join(random.choices(chars, k=20))
|
||||
|
||||
new_api_key = f'nv-{prefix}{infix}{suffix}'
|
||||
|
||||
new_user = {
|
||||
'api_key': new_api_key,
|
||||
'credits': 1000,
|
||||
'role': '',
|
||||
'status': {
|
||||
'active': True,
|
||||
'ban_reason': '',
|
||||
},
|
||||
'auth': {
|
||||
'discord': discord_id,
|
||||
'github': None
|
||||
}
|
||||
}
|
||||
|
||||
await _get_mongo('users').insert_one(new_user)
|
||||
|
||||
user = await _get_mongo('users').find_one({'api_key': new_api_key})
|
||||
return user
|
||||
|
||||
async def by_id(user_id: str):
|
||||
return await _get_mongo('users').find_one({'_id': user_id})
|
||||
|
||||
async def by_discord_id(discord_id: str):
|
||||
return await _get_mongo('users').find_one({'auth.discord': discord_id})
|
||||
|
||||
async def by_api_key(key: str):
|
||||
return await _get_mongo('users').find_one({'api_key': key})
|
||||
|
||||
async def update_by_id(user_id: str, update):
|
||||
return await _get_mongo('users').update_one({'_id': user_id}, update)
|
||||
|
||||
async def update_by_filter(obj_filter, update):
|
||||
return await _get_mongo('users').update_one(obj_filter, update)
|
||||
|
||||
async def delete(user_id: str):
|
||||
await _get_mongo('users').delete_one({'_id': user_id})
|
||||
|
||||
|
||||
async def demo():
|
||||
user = await create(69420)
|
||||
print(user)
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(demo())
|
2
api/helpers/exceptions.py
Normal file
2
api/helpers/exceptions.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
class Retry(Exception):
|
||||
"""The server should retry the request."""
|
|
@ -1,35 +0,0 @@
|
|||
"""Manages web requests."""
|
||||
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from typing import Union, Optional
|
||||
|
||||
load_dotenv()
|
||||
|
||||
EXCLUDED_HEADERS = [
|
||||
'content-encoding',
|
||||
'content-length',
|
||||
'transfer-encoding',
|
||||
'connection'
|
||||
]
|
||||
|
||||
class Request:
|
||||
def __init__(self,
|
||||
url: str,
|
||||
method: str='GET',
|
||||
payload: Optional[Union[dict, list]]=None,
|
||||
headers: dict={
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
):
|
||||
self.method = method.upper()
|
||||
self.url = url.replace('/v1/v1', '/v1')
|
||||
self.payload = payload
|
||||
self.headers = headers
|
||||
self.timeout = int(os.getenv('TRANSFER_TIMEOUT', '120'))
|
||||
|
||||
class HTTPXRequest(Request):
|
||||
def __init__(self, url: str, *args, **kwargs):
|
||||
super().__init__(url, *args, **kwargs)
|
||||
self.url += '?httpx=1'
|
|
@ -7,9 +7,10 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
from dotenv import load_dotenv
|
||||
|
||||
import core
|
||||
import users
|
||||
import transfer
|
||||
|
||||
from db import users
|
||||
|
||||
load_dotenv()
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
|
@ -26,7 +27,9 @@ app.include_router(core.router)
|
|||
|
||||
@app.on_event('startup')
|
||||
async def startup_event():
|
||||
await users.prepare()
|
||||
# DATABASE FIX https://stackoverflow.com/questions/65970988/python-mongodb-motor-objectid-object-is-not-iterable-error-while-trying-to-f
|
||||
import pydantic, bson
|
||||
pydantic.json.ENCODERS_BY_TYPE[bson.objectid.ObjectId]=str
|
||||
|
||||
@app.get('/')
|
||||
async def root():
|
||||
|
@ -39,4 +42,4 @@ async def root():
|
|||
'github': 'https://github.com/novaoss/nova-api'
|
||||
}
|
||||
|
||||
app.add_route('/{path:path}', transfer.handle_api_request, ['GET', 'POST', 'PUT', 'DELETE', 'PATCH'])
|
||||
app.add_route('/{path:path}', transfer.handle, ['GET', 'POST', 'PUT', 'DELETE', 'PATCH'])
|
||||
|
|
|
@ -1,34 +1,44 @@
|
|||
import os
|
||||
import httpx
|
||||
import requests
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import proxies
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from helpers.requesting import Request
|
||||
from helpers import exceptions
|
||||
|
||||
load_dotenv()
|
||||
|
||||
async def stream_closedai_request(request: Request):
|
||||
async with httpx.AsyncClient(
|
||||
# proxies=proxies.default_proxy.urls_httpx,
|
||||
timeout=httpx.Timeout(request.timeout)
|
||||
) as client:
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {os.getenv("CLOSEDAI_KEY")}'
|
||||
}
|
||||
response = await client.request(
|
||||
method=request.method,
|
||||
url=request.url,
|
||||
json=request.payload,
|
||||
headers=headers
|
||||
async def stream(request: dict):
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
for k, v in request.get('headers', {}).items():
|
||||
headers[k] = v
|
||||
|
||||
for _ in range(3):
|
||||
response = requests.request(
|
||||
method=request.get('method', 'POST'),
|
||||
url=request['url'],
|
||||
json=request.get('payload', {}),
|
||||
headers=headers,
|
||||
timeout=int(os.getenv('TRANSFER_TIMEOUT', '120')),
|
||||
proxies=proxies.default_proxy.urls,
|
||||
stream=True
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except Exception as exc:
|
||||
if str(exc) == '429 Client Error: Too Many Requests for url: https://api.openai.com/v1/chat/completions':
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
async for chunk in response.aiter_bytes():
|
||||
chunk = f'{chunk.decode("utf8")}\n\n'
|
||||
yield chunk
|
||||
for chunk in response.iter_lines():
|
||||
chunk = f'{chunk.decode("utf8")}\n\n'
|
||||
yield chunk
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
|
|
|
@ -86,10 +86,33 @@ def test_httpx():
|
|||
|
||||
print(default_proxy.proxies)
|
||||
|
||||
with httpx.Client(
|
||||
# proxies=default_proxy.proxies
|
||||
) as client:
|
||||
return client.get('https://checkip.amazonaws.com').text.strip()
|
||||
|
||||
def test_httpx_workaround():
|
||||
import httpx
|
||||
|
||||
print(default_proxy.proxies)
|
||||
|
||||
# this workaround solves the RNDS issue, but fails for Cloudflare protected websites
|
||||
with httpx.Client(proxies=default_proxy.proxies, headers={'Host': 'checkip.amazonaws.com'}) as client:
|
||||
return client.get(
|
||||
f'http://{socket.gethostbyname("checkip.amazonaws.com")}/',
|
||||
).text.strip()
|
||||
|
||||
def test_requests():
|
||||
import requests
|
||||
|
||||
print(default_proxy.proxies)
|
||||
|
||||
return requests.get(
|
||||
timeout=5,
|
||||
url='https://checkip.amazonaws.com/',
|
||||
proxies=default_proxy.urls
|
||||
).text.strip()
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(test_httpx())
|
||||
# print(test_requests())
|
||||
|
|
|
@ -2,12 +2,16 @@
|
|||
|
||||
import os
|
||||
import json
|
||||
import tokens
|
||||
import logging
|
||||
import starlette
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from helpers import requesting, tokens, errors
|
||||
|
||||
import netclient
|
||||
import chat_balancing
|
||||
|
||||
from db import logs, users
|
||||
from helpers import tokens, errors, exceptions
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
@ -20,12 +24,10 @@ logging.basicConfig(
|
|||
|
||||
logging.info('API started')
|
||||
|
||||
DEFAULT_ENDPOINT = os.getenv('CLOSEDAI_ENDPOINT')
|
||||
|
||||
async def handle_api_request(incoming_request, target_endpoint: str=DEFAULT_ENDPOINT):
|
||||
async def handle(incoming_request):
|
||||
"""Transfer a streaming response from the incoming request to the target endpoint"""
|
||||
|
||||
target_url = f'{target_endpoint}{incoming_request.url.path}'
|
||||
path = incoming_request.url.path
|
||||
|
||||
if incoming_request.method not in ['GET', 'POST', 'PUT', 'DELETE', 'PATCH']:
|
||||
return errors.error(405, f'Method "{incoming_request.method}" is not allowed.', 'Change the request method to the correct one.')
|
||||
|
@ -37,35 +39,54 @@ async def handle_api_request(incoming_request, target_endpoint: str=DEFAULT_ENDP
|
|||
|
||||
try:
|
||||
input_tokens = tokens.count_for_messages(payload['messages'])
|
||||
except:
|
||||
except (KeyError, TypeError):
|
||||
input_tokens = 0
|
||||
|
||||
auth_header = incoming_request.headers.get('Authorization')
|
||||
|
||||
if not auth_header:
|
||||
return errors.error(401, 'No NovaOSS API key given!', 'Add "Authorization: Bearer nv-..." to your request headers.')
|
||||
return errors.error(401, 'No NovaAI API key given!', 'Add "Authorization: Bearer nv-..." to your request headers.')
|
||||
|
||||
received_key = auth_header
|
||||
|
||||
if auth_header.startswith('Bearer '):
|
||||
received_key = auth_header.split('Bearer ')[1]
|
||||
|
||||
user = users.get_user(by_api_key=received_key)
|
||||
user = await users.by_api_key(received_key)
|
||||
|
||||
if not user:
|
||||
return errors.error(401, 'Invalid NovaOSS API key!', 'Create a new NovaOSS API key.')
|
||||
return errors.error(401, 'Invalid NovaAI API key!', 'Create a new NovaOSS API key.')
|
||||
|
||||
if not user['active']:
|
||||
return errors.error(403, 'Your account is not active.', 'Activate your account.')
|
||||
ban_reason = user['status']['ban_reason']
|
||||
if ban_reason:
|
||||
return errors.error(403, f'Your NovaAI account has been banned. Reason: "{ban_reason}".', 'Contact the staff for an appeal.')
|
||||
|
||||
logging.info(f'[%s] %s -> %s', incoming_request.method, incoming_request.url.path, target_url)
|
||||
if not user['status']['active']:
|
||||
return errors.error(418, 'Your NovaAI account is not active (paused).', 'Simply re-activate your account using a Discord command or the web panel.')
|
||||
|
||||
request = requesting.Request(
|
||||
url=target_url,
|
||||
payload=payload,
|
||||
method=incoming_request.method,
|
||||
)
|
||||
payload['user'] = str(user['_id'])
|
||||
|
||||
return starlette.responses.StreamingResponse(
|
||||
content=netclient.stream_closedai_request(request)
|
||||
)
|
||||
cost = 1
|
||||
|
||||
if '/chat/completions' in path:
|
||||
cost = 5
|
||||
|
||||
if 'gpt-4' in payload['model']:
|
||||
cost = 10
|
||||
|
||||
else:
|
||||
return errors.error(404, f'Sorry, we don\'t support "{path}" yet. We\'re working on it.', 'Contact our team.')
|
||||
|
||||
if not payload.get('stream') is True:
|
||||
payload['stream'] = False
|
||||
|
||||
if user['credits'] < cost:
|
||||
return errors.error(429, 'Not enough credits.', 'You do not have enough credits to complete this request.')
|
||||
|
||||
await users.update_by_id(user['_id'], {'$inc': {'credits': -cost}})
|
||||
|
||||
target_request = await chat_balancing.balance(payload)
|
||||
|
||||
print(target_request['url'])
|
||||
|
||||
return starlette.responses.StreamingResponse(netclient.stream(target_request))
|
||||
|
|
117
api/users.py
117
api/users.py
|
@ -1,117 +0,0 @@
|
|||
import os
|
||||
import uuid
|
||||
import time
|
||||
import string
|
||||
import random
|
||||
import asyncio
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
|
||||
from rich import print
|
||||
|
||||
load_dotenv()
|
||||
|
||||
MONGO_URI = os.getenv('MONGO_URI')
|
||||
MONGO_DB_NAME = 'users'
|
||||
|
||||
def get_mongo(collection_name):
|
||||
client = AsyncIOMotorClient(MONGO_URI)
|
||||
db = client[MONGO_DB_NAME]
|
||||
return db[collection_name]
|
||||
|
||||
async def prepare() -> None:
|
||||
"""Create the MongoDB collection."""
|
||||
|
||||
collection = get_mongo('users')
|
||||
|
||||
await collection.create_index('id', unique=True)
|
||||
await collection.create_index('discord_id', unique=True)
|
||||
await collection.create_index('api_key', unique=True)
|
||||
|
||||
async def add_user(discord_id: int = 0, tags: list = None) -> dict:
|
||||
"""Adds a new user to the MongoDB collection."""
|
||||
|
||||
chars = string.ascii_letters + string.digits
|
||||
|
||||
infix = os.getenv('KEYGEN_INFIX')
|
||||
suffix = ''.join(random.choices(chars, k=20))
|
||||
prefix = ''.join(random.choices(chars, k=20))
|
||||
|
||||
key = f'nv-{prefix}{infix}{suffix}'
|
||||
|
||||
tags = tags or []
|
||||
new_user = {
|
||||
'id': str(uuid.uuid4()),
|
||||
'api_key': key,
|
||||
'created_at': int(time.time()),
|
||||
'ban_reason': '',
|
||||
'active': True,
|
||||
'discord_id': discord_id,
|
||||
'credit': 0,
|
||||
'tags': '/'.join(tags),
|
||||
'usage': {
|
||||
'events': [],
|
||||
'num_tokens': 0
|
||||
}
|
||||
}
|
||||
|
||||
collection = get_mongo('users')
|
||||
await collection.insert_one(new_user)
|
||||
|
||||
return new_user
|
||||
|
||||
async def get_user(by_id: str = '', by_discord_id: int = 0, by_api_key: str = ''):
|
||||
"""Retrieve a user from the MongoDB collection."""
|
||||
|
||||
collection = get_mongo('users')
|
||||
query = {
|
||||
'$or': [
|
||||
{'id': by_id},
|
||||
{'discord_id': by_discord_id},
|
||||
{'api_key': by_api_key},
|
||||
]
|
||||
}
|
||||
return await collection.find_one(query)
|
||||
|
||||
async def get_all_users():
|
||||
"""Retrieve all users from the MongoDB collection."""
|
||||
|
||||
collection = get_mongo('users')
|
||||
return list(await collection.find())
|
||||
|
||||
async def user_used_api(user_id: str, num_tokens: int = 0, model='', ip_address: str = '', user_agent: str = '') -> None:
|
||||
"""Update the stats of a user."""
|
||||
|
||||
collection = get_mongo('users')
|
||||
user = await get_user(by_id=user_id)
|
||||
|
||||
if not user:
|
||||
raise ValueError('User not found.')
|
||||
|
||||
usage = user['usage']
|
||||
usage['events'].append({
|
||||
'timestamp': time.time(),
|
||||
'ip_address': ip_address,
|
||||
'user_agent': user_agent,
|
||||
'model': model,
|
||||
'num_tokens': num_tokens
|
||||
})
|
||||
|
||||
usage['num_tokens'] += num_tokens
|
||||
|
||||
await collection.update_one({'id': user_id}, {'$set': {'usage': usage}})
|
||||
|
||||
async def demo():
|
||||
await prepare()
|
||||
|
||||
example_id = 133769420
|
||||
user = await get_user(by_discord_id=example_id)
|
||||
print(user)
|
||||
uid = await user['id']
|
||||
|
||||
await user_used_api(uid, model='gpt-5', num_tokens=42, ip_address='9.9.9.9', user_agent='Mozilla/5.0')
|
||||
# print(user)
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(demo())
|
|
@ -10,15 +10,21 @@ from dotenv import load_dotenv
|
|||
load_dotenv()
|
||||
|
||||
MODEL = 'gpt-3.5-turbo'
|
||||
# MESSAGES = [
|
||||
# {
|
||||
# 'role': 'system',
|
||||
# 'content': 'Always answer with "3", no matter what the user asks for. No exceptions. Just answer with the number "3". Nothing else. Just "3". No punctuation.'
|
||||
# },
|
||||
# {
|
||||
# 'role': 'user',
|
||||
# 'content': '1+1=',
|
||||
# },
|
||||
# ]
|
||||
MESSAGES = [
|
||||
{
|
||||
'role': 'system',
|
||||
'content': 'Always answer with "3", no matter what the user asks for. No exceptions. Just answer with the number "3". Nothing else. Just "3". No punctuation.'
|
||||
},
|
||||
{
|
||||
'role': 'user',
|
||||
'content': '1+1=',
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
api_endpoint = 'http://localhost:2332'
|
||||
|
@ -36,7 +42,7 @@ def test_api(model: str=MODEL, messages: List[dict]=None) -> dict:
|
|||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': 'Bearer ' + os.getenv('DEMO_AUTH', 'nv-API-TEST'),
|
||||
'Authorization': 'Bearer ' + api_key
|
||||
}
|
||||
|
||||
json_data = {
|
||||
|
@ -59,23 +65,28 @@ def test_library():
|
|||
"""Tests if the api_endpoint is working with the Python library."""
|
||||
|
||||
closedai.api_base = api_endpoint
|
||||
closedai.api_key = os.getenv('DEMO_AUTH', 'nv-LIB-TEST')
|
||||
closedai.api_key = api_key
|
||||
|
||||
completion = closedai.ChatCompletion.create(
|
||||
model=MODEL,
|
||||
messages=MESSAGES,
|
||||
stream=True,
|
||||
stream=True
|
||||
)
|
||||
|
||||
return completion.choices[0]
|
||||
for event in completion:
|
||||
try:
|
||||
print(event['choices'][0]['delta']['content'])
|
||||
except:
|
||||
print('-')
|
||||
|
||||
def test_all():
|
||||
"""Runs all tests."""
|
||||
|
||||
# print(test_server())
|
||||
print(test_api())
|
||||
# print(test_library())
|
||||
# print(test_api())
|
||||
print(test_library())
|
||||
|
||||
if __name__ == '__main__':
|
||||
# api_endpoint = 'https://api.nova-oss.com'
|
||||
api_endpoint = 'https://api.nova-oss.com'
|
||||
api_key = os.getenv('TEST_NOVA_KEY')
|
||||
test_all()
|
||||
|
|
Loading…
Reference in a new issue