mirror of
https://github.com/NovaOSS/nova-api.git
synced 2024-11-25 20:33:58 +01:00
Works for BetterGPT
This commit is contained in:
parent
3cfdddab7c
commit
83a3c13abf
17
Dockerfile
Normal file
17
Dockerfile
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
#
|
||||||
|
FROM python:3.10
|
||||||
|
|
||||||
|
#
|
||||||
|
WORKDIR /code
|
||||||
|
|
||||||
|
#
|
||||||
|
COPY ./requirements.txt /code/requirements.txt
|
||||||
|
|
||||||
|
#
|
||||||
|
RUN pip install . pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
||||||
|
|
||||||
|
#
|
||||||
|
COPY ./app /code/app
|
||||||
|
|
||||||
|
#
|
||||||
|
CMD ["uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "2333"]
|
104
api/keys.py
104
api/keys.py
|
@ -1,104 +0,0 @@
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import time
|
|
||||||
from pymongo import MongoClient
|
|
||||||
|
|
||||||
with open('./config.json', 'r') as file:
|
|
||||||
config = json.load(file)
|
|
||||||
|
|
||||||
class Keys:
|
|
||||||
# --- START OF CONFIG ---
|
|
||||||
MONGO_URI = os.getenv('MONGO_URI') or config.get('MONGO_URI')
|
|
||||||
# --- END OF CONFIG ---
|
|
||||||
|
|
||||||
locked_keys = set()
|
|
||||||
cache = {}
|
|
||||||
|
|
||||||
# Initialize MongoDB
|
|
||||||
client = MongoClient(MONGO_URI)
|
|
||||||
db = client.get_database('keys_db')
|
|
||||||
collection = db['keys']
|
|
||||||
|
|
||||||
def __init__(self, key: str, model: str, provider_name: str, ratelimit: int, url: str):
|
|
||||||
self.key = key
|
|
||||||
self.model = model
|
|
||||||
self.provider_name = provider_name
|
|
||||||
self.ratelimit = ratelimit
|
|
||||||
self.url = url
|
|
||||||
if not Keys.cache:
|
|
||||||
self._load_keys()
|
|
||||||
|
|
||||||
def _load_keys(self) -> None:
|
|
||||||
cursor = Keys.collection.find({}, {'_id': 0, 'key_value': 1, 'model': 1, 'provider_name': 1, 'ratelimit': 1, 'url': 1, 'last_used': 1})
|
|
||||||
for doc in cursor:
|
|
||||||
key_value = doc['key_value']
|
|
||||||
model = doc['model']
|
|
||||||
provider_name = doc['provider_name']
|
|
||||||
ratelimit = doc['ratelimit']
|
|
||||||
url = doc['url']
|
|
||||||
last_used = doc.get('last_used', 0)
|
|
||||||
key_data = {'provider_name': provider_name, 'ratelimit': ratelimit, 'url': url, 'last_used': last_used}
|
|
||||||
Keys.cache.setdefault(model, {}).setdefault(key_value, key_data)
|
|
||||||
|
|
||||||
def lock(self) -> None:
|
|
||||||
self.locked_keys.add(self.key)
|
|
||||||
|
|
||||||
def unlock(self) -> None:
|
|
||||||
self.locked_keys.remove(self.key)
|
|
||||||
|
|
||||||
def is_locked(self) -> bool:
|
|
||||||
return self.key in self.locked_keys
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get(model: str) -> str:
|
|
||||||
key_candidates = list(Keys.cache.get(model, {}).keys())
|
|
||||||
random.shuffle(key_candidates)
|
|
||||||
|
|
||||||
current_time = time.time()
|
|
||||||
|
|
||||||
for key_candidate in key_candidates:
|
|
||||||
key_data = Keys.cache[model][key_candidate]
|
|
||||||
key = Keys(key_candidate, model, key_data['provider_name'], key_data['ratelimit'], key_data['url'])
|
|
||||||
time_since_last_used = current_time - key_data.get('last_used', 0)
|
|
||||||
|
|
||||||
if not key.is_locked() and time_since_last_used >= key.ratelimit:
|
|
||||||
key.lock()
|
|
||||||
key_data['last_used'] = current_time # Update last_used in the cache
|
|
||||||
Keys.collection.update_one(
|
|
||||||
{'key_value': key.key, 'model': key.model},
|
|
||||||
{'$set': {'last_used': current_time}} # Update last_used in the database
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
'url': key.url,
|
|
||||||
'key_value': key.key
|
|
||||||
}
|
|
||||||
|
|
||||||
print(f"[WARN] No unlocked keys found for model '{model}' in get keys request!")
|
|
||||||
|
|
||||||
def delete(self) -> None:
|
|
||||||
Keys.collection.delete_one({'key_value': self.key, 'model': self.model})
|
|
||||||
# Update cache
|
|
||||||
try:
|
|
||||||
del Keys.cache[self.model][self.key]
|
|
||||||
except KeyError:
|
|
||||||
print(f"[WARN] Tried to remove a key from cache which was not present: {self.key}")
|
|
||||||
|
|
||||||
def save(self) -> None:
|
|
||||||
key_data = {
|
|
||||||
'provider_name': self.provider_name,
|
|
||||||
'ratelimit': self.ratelimit,
|
|
||||||
'url': self.url,
|
|
||||||
'last_used': 0 # Initialize last_used to 0 when saving a new key
|
|
||||||
}
|
|
||||||
Keys.collection.insert_one({'key_value': self.key, 'model': self.model, **key_data})
|
|
||||||
# Update cache
|
|
||||||
Keys.cache.setdefault(self.model, {}).setdefault(self.key, key_data)
|
|
||||||
|
|
||||||
# Usage example:
|
|
||||||
# os.environ['MONGO_URI'] = "mongodb://localhost:27017"
|
|
||||||
# key_instance = Keys("example_key", "gpt-4", "openai", "10", "https://whatever-openai-thing-is/chat/completions/")
|
|
||||||
# key_instance.save()
|
|
||||||
# key_value = Keys.get("gpt-4")
|
|
||||||
|
|
||||||
|
|
23
api/main.py
23
api/main.py
|
@ -1,8 +1,8 @@
|
||||||
import os
|
"""FastAPI setup."""
|
||||||
import fastapi
|
|
||||||
|
import fastapi
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from starlette.responses import StreamingResponse
|
|
||||||
from starlette.requests import Request
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
@ -25,9 +25,7 @@ app.add_middleware(
|
||||||
@app.on_event('startup')
|
@app.on_event('startup')
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
"""Read up the API server."""
|
"""Read up the API server."""
|
||||||
|
# await security.ip_protection_check()
|
||||||
# security.enable_proxy()
|
|
||||||
# security.ip_protection_check()
|
|
||||||
|
|
||||||
@app.get('/')
|
@app.get('/')
|
||||||
async def root():
|
async def root():
|
||||||
|
@ -35,8 +33,15 @@ async def root():
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'status': 'ok',
|
'status': 'ok',
|
||||||
'discord': 'https://discord.gg/85gdcd57Xr',
|
'readme': 'https://nova-oss.com'
|
||||||
'github': 'https://github.com/Luna-OSS'
|
}
|
||||||
|
|
||||||
|
@app.route('/v1')
|
||||||
|
async def api_root():
|
||||||
|
"""Returns the API root endpoint."""
|
||||||
|
|
||||||
|
return {
|
||||||
|
'status': 'ok',
|
||||||
}
|
}
|
||||||
|
|
||||||
app.add_route('/{path:path}', transfer.transfer_streaming_response, ['GET', 'POST', 'PUT', 'DELETE', 'PATCH'])
|
app.add_route('/{path:path}', transfer.transfer_streaming_response, ['GET', 'POST', 'PUT', 'DELETE', 'PATCH'])
|
||||||
|
|
|
@ -1,12 +1,9 @@
|
||||||
"""This module contains the Proxy class, which represents a proxy."""
|
"""This module makes it easy to implement proxies by providing a class.."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import httpx
|
|
||||||
import socket
|
import socket
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from sockslib import socks
|
|
||||||
from rich import print
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
@ -28,43 +25,6 @@ class Proxy:
|
||||||
self.username = username
|
self.username = username
|
||||||
self.password = password
|
self.password = password
|
||||||
|
|
||||||
@property
|
|
||||||
def auth(self):
|
|
||||||
"""Returns the authentication part of the proxy URL, if the proxy has a username and password."""
|
|
||||||
return f'{self.username}:{self.password}@' if all([self.username, self.password]) else ''
|
|
||||||
|
|
||||||
@property
|
|
||||||
def protocol(self):
|
|
||||||
"""Makes sure the hostnames are resolved correctly using the proxy.
|
|
||||||
See https://stackoverflow.com/a/43266186
|
|
||||||
"""
|
|
||||||
return self.proxy_type# + 'h' if self.proxy_type.startswith('socks') else self.proxy_type
|
|
||||||
|
|
||||||
@property
|
|
||||||
def proxies(self):
|
|
||||||
"""Returns a dictionary of proxies, ready to be used with the requests library or httpx.
|
|
||||||
"""
|
|
||||||
|
|
||||||
url = f'{self.protocol}://{self.auth}{self.host}:{self.port}'
|
|
||||||
|
|
||||||
proxies_dict = {
|
|
||||||
'http://': url.replace('https', 'http') if self.proxy_type == 'https' else url,
|
|
||||||
'https://': url.replace('http', 'https') if self.proxy_type == 'http' else url
|
|
||||||
}
|
|
||||||
|
|
||||||
return proxies_dict
|
|
||||||
|
|
||||||
@property
|
|
||||||
def url(self):
|
|
||||||
"""Returns the proxy URL."""
|
|
||||||
return f'{self.protocol}://{self.auth}{self.host}:{self.port}'
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return f'{self.proxy_type}://{len(self.auth) * "*"}{self.host}:{self.port}'
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f'<Proxy type={self.proxy_type} host={self.host} port={self.port} username={self.username} password={len(self.password) * "*"}>'
|
|
||||||
|
|
||||||
active_proxy = Proxy(
|
active_proxy = Proxy(
|
||||||
proxy_type=os.getenv('PROXY_TYPE', 'http'),
|
proxy_type=os.getenv('PROXY_TYPE', 'http'),
|
||||||
host=os.getenv('PROXY_HOST', '127.0.0.1'),
|
host=os.getenv('PROXY_HOST', '127.0.0.1'),
|
||||||
|
@ -72,54 +32,3 @@ active_proxy = Proxy(
|
||||||
username=os.getenv('PROXY_USER'),
|
username=os.getenv('PROXY_USER'),
|
||||||
password=os.getenv('PROXY_PASS')
|
password=os.getenv('PROXY_PASS')
|
||||||
)
|
)
|
||||||
|
|
||||||
def activate_proxy() -> None:
|
|
||||||
socks.set_default_proxy(
|
|
||||||
proxy_type=socks.PROXY_TYPES[active_proxy.proxy_type.upper()],
|
|
||||||
addr=active_proxy.host,
|
|
||||||
port=active_proxy.port,
|
|
||||||
username=active_proxy.username,
|
|
||||||
password=active_proxy.password
|
|
||||||
)
|
|
||||||
socket.socket = socks.socksocket
|
|
||||||
|
|
||||||
def check_proxy():
|
|
||||||
"""Checks if the proxy is working."""
|
|
||||||
|
|
||||||
resp = httpx.get(
|
|
||||||
url='https://echo.hoppscotch.io/',
|
|
||||||
timeout=20,
|
|
||||||
proxies=active_proxy.proxies
|
|
||||||
)
|
|
||||||
resp.raise_for_status()
|
|
||||||
|
|
||||||
return resp.ok
|
|
||||||
|
|
||||||
async def check_api():
|
|
||||||
model = 'gpt-3.5-turbo'
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
'role': 'user',
|
|
||||||
'content': 'Explain what a wormhole is.'
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Authorization': 'Bearer ' + os.getenv('CLOSEDAI_KEY')
|
|
||||||
}
|
|
||||||
|
|
||||||
json_data = {
|
|
||||||
'model': model,
|
|
||||||
'messages': messages,
|
|
||||||
'stream': True
|
|
||||||
}
|
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=20) as client:
|
|
||||||
async with client.stream("POST", 'https://api.openai.com/v1/chat/completions', headers=headers, json=json_data) as response:
|
|
||||||
response.raise_for_status()
|
|
||||||
async for chunk in response.aiter_text():
|
|
||||||
print(chunk.strip())
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
asyncio.run(check_api())
|
|
||||||
|
|
|
@ -1,47 +1,45 @@
|
||||||
"""Security checks for the API. Checks if the IP is masked etc."""
|
"""Security checks for the API. Checks if the IP is masked etc."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import httpx
|
import asyncio
|
||||||
|
|
||||||
from rich import print
|
from rich import print
|
||||||
|
from tests import check_proxy
|
||||||
import proxies
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
is_proxy_enabled = False
|
is_proxy_enabled = bool(os.getenv('PROXY_HOST', None))
|
||||||
|
|
||||||
class InsecureIPError(Exception):
|
class InsecureIPError(Exception):
|
||||||
"""Raised when the IP address of the server is not secure."""
|
"""Raised when the IP address of the server is not secure."""
|
||||||
|
|
||||||
def ip_protection_check():
|
async def ip_protection_check():
|
||||||
"""Makes sure that the actual server IP address is not exposed to the public."""
|
"""Makes sure that the actual server IP address is not exposed to the public."""
|
||||||
|
|
||||||
|
if not is_proxy_enabled:
|
||||||
|
print('[yellow]WARN: The proxy is not enabled. \
|
||||||
|
Skipping IP protection check.[/yellow]')
|
||||||
|
return True
|
||||||
|
|
||||||
actual_ips = os.getenv('ACTUAL_IPS', '').split()
|
actual_ips = os.getenv('ACTUAL_IPS', '').split()
|
||||||
|
|
||||||
if actual_ips:
|
if actual_ips:
|
||||||
echo_response = httpx.get(
|
# run the async function check_proxy() and get its result
|
||||||
url='https://echo.hoppscotch.io/',
|
response_ip = await check_proxy()
|
||||||
timeout=15
|
|
||||||
)
|
|
||||||
|
|
||||||
response_data = echo_response.json()
|
|
||||||
response_ip = response_data['headers']['x-forwarded-for']
|
|
||||||
|
|
||||||
for actual_ip in actual_ips:
|
for actual_ip in actual_ips:
|
||||||
if actual_ip in response_data:
|
if actual_ip in response_ip:
|
||||||
raise InsecureIPError(f'IP pattern "{actual_ip}" is in the values of ACTUAL_IPS of the\
|
raise InsecureIPError(f'IP pattern "{actual_ip}" is in the values of ACTUAL_IPS of the\
|
||||||
.env file. Enable a VPN or proxy to continue.')
|
.env file. Enable a VPN or proxy to continue.')
|
||||||
|
|
||||||
if is_proxy_enabled:
|
if is_proxy_enabled:
|
||||||
print(f'[green]SUCCESS: The IP "{response_ip}" was detected, which seems to be a proxy. Great![/green]')
|
print(f'[green]GOOD: The IP "{response_ip}" was detected, which seems to be a proxy. Great![/green]')
|
||||||
|
return True
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print('[yellow]WARNING: ACTUAL_IPS is not set in the .env file or empty.\
|
print('[yellow]WARN: ACTUAL_IPS is not set in the .env file or empty.\
|
||||||
This means that the real IP of the server could be exposed. If you\'re using something\
|
This means that the real IP of the server could be exposed. If you\'re using something\
|
||||||
like Cloudflare or Repl.it, you can ignore this warning.[/yellow]')
|
like Cloudflare or Repl.it, you can ignore this warning.[/yellow]')
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
enable_proxy()
|
|
||||||
ip_protection_check()
|
ip_protection_check()
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
"""Module for transferring requests to ClosedAI API"""
|
"""Module for transferring requests to ClosedAI API"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import aiohttp_socks
|
import aiohttp_socks
|
||||||
|
|
||||||
|
@ -8,11 +10,21 @@ import proxies
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
import starlette
|
||||||
from starlette.responses import StreamingResponse
|
from starlette.responses import StreamingResponse
|
||||||
from starlette.background import BackgroundTask
|
from starlette.background import BackgroundTask
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
# log to "api.log" file
|
||||||
|
logging.basicConfig(
|
||||||
|
filename='api.log',
|
||||||
|
level=logging.DEBUG,
|
||||||
|
format='%(asctime)s %(levelname)s %(name)s %(message)s'
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info('API started')
|
||||||
|
|
||||||
EXCLUDED_HEADERS = [
|
EXCLUDED_HEADERS = [
|
||||||
'content-encoding',
|
'content-encoding',
|
||||||
'content-length',
|
'content-length',
|
||||||
|
@ -33,7 +45,16 @@ proxy_connector = aiohttp_socks.ProxyConnector(
|
||||||
async def transfer_streaming_response(incoming_request, target_endpoint: str='https://api.openai.com/v1'):
|
async def transfer_streaming_response(incoming_request, target_endpoint: str='https://api.openai.com/v1'):
|
||||||
"""Transfer a streaming response from the incoming request to the target endpoint"""
|
"""Transfer a streaming response from the incoming request to the target endpoint"""
|
||||||
|
|
||||||
|
if incoming_request.headers.get('Authorization') != f'Bearer {os.getenv("DEMO_AUTH")}':
|
||||||
|
return starlette.responses.Response(
|
||||||
|
status_code=403,
|
||||||
|
content='Invalid API key'
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
incoming_json_payload = await incoming_request.json()
|
incoming_json_payload = await incoming_request.json()
|
||||||
|
except json.decoder.JSONDecodeError:
|
||||||
|
incoming_json_payload = {}
|
||||||
|
|
||||||
async def receive_target_stream():
|
async def receive_target_stream():
|
||||||
connector = aiohttp_socks.ProxyConnector(
|
connector = aiohttp_socks.ProxyConnector(
|
||||||
|
@ -46,12 +67,15 @@ async def transfer_streaming_response(incoming_request, target_endpoint: str='ht
|
||||||
)
|
)
|
||||||
async with aiohttp.ClientSession(
|
async with aiohttp.ClientSession(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
timeout=aiohttp.ClientTimeout(total=120),
|
timeout=aiohttp.ClientTimeout(total=int(os.getenv('TRANSFER_TIMEOUT', '120'))),
|
||||||
raise_for_status=True
|
raise_for_status=True
|
||||||
) as session:
|
) as session:
|
||||||
|
target_url = f'{target_endpoint}{incoming_request.url.path}'.replace('/v1/v1', '/v1')
|
||||||
|
logging.info('TRANSFER %s -> %s', incoming_request.url.path, target_url)
|
||||||
|
|
||||||
async with session.request(
|
async with session.request(
|
||||||
method=incoming_request.method,
|
method=incoming_request.method,
|
||||||
url=f'{target_endpoint}/{incoming_request.url.path}',
|
url=target_url,
|
||||||
json=incoming_json_payload,
|
json=incoming_json_payload,
|
||||||
headers={
|
headers={
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
|
@ -60,6 +84,7 @@ async def transfer_streaming_response(incoming_request, target_endpoint: str='ht
|
||||||
) as response:
|
) as response:
|
||||||
async for chunk in response.content.iter_any():
|
async for chunk in response.content.iter_any():
|
||||||
chunk = f'{chunk.decode("utf8")}\n\n'
|
chunk = f'{chunk.decode("utf8")}\n\n'
|
||||||
|
logging.debug(chunk)
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
|
||||||
port = sys.argv[1] if len(sys.argv) > 1 else 8000
|
port = sys.argv[1] if len(sys.argv) > 1 else 2333
|
||||||
os.system(f'cd api && uvicorn main:app --reload --host 0.0.0.0 --port {port}')
|
os.system(f'cd api && uvicorn main:app --reload --host 0.0.0.0 --port {port}')
|
||||||
|
|
|
@ -1,3 +0,0 @@
|
||||||
{
|
|
||||||
"MONGO_URI": ""
|
|
||||||
}
|
|
6
setup.py
6
setup.py
|
@ -1,13 +1,13 @@
|
||||||
import setuptools
|
import setuptools
|
||||||
|
|
||||||
with open('README.md', 'r') as fh:
|
with open('README.md', 'r', encoding='utf8') as fh:
|
||||||
long_description = fh.read()
|
long_description = fh.read()
|
||||||
|
|
||||||
setuptools.setup(
|
setuptools.setup(
|
||||||
name='nova-api',
|
name='nova-api',
|
||||||
version='0.0.1',
|
version='0.0.1',
|
||||||
author='Luna OSS',
|
author='NovaOSS Contributors',
|
||||||
author_email='nsde@dmc.chat',
|
author_email='owner@nova-oss.com',
|
||||||
description='Nova API Server',
|
description='Nova API Server',
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
long_description_content_type='text/markdown',
|
long_description_content_type='text/markdown',
|
||||||
|
|
|
@ -1,35 +1,42 @@
|
||||||
"""Tests the API."""
|
"""Tests the API."""
|
||||||
|
|
||||||
from typing import List
|
import os
|
||||||
|
|
||||||
import openai as closedai
|
import openai as closedai
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
PORT = 8000
|
from typing import List
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
MODEL = 'gpt-3.5-turbo'
|
MODEL = 'gpt-3.5-turbo'
|
||||||
MESSAGES = [
|
MESSAGES = [
|
||||||
|
{
|
||||||
|
'role': 'system',
|
||||||
|
'content': 'Always answer with "3", no matter what the user asks for. No exceptions. Just answer with the number "3". Nothing else. Just "3". No punctuation.'
|
||||||
|
},
|
||||||
{
|
{
|
||||||
'role': 'user',
|
'role': 'user',
|
||||||
'content': 'Hello!',
|
'content': '1+1=',
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
ENDPOINT = f'http://localhost:{PORT}'
|
|
||||||
|
api_endpoint = 'http://localhost:2333'
|
||||||
|
|
||||||
def test_server():
|
def test_server():
|
||||||
"""Tests if the API is running."""
|
"""Tests if the API server is running."""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return httpx.get(f'{ENDPOINT}').json()['status'] == 'ok'
|
return httpx.get(f'{api_endpoint}').json()['status'] == 'ok'
|
||||||
except httpx.ConnectError as exc:
|
except httpx.ConnectError as exc:
|
||||||
raise ConnectionError(f'API is not running on port {PORT}.') from exc
|
raise ConnectionError(f'API is not running on port {api_endpoint}.') from exc
|
||||||
|
|
||||||
def test_api(model: str=MODEL, messages: List[dict]=None) -> dict:
|
def test_api(model: str=MODEL, messages: List[dict]=None) -> dict:
|
||||||
"""Tests an API endpoint."""
|
"""Tests an API api_endpoint."""
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
'Authorization': 'nv-API-TEST',
|
'Authorization': 'Bearer ' + os.getenv('DEMO_AUTH', 'nv-API-TEST'),
|
||||||
}
|
}
|
||||||
|
|
||||||
json_data = {
|
json_data = {
|
||||||
|
@ -38,20 +45,26 @@ def test_api(model: str=MODEL, messages: List[dict]=None) -> dict:
|
||||||
'stream': True,
|
'stream': True,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = httpx.post(f'{ENDPOINT}/chat/completions', headers=headers, json=json_data, timeout=20)
|
response = httpx.post(
|
||||||
|
url=f'{api_endpoint}/chat/completions',
|
||||||
|
headers=headers,
|
||||||
|
json=json_data,
|
||||||
|
timeout=20
|
||||||
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def test_library():
|
def test_library():
|
||||||
"""Tests if the endpoint is working with the "Closed"AI library."""
|
"""Tests if the api_endpoint is working with the Python library."""
|
||||||
|
|
||||||
closedai.api_base = ENDPOINT
|
closedai.api_base = api_endpoint
|
||||||
closedai.api_key = 'nv-LIBRARY-TEST'
|
closedai.api_key = os.getenv('DEMO_AUTH', 'nv-LIB-TEST')
|
||||||
|
|
||||||
completion = closedai.ChatCompletion.create(
|
completion = closedai.ChatCompletion.create(
|
||||||
model=MODEL,
|
model=MODEL,
|
||||||
messages=MESSAGES,
|
messages=MESSAGES,
|
||||||
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return completion.choices[0]
|
return completion.choices[0]
|
||||||
|
@ -59,9 +72,10 @@ def test_library():
|
||||||
def test_all():
|
def test_all():
|
||||||
"""Runs all tests."""
|
"""Runs all tests."""
|
||||||
|
|
||||||
print(test_server())
|
# print(test_server())
|
||||||
print(test_api())
|
print(test_api())
|
||||||
print(test_library())
|
# print(test_library())
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
api_endpoint = 'https://api.nova-oss.com'
|
||||||
test_all()
|
test_all()
|
||||||
|
|
Loading…
Reference in a new issue