mirror of
https://github.com/NovaOSS/nova-api.git
synced 2024-11-25 23:53:57 +01:00
Compare commits
No commits in common. "c4137a9eab6c10f74c9c04cb845ab4ef2ffe117c" and "a7b2ce7aa5dc7c07211dd4dfdae8a7512da0017b" have entirely different histories.
c4137a9eab
...
a7b2ce7aa5
47
api/core.py
47
api/core.py
|
@ -3,8 +3,6 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from helpers import errors
|
|
||||||
|
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||||
sys.path.append(project_root)
|
sys.path.append(project_root)
|
||||||
|
|
||||||
|
@ -26,17 +24,19 @@ load_dotenv()
|
||||||
router = fastapi.APIRouter(tags=['core'])
|
router = fastapi.APIRouter(tags=['core'])
|
||||||
|
|
||||||
async def check_core_auth(request):
|
async def check_core_auth(request):
|
||||||
"""Checks the core API key. Returns nothing if it's valid, otherwise returns an error.
|
"""
|
||||||
|
|
||||||
|
### Checks the request's auth
|
||||||
|
Auth is taken from environment variable `CORE_API_KEY`
|
||||||
|
|
||||||
"""
|
"""
|
||||||
received_auth = request.headers.get('Authorization')
|
received_auth = request.headers.get('Authorization')
|
||||||
|
|
||||||
correct_core_api = os.environ['CORE_API_KEY']
|
correct_core_api = os.environ['CORE_API_KEY']
|
||||||
|
|
||||||
# use hmac.compare_digest to prevent timing attacks
|
# use hmac.compare_digest to prevent timing attacks
|
||||||
if not (received_auth and hmac.compare_digest(received_auth, correct_core_api)):
|
if received_auth and hmac.compare_digest(received_auth, correct_core_api):
|
||||||
return await errors.error(401, 'The core API key you provided is invalid.', 'Check the `Authorization` header.')
|
return fastapi.Response(status_code=403, content='Invalid or no API key given.')
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
@router.get('/users')
|
@router.get('/users')
|
||||||
async def get_users(discord_id: int, incoming_request: fastapi.Request):
|
async def get_users(discord_id: int, incoming_request: fastapi.Request):
|
||||||
|
@ -50,7 +50,7 @@ async def get_users(discord_id: int, incoming_request: fastapi.Request):
|
||||||
manager = UserManager()
|
manager = UserManager()
|
||||||
user = await manager.user_by_discord_id(discord_id)
|
user = await manager.user_by_discord_id(discord_id)
|
||||||
if not user:
|
if not user:
|
||||||
return await errors.error(404, 'Discord user not found in the API database.', 'Check the `discord_id` parameter.')
|
return fastapi.Response(status_code=404, content='User not found.')
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
@ -83,7 +83,7 @@ async def create_user(incoming_request: fastapi.Request):
|
||||||
payload = await incoming_request.json()
|
payload = await incoming_request.json()
|
||||||
discord_id = payload.get('discord_id')
|
discord_id = payload.get('discord_id')
|
||||||
except (json.decoder.JSONDecodeError, AttributeError):
|
except (json.decoder.JSONDecodeError, AttributeError):
|
||||||
return await errors.error(400, 'Invalid or no payload received.', 'The payload must be a JSON object with a `discord_id` key.')
|
return fastapi.Response(status_code=400, content='Invalid or no payload received.')
|
||||||
|
|
||||||
# Create the user
|
# Create the user
|
||||||
manager = UserManager()
|
manager = UserManager()
|
||||||
|
@ -106,12 +106,9 @@ async def update_user(incoming_request: fastapi.Request):
|
||||||
discord_id = payload.get('discord_id')
|
discord_id = payload.get('discord_id')
|
||||||
updates = payload.get('updates')
|
updates = payload.get('updates')
|
||||||
except (json.decoder.JSONDecodeError, AttributeError):
|
except (json.decoder.JSONDecodeError, AttributeError):
|
||||||
return await errors.error(
|
return fastapi.Response(status_code=400, content='Invalid or no payload received.')
|
||||||
400, 'Invalid or no payload received.',
|
|
||||||
'The payload must be a JSON object with a `discord_id` key and an `updates` key.'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update the user
|
# Update the user
|
||||||
manager = UserManager()
|
manager = UserManager()
|
||||||
user = await manager.update_by_discord_id(discord_id, updates)
|
user = await manager.update_by_discord_id(discord_id, updates)
|
||||||
|
|
||||||
|
@ -126,23 +123,9 @@ async def run_checks(incoming_request: fastapi.Request):
|
||||||
if auth_error:
|
if auth_error:
|
||||||
return auth_error
|
return auth_error
|
||||||
|
|
||||||
try:
|
|
||||||
chat = await checks.client.test_chat()
|
|
||||||
except Exception:
|
|
||||||
chat = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
moderation = await checks.client.test_api_moderation()
|
|
||||||
except Exception:
|
|
||||||
moderation = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
models = await checks.client.test_models()
|
|
||||||
except Exception:
|
|
||||||
models = None
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'chat/completions': chat,
|
'library': await checks.client.test_library(),
|
||||||
'models': models,
|
'library_moderation': await checks.client.test_library_moderation(),
|
||||||
'moderations': moderation,
|
'api_moderation': await checks.client.test_api_moderation(),
|
||||||
|
'models': await checks.client.test_models()
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,8 +36,7 @@ async def root():
|
||||||
'hi': 'Welcome to the Nova API!',
|
'hi': 'Welcome to the Nova API!',
|
||||||
'learn_more_here': 'https://nova-oss.com',
|
'learn_more_here': 'https://nova-oss.com',
|
||||||
'github': 'https://github.com/novaoss/nova-api',
|
'github': 'https://github.com/novaoss/nova-api',
|
||||||
'core_api_docs_for_nova_developers': '/docs',
|
'core_api_docs_for_nova_developers': '/docs'
|
||||||
'ping': 'pong'
|
|
||||||
}
|
}
|
||||||
|
|
||||||
app.add_route('/v1/{path:path}', transfer.handle, ['GET', 'POST', 'PUT', 'DELETE', 'PATCH'])
|
app.add_route('/v1/{path:path}', transfer.handle, ['GET', 'POST', 'PUT', 'DELETE', 'PATCH'])
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
"""This module contains functions for checking if a message violates the moderation policy."""
|
"""This module contains functions for checking if a message violates the moderation policy."""
|
||||||
|
|
||||||
import time
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
|
@ -30,8 +29,6 @@ async def is_policy_violated(inp: Union[str, list]) -> bool:
|
||||||
else:
|
else:
|
||||||
text = '\n'.join(inp)
|
text = '\n'.join(inp)
|
||||||
|
|
||||||
print(f'[i] checking moderation for {text}')
|
|
||||||
|
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
req = await load_balancing.balance_organic_request(
|
req = await load_balancing.balance_organic_request(
|
||||||
{
|
{
|
||||||
|
@ -39,11 +36,9 @@ async def is_policy_violated(inp: Union[str, list]) -> bool:
|
||||||
'payload': {'input': text}
|
'payload': {'input': text}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
print(f'[i] moderation request sent to {req["url"]}')
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession(connector=proxies.get_proxy().connector) as session:
|
async with aiohttp.ClientSession(connector=proxies.get_proxy().connector) as session:
|
||||||
try:
|
try:
|
||||||
start = time.perf_counter()
|
|
||||||
async with session.request(
|
async with session.request(
|
||||||
method=req.get('method', 'POST'),
|
method=req.get('method', 'POST'),
|
||||||
url=req['url'],
|
url=req['url'],
|
||||||
|
@ -56,19 +51,14 @@ async def is_policy_violated(inp: Union[str, list]) -> bool:
|
||||||
) as res:
|
) as res:
|
||||||
res.raise_for_status()
|
res.raise_for_status()
|
||||||
json_response = await res.json()
|
json_response = await res.json()
|
||||||
print(json_response)
|
|
||||||
|
|
||||||
categories = json_response['results'][0]['category_scores']
|
categories = json_response['results'][0]['category_scores']
|
||||||
|
|
||||||
print(f'[i] moderation check took {time.perf_counter() - start:.2f}s')
|
|
||||||
|
|
||||||
if json_response['results'][0]['flagged']:
|
if json_response['results'][0]['flagged']:
|
||||||
return max(categories, key=categories.get)
|
return max(categories, key=categories.get)
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
|
||||||
if '401' in str(exc):
|
if '401' in str(exc):
|
||||||
await provider_auth.invalidate_key(req.get('provider_auth'))
|
await provider_auth.invalidate_key(req.get('provider_auth'))
|
||||||
print('[!] moderation error:', type(exc), exc)
|
print('[!] moderation error:', type(exc), exc)
|
||||||
|
|
|
@ -106,22 +106,6 @@ async def stream(
|
||||||
# We haven't done any requests as of right now, everything until now was just preparation
|
# We haven't done any requests as of right now, everything until now was just preparation
|
||||||
# Here, we process the request
|
# Here, we process the request
|
||||||
async with aiohttp.ClientSession(connector=proxies.get_proxy().connector) as session:
|
async with aiohttp.ClientSession(connector=proxies.get_proxy().connector) as session:
|
||||||
try:
|
|
||||||
async with session.get(
|
|
||||||
url='https://checkip.amazonaws.com',
|
|
||||||
timeout=aiohttp.ClientTimeout(
|
|
||||||
connect=3,
|
|
||||||
total=float(os.getenv('TRANSFER_TIMEOUT', '5'))
|
|
||||||
)
|
|
||||||
) as response:
|
|
||||||
for actual_ip in os.getenv('ACTUAL_IPS', '').split(' '):
|
|
||||||
if actual_ip in await response.text():
|
|
||||||
raise ValueError(f'Proxy {response.text()} is transparent!')
|
|
||||||
|
|
||||||
except Exception as exc:
|
|
||||||
print(f'[!] proxy {proxies.get_proxy()} error - ({type(exc)} {exc})')
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with session.request(
|
async with session.request(
|
||||||
method=target_request.get('method', 'POST'),
|
method=target_request.get('method', 'POST'),
|
||||||
|
@ -136,9 +120,6 @@ async def stream(
|
||||||
total=float(os.getenv('TRANSFER_TIMEOUT', '120'))
|
total=float(os.getenv('TRANSFER_TIMEOUT', '120'))
|
||||||
),
|
),
|
||||||
) as response:
|
) as response:
|
||||||
if response.status == 429:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if response.content_type == 'application/json':
|
if response.content_type == 'application/json':
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
|
|
||||||
|
@ -168,11 +149,7 @@ async def stream(
|
||||||
break
|
break
|
||||||
|
|
||||||
except ProxyError as exc:
|
except ProxyError as exc:
|
||||||
print('[!] aiohttp came up with a dumb excuse to not work again ("pRoXy ErRor")')
|
print('[!] Proxy error:', exc)
|
||||||
continue
|
|
||||||
|
|
||||||
except ConnectionResetError as exc:
|
|
||||||
print('[!] aiohttp came up with a dumb excuse to not work again ("cOnNeCtIoN rEsEt")')
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if is_chat and is_stream:
|
if is_chat and is_stream:
|
||||||
|
|
|
@ -26,7 +26,7 @@ async def handle(incoming_request):
|
||||||
Checks method, token amount, auth and cost along with if request is NSFW.
|
Checks method, token amount, auth and cost along with if request is NSFW.
|
||||||
"""
|
"""
|
||||||
users = UserManager()
|
users = UserManager()
|
||||||
path = incoming_request.url.path.replace('v1/v1', 'v1').replace('//', '/')
|
path = incoming_request.url.path
|
||||||
|
|
||||||
if '/models' in path:
|
if '/models' in path:
|
||||||
return fastapi.responses.JSONResponse(content=models_list)
|
return fastapi.responses.JSONResponse(content=models_list)
|
||||||
|
@ -62,11 +62,10 @@ async def handle(incoming_request):
|
||||||
cost = costs['chat-models'].get(payload.get('model'), cost)
|
cost = costs['chat-models'].get(payload.get('model'), cost)
|
||||||
|
|
||||||
policy_violation = False
|
policy_violation = False
|
||||||
if '/moderations' not in path:
|
if 'chat/completions' in path or ('input' in payload or 'prompt' in payload):
|
||||||
if '/chat/completions' in path or ('input' in payload or 'prompt' in payload):
|
inp = payload.get('input', payload.get('prompt', ''))
|
||||||
inp = payload.get('input', payload.get('prompt', ''))
|
if inp and len(inp) > 2 and not inp.isnumeric():
|
||||||
if inp and len(inp) > 2 and not inp.isnumeric():
|
policy_violation = await moderation.is_policy_violated(inp)
|
||||||
policy_violation = await moderation.is_policy_violated(inp)
|
|
||||||
|
|
||||||
if policy_violation:
|
if policy_violation:
|
||||||
return await errors.error(400, f'The request contains content which violates this model\'s policies for "{policy_violation}".', 'We currently don\'t support any NSFW models.')
|
return await errors.error(400, f'The request contains content which violates this model\'s policies for "{policy_violation}".', 'We currently don\'t support any NSFW models.')
|
||||||
|
|
|
@ -1,4 +1,2 @@
|
||||||
import client
|
import client
|
||||||
import asyncio
|
client.demo()
|
||||||
|
|
||||||
asyncio.run(client.demo())
|
|
||||||
|
|
115
checks/client.py
115
checks/client.py
|
@ -22,103 +22,82 @@ MESSAGES = [
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
api_endpoint = 'http://localhost:2332/v1'
|
api_endpoint = 'http://localhost:2332'
|
||||||
|
|
||||||
async def test_server():
|
async def test_server():
|
||||||
"""Tests if the API server is running."""
|
"""Tests if the API server is running."""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
request_start = time.perf_counter()
|
return httpx.get(f'{api_endpoint.replace("/v1", "")}').json()['status'] == 'ok'
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await client.get(
|
|
||||||
url=f'{api_endpoint.replace("/v1", "")}',
|
|
||||||
timeout=3
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
assert response.json()['ping'] == 'pong', 'The API did not return a correct response.'
|
|
||||||
except httpx.ConnectError as exc:
|
except httpx.ConnectError as exc:
|
||||||
raise ConnectionError(f'API is not running on port {api_endpoint}.') from exc
|
raise ConnectionError(f'API is not running on port {api_endpoint}.') from exc
|
||||||
|
|
||||||
else:
|
async def test_api(model: str=MODEL, messages: List[dict]=None) -> dict:
|
||||||
return time.perf_counter() - request_start
|
|
||||||
|
|
||||||
async def test_chat(model: str=MODEL, messages: List[dict]=None) -> dict:
|
|
||||||
"""Tests an API api_endpoint."""
|
"""Tests an API api_endpoint."""
|
||||||
|
|
||||||
json_data = {
|
json_data = {
|
||||||
'model': model,
|
'model': model,
|
||||||
'messages': messages or MESSAGES,
|
'messages': messages or MESSAGES,
|
||||||
'stream': False
|
'stream': True,
|
||||||
}
|
}
|
||||||
|
|
||||||
request_start = time.perf_counter()
|
response = httpx.post(
|
||||||
|
url=f'{api_endpoint}/chat/completions',
|
||||||
|
headers=HEADERS,
|
||||||
|
json=json_data,
|
||||||
|
timeout=20
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
return response.text
|
||||||
response = await client.post(
|
|
||||||
url=f'{api_endpoint}/chat/completions',
|
|
||||||
headers=HEADERS,
|
|
||||||
json=json_data,
|
|
||||||
timeout=10,
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
assert '2' in response.json()['choices'][0]['message']['content'], 'The API did not return a correct response.'
|
async def test_library():
|
||||||
return time.perf_counter() - request_start
|
|
||||||
|
|
||||||
async def test_library_chat():
|
|
||||||
"""Tests if the api_endpoint is working with the OpenAI Python library."""
|
"""Tests if the api_endpoint is working with the OpenAI Python library."""
|
||||||
|
|
||||||
request_start = time.perf_counter()
|
|
||||||
completion = openai.ChatCompletion.create(
|
completion = openai.ChatCompletion.create(
|
||||||
model=MODEL,
|
model=MODEL,
|
||||||
messages=MESSAGES
|
messages=MESSAGES
|
||||||
)
|
)
|
||||||
|
|
||||||
assert '2' in completion.choices[0]['message']['content'], 'The API did not return a correct response.'
|
print(completion)
|
||||||
return time.perf_counter() - request_start
|
|
||||||
|
return completion['choices'][0]['message']['content']
|
||||||
|
|
||||||
|
async def test_library_moderation():
|
||||||
|
try:
|
||||||
|
return openai.Moderation.create('I wanna kill myself, I wanna kill myself; It\'s all I hear right now, it\'s all I hear right now')
|
||||||
|
except openai.error.InvalidRequestError:
|
||||||
|
return True
|
||||||
|
|
||||||
async def test_models():
|
async def test_models():
|
||||||
"""Tests the models endpoint."""
|
response = httpx.get(
|
||||||
|
url=f'{api_endpoint}/models',
|
||||||
request_start = time.perf_counter()
|
headers=HEADERS,
|
||||||
async with httpx.AsyncClient() as client:
|
timeout=5
|
||||||
response = await client.get(
|
)
|
||||||
url=f'{api_endpoint}/models',
|
response.raise_for_status()
|
||||||
headers=HEADERS,
|
return response.json()
|
||||||
timeout=3
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
res = response.json()
|
|
||||||
|
|
||||||
all_models = [model['id'] for model in res['data']]
|
|
||||||
|
|
||||||
assert 'gpt-3.5-turbo' in all_models, 'The model gpt-3.5-turbo is not present in the models endpoint.'
|
|
||||||
return time.perf_counter() - request_start
|
|
||||||
|
|
||||||
async def test_api_moderation() -> dict:
|
async def test_api_moderation() -> dict:
|
||||||
"""Tests the moderation endpoint."""
|
"""Tests an API api_endpoint."""
|
||||||
|
|
||||||
request_start = time.perf_counter()
|
response = httpx.get(
|
||||||
async with httpx.AsyncClient() as client:
|
url=f'{api_endpoint}/moderations',
|
||||||
response = await client.post(
|
headers=HEADERS,
|
||||||
url=f'{api_endpoint}/moderations',
|
timeout=20
|
||||||
headers=HEADERS,
|
)
|
||||||
timeout=5,
|
response.raise_for_status()
|
||||||
json={'input': 'fuck you, die'}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.json()['results'][0]['flagged'] == True, 'Profanity not detected'
|
return response.text
|
||||||
return time.perf_counter() - request_start
|
|
||||||
|
|
||||||
# ==========================================================================================
|
# ==========================================================================================
|
||||||
|
|
||||||
async def demo():
|
def demo():
|
||||||
"""Runs all tests."""
|
"""Runs all tests."""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for _ in range(30):
|
for _ in range(30):
|
||||||
if await test_server():
|
if test_server():
|
||||||
break
|
break
|
||||||
|
|
||||||
print('Waiting until API Server is started up...')
|
print('Waiting until API Server is started up...')
|
||||||
|
@ -126,17 +105,17 @@ async def demo():
|
||||||
else:
|
else:
|
||||||
raise ConnectionError('API Server is not running.')
|
raise ConnectionError('API Server is not running.')
|
||||||
|
|
||||||
print('[lightblue]Checking if the API works...')
|
print('[lightblue]Running a api endpoint to see if requests can go through...')
|
||||||
print(await test_chat())
|
print(asyncio.run(test_api('gpt-3.5-turbo')))
|
||||||
|
|
||||||
print('[lightblue]Checking if the API works with the Python library...')
|
print('[lightblue]Checking if the API works with the python library...')
|
||||||
print(await test_library_chat())
|
print(asyncio.run(test_library()))
|
||||||
|
|
||||||
print('[lightblue]Checking if the moderation endpoint works...')
|
print('[lightblue]Checking if the moderation endpoint works...')
|
||||||
print(await test_api_moderation())
|
print(asyncio.run(test_library_moderation()))
|
||||||
|
|
||||||
print('[lightblue]Checking the models endpoint...')
|
print('[lightblue]Checking the /v1/models endpoint...')
|
||||||
print(await test_models())
|
print(asyncio.run(test_models()))
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
print('[red]Error: ' + str(exc))
|
print('[red]Error: ' + str(exc))
|
||||||
|
@ -152,4 +131,4 @@ HEADERS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
asyncio.run(demo())
|
demo()
|
||||||
|
|
|
@ -23,4 +23,4 @@ if 'prod' in sys.argv:
|
||||||
port = 2333
|
port = 2333
|
||||||
dev = False
|
dev = False
|
||||||
|
|
||||||
os.system(f'cd api && uvicorn main:app{" --reload" if dev else ""} --host 0.0.0.0 --port {port}')
|
os.system(f'cd api && uvicorn main:app{" --reload" if dev else ""} --host 0.0.0.0 --port {port} & python tests')
|
||||||
|
|
Loading…
Reference in a new issue