mirror of
https://github.com/NovaOSS/nova-api.git
synced 2024-11-25 14:43:58 +01:00
moderation is done yay
This commit is contained in:
parent
558ea89722
commit
83d57307cc
22
api/core.py
22
api/core.py
|
@ -32,7 +32,7 @@ async def get_users(discord_id: int, incoming_request: fastapi.Request):
|
|||
|
||||
return user
|
||||
|
||||
def new_user_webhook(user: dict) -> None:
|
||||
async def new_user_webhook(user: dict) -> None:
|
||||
dhook = Webhook(os.getenv('DISCORD_WEBHOOK__USER_CREATED'))
|
||||
|
||||
embed = Embed(
|
||||
|
@ -40,7 +40,7 @@ def new_user_webhook(user: dict) -> None:
|
|||
color=0x90ee90,
|
||||
)
|
||||
|
||||
embed.add_field(name='ID', value=user['_id'], inline=False)
|
||||
embed.add_field(name='ID', value=str(user['_id']), inline=False)
|
||||
embed.add_field(name='Discord', value=user['auth']['discord'])
|
||||
embed.add_field(name='Github', value=user['auth']['github'])
|
||||
|
||||
|
@ -60,15 +60,17 @@ async def create_user(incoming_request: fastapi.Request):
|
|||
return fastapi.Response(status_code=400, content='Invalid or no payload received.')
|
||||
|
||||
user = await users.create(discord_id)
|
||||
new_user_webhook(user)
|
||||
await new_user_webhook(user)
|
||||
|
||||
return user
|
||||
|
||||
if __name__ == '__main__':
|
||||
new_user_webhook({
|
||||
'_id': 'JUST_A_TEST_IGNORE_ME',
|
||||
'auth': {
|
||||
'discord': 123,
|
||||
'github': 'abc'
|
||||
}
|
||||
})
|
||||
# new_user_webhook({
|
||||
# '_id': 'JUST_A_TEST_IGNORE_ME',
|
||||
# 'auth': {
|
||||
# 'discord': 123,
|
||||
# 'github': 'abc'
|
||||
# }
|
||||
# })
|
||||
|
||||
pass
|
||||
|
|
|
@ -8,10 +8,24 @@ from helpers import network
|
|||
|
||||
load_dotenv()
|
||||
|
||||
def _get_mongo(collection_name: str):
|
||||
UA_SIMPLIFY = {
|
||||
'Windows NT': 'W',
|
||||
'Mozilla/5.0': 'M',
|
||||
'Win64; x64': '64',
|
||||
'Safari/537.36': 'S',
|
||||
'AppleWebKit/537.36 (KHTML, like Gecko)': 'K',
|
||||
}
|
||||
|
||||
async def _get_mongo(collection_name: str):
|
||||
return AsyncIOMotorClient(os.getenv('MONGO_URI'))['nova-core'][collection_name]
|
||||
|
||||
async def replacer(text: str, dict_: dict) -> str:
|
||||
for k, v in dict_.items():
|
||||
text = text.replace(k, v)
|
||||
return text
|
||||
|
||||
async def log_api_request(user: dict, incoming_request, target_url: str):
|
||||
db = await _get_mongo('logs')
|
||||
payload = {}
|
||||
|
||||
try:
|
||||
|
@ -22,19 +36,23 @@ async def log_api_request(user: dict, incoming_request, target_url: str):
|
|||
|
||||
last_prompt = None
|
||||
if 'messages' in payload:
|
||||
last_prompt = payload['messages'][-1]['content']
|
||||
last_prompt = payload['messages'][-1]['content'][:50]
|
||||
|
||||
if len(last_prompt) == 50:
|
||||
last_prompt += '...'
|
||||
|
||||
model = payload.get('model')
|
||||
ip_address = await network.get_ip(incoming_request)
|
||||
useragent = await replacer(incoming_request.headers.get('User-Agent'), UA_SIMPLIFY)
|
||||
|
||||
new_log_item = {
|
||||
'timestamp': time.time(),
|
||||
'method': incoming_request.method,
|
||||
'path': incoming_request.url.path,
|
||||
'user_id': user['_id'],
|
||||
'user_id': str(user['_id']),
|
||||
'security': {
|
||||
'ip': ip_address,
|
||||
'useragent': incoming_request.headers.get('User-Agent')
|
||||
'useragent': useragent,
|
||||
},
|
||||
'details': {
|
||||
'model': model,
|
||||
|
@ -43,21 +61,25 @@ async def log_api_request(user: dict, incoming_request, target_url: str):
|
|||
}
|
||||
}
|
||||
|
||||
inserted = await _get_mongo('logs').insert_one(new_log_item)
|
||||
log_item = await _get_mongo('logs').find_one({'_id': inserted.inserted_id})
|
||||
inserted = await db.insert_one(new_log_item)
|
||||
log_item = await db.find_one({'_id': inserted.inserted_id})
|
||||
return log_item
|
||||
|
||||
async def by_id(log_id: str):
|
||||
return await _get_mongo('logs').find_one({'_id': log_id})
|
||||
db = await _get_mongo('logs')
|
||||
return await db.find_one({'_id': log_id})
|
||||
|
||||
async def by_user_id(user_id: str):
|
||||
return await _get_mongo('logs').find({'user_id': user_id})
|
||||
db = await _get_mongo('logs')
|
||||
return await db.find({'user_id': user_id})
|
||||
|
||||
async def delete_by_id(log_id: str):
|
||||
return await _get_mongo('logs').delete_one({'_id': log_id})
|
||||
db = await _get_mongo('logs')
|
||||
return await db.delete_one({'_id': log_id})
|
||||
|
||||
async def delete_by_user_id(user_id: str):
|
||||
return await _get_mongo('logs').delete_many({'user_id': user_id})
|
||||
db = await _get_mongo('logs')
|
||||
return await db.delete_many({'user_id': user_id})
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
|
|
|
@ -8,34 +8,41 @@ from motor.motor_asyncio import AsyncIOMotorClient
|
|||
|
||||
load_dotenv()
|
||||
|
||||
def _get_mongo(collection_name: str):
|
||||
async def _get_mongo(collection_name: str):
|
||||
return AsyncIOMotorClient(os.getenv('MONGO_URI'))['nova-core'][collection_name]
|
||||
|
||||
async def add_date():
|
||||
date = datetime.datetime.now(pytz.timezone('GMT')).strftime('%Y.%m.%d')
|
||||
year, month, day = date.split('.')
|
||||
|
||||
await _get_mongo('stats').update_one({}, {'$inc': {f'dates.{year}.{month}.{day}': 1}}, upsert=True)
|
||||
db = await _get_mongo('stats')
|
||||
await db.update_one({}, {'$inc': {f'dates.{year}.{month}.{day}': 1}}, upsert=True)
|
||||
|
||||
async def add_ip_address(ip_address: str):
|
||||
ip_address = ip_address.replace('.', '_')
|
||||
await _get_mongo('stats').update_one({}, {'$inc': {f'ips.{ip_address}': 1}}, upsert=True)
|
||||
db = await _get_mongo('stats')
|
||||
await db.update_one({}, {'$inc': {f'ips.{ip_address}': 1}}, upsert=True)
|
||||
|
||||
async def add_target(url: str):
|
||||
await _get_mongo('stats').update_one({}, {'$inc': {f'targets.{url}': 1}}, upsert=True)
|
||||
db = await _get_mongo('stats')
|
||||
await db.update_one({}, {'$inc': {f'targets.{url}': 1}}, upsert=True)
|
||||
|
||||
async def add_tokens(tokens: int, model: str):
|
||||
await _get_mongo('stats').update_one({}, {'$inc': {f'tokens.{model}': tokens}}, upsert=True)
|
||||
db = await _get_mongo('stats')
|
||||
await db.update_one({}, {'$inc': {f'tokens.{model}': tokens}}, upsert=True)
|
||||
|
||||
async def add_model(model: str):
|
||||
await _get_mongo('stats').update_one({}, {'$inc': {f'models.{model}': 1}}, upsert=True)
|
||||
db = await _get_mongo('stats')
|
||||
await db.update_one({}, {'$inc': {f'models.{model}': 1}}, upsert=True)
|
||||
|
||||
async def add_path(path: str):
|
||||
path = path.replace('/', '_')
|
||||
await _get_mongo('stats').update_one({}, {'$inc': {f'paths.{path}': 1}}, upsert=True)
|
||||
db = await _get_mongo('stats')
|
||||
await db.update_one({}, {'$inc': {f'paths.{path}': 1}}, upsert=True)
|
||||
|
||||
async def get_value(obj_filter):
|
||||
return await _get_mongo('stats').find_one({obj_filter})
|
||||
db = await _get_mongo('stats')
|
||||
return await db.find_one({obj_filter})
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(add_date())
|
||||
|
|
|
@ -12,7 +12,7 @@ load_dotenv()
|
|||
with open('config/credits.yml', encoding='utf8') as f:
|
||||
credits_config = yaml.safe_load(f)
|
||||
|
||||
def _get_mongo(collection_name: str):
|
||||
async def _get_mongo(collection_name: str):
|
||||
return AsyncIOMotorClient(os.getenv('MONGO_URI'))['nova-core'][collection_name]
|
||||
|
||||
async def create(discord_id: int=0) -> dict:
|
||||
|
@ -46,23 +46,28 @@ async def create(discord_id: int=0) -> dict:
|
|||
return user
|
||||
|
||||
async def by_id(user_id: str):
|
||||
return await _get_mongo('users').find_one({'_id': user_id})
|
||||
db = await _get_mongo('users')
|
||||
return await db.find_one({'_id': user_id})
|
||||
|
||||
async def by_discord_id(discord_id: str):
|
||||
return await _get_mongo('users').find_one({'auth.discord': discord_id})
|
||||
db = await _get_mongo('users')
|
||||
return await db.find_one({'auth.discord': discord_id})
|
||||
|
||||
async def by_api_key(key: str):
|
||||
return await _get_mongo('users').find_one({'api_key': key})
|
||||
db = await _get_mongo('users')
|
||||
return await db.find_one({'api_key': key})
|
||||
|
||||
async def update_by_id(user_id: str, update):
|
||||
return await _get_mongo('users').update_one({'_id': user_id}, update)
|
||||
db = await _get_mongo('users')
|
||||
return await db.update_one({'_id': user_id}, update)
|
||||
|
||||
async def update_by_filter(obj_filter, update):
|
||||
return await _get_mongo('users').update_one(obj_filter, update)
|
||||
db = await _get_mongo('users')
|
||||
return await db.update_one(obj_filter, update)
|
||||
|
||||
async def delete(user_id: str):
|
||||
await _get_mongo('users').delete_one({'_id': user_id})
|
||||
|
||||
db = await _get_mongo('users')
|
||||
await db.delete_one({'_id': user_id})
|
||||
|
||||
async def demo():
|
||||
user = await create(69420)
|
||||
|
|
|
@ -17,7 +17,7 @@ async def create_chat_id() -> str:
|
|||
|
||||
return f'chatcmpl-{chat_id}'
|
||||
|
||||
def create_chat_chunk(chat_id: str, model: str, content=None) -> dict:
|
||||
async def create_chat_chunk(chat_id: str, model: str, content=None) -> dict:
|
||||
content = content or {}
|
||||
|
||||
delta = {}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
import starlette
|
||||
|
||||
def error(code: int, message: str, tip: str) -> starlette.responses.Response:
|
||||
async def error(code: int, message: str, tip: str) -> starlette.responses.Response:
|
||||
info = {'error': {
|
||||
'code': code,
|
||||
'message': message,
|
||||
|
@ -12,7 +12,7 @@ def error(code: int, message: str, tip: str) -> starlette.responses.Response:
|
|||
|
||||
return starlette.responses.Response(status_code=code, content=json.dumps(info))
|
||||
|
||||
def yield_error(code: int, message: str, tip: str) -> str:
|
||||
async def yield_error(code: int, message: str, tip: str) -> str:
|
||||
return json.dumps({
|
||||
'code': code,
|
||||
'message': message,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import tiktoken
|
||||
|
||||
def count_for_messages(messages: list, model: str='gpt-3.5-turbo-0613') -> int:
|
||||
async def count_for_messages(messages: list, model: str='gpt-3.5-turbo-0613') -> int:
|
||||
"""Return the number of tokens used by a list of messages."""
|
||||
|
||||
try:
|
||||
|
|
|
@ -11,13 +11,15 @@ provider_modules = [
|
|||
providers.closed4
|
||||
]
|
||||
|
||||
def _get_module_name(module) -> str:
|
||||
async def _get_module_name(module) -> str:
|
||||
name = module.__name__
|
||||
if '.' in name:
|
||||
return name.split('.')[-1]
|
||||
return name
|
||||
|
||||
async def balance_chat_request(payload: dict) -> dict:
|
||||
"""Load balance the chat completion request between chat providers."""
|
||||
|
||||
providers_available = []
|
||||
|
||||
for provider_module in provider_modules:
|
||||
|
@ -34,20 +36,37 @@ async def balance_chat_request(payload: dict) -> dict:
|
|||
|
||||
provider = random.choice(providers_available)
|
||||
target = provider.chat_completion(**payload)
|
||||
target['module'] = _get_module_name(provider)
|
||||
|
||||
module_name = await _get_module_name(provider)
|
||||
target['module'] = module_name
|
||||
|
||||
return target
|
||||
|
||||
async def balance_organic_request(request: dict) -> dict:
|
||||
"""Load balnace to non-chat completion request between other "organic" providers which respond in the desired format already."""
|
||||
|
||||
providers_available = []
|
||||
|
||||
if not request.get('headers'):
|
||||
request['headers'] = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
for provider_module in provider_modules:
|
||||
if provider_module.ORGANIC:
|
||||
providers_available.append(provider_module)
|
||||
if not provider_module.ORGANIC:
|
||||
continue
|
||||
|
||||
if '/moderations' in request['path']:
|
||||
if not provider_module.MODERATIONS:
|
||||
continue
|
||||
|
||||
providers_available.append(provider_module)
|
||||
|
||||
provider = random.choice(providers_available)
|
||||
target = provider.organify(request)
|
||||
target['module'] = _get_module_name(provider)
|
||||
|
||||
module_name = await _get_module_name(provider)
|
||||
target['module'] = module_name
|
||||
|
||||
return target
|
||||
|
||||
|
|
|
@ -1,18 +1,49 @@
|
|||
import os
|
||||
import asyncio
|
||||
import openai as closedai
|
||||
|
||||
from typing import Union
|
||||
from dotenv import load_dotenv
|
||||
import aiohttp
|
||||
import proxies
|
||||
import load_balancing
|
||||
|
||||
load_dotenv()
|
||||
async def is_safe(inp) -> bool:
|
||||
text = inp
|
||||
|
||||
closedai.api_key = os.getenv('LEGIT_CLOSEDAI_KEY')
|
||||
if isinstance(inp, list):
|
||||
text = ''
|
||||
if isinstance(inp[0], dict):
|
||||
for msg in inp:
|
||||
text += msg['content'] + '\n'
|
||||
|
||||
async def is_safe(text: Union[str, list]) -> bool:
|
||||
return closedai.Moderation.create(
|
||||
input=text,
|
||||
)['results'][0]['flagged']
|
||||
else:
|
||||
text = '\n'.join(inp)
|
||||
|
||||
for _ in range(3):
|
||||
req = await load_balancing.balance_organic_request(
|
||||
{
|
||||
'path': '/v1/moderations',
|
||||
'payload': {'input': text}
|
||||
}
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession(connector=proxies.default_proxy.connector) as session:
|
||||
try:
|
||||
async with session.request(
|
||||
method=req.get('method', 'POST'),
|
||||
url=req['url'],
|
||||
data=req.get('data'),
|
||||
json=req.get('payload'),
|
||||
headers=req.get('headers'),
|
||||
cookies=req.get('cookies'),
|
||||
ssl=False,
|
||||
timeout=aiohttp.ClientTimeout(total=5),
|
||||
) as res:
|
||||
res.raise_for_status()
|
||||
|
||||
json_response = await res.json()
|
||||
|
||||
return not json_response['results'][0]['flagged']
|
||||
except Exception as exc:
|
||||
print('[!] moderation error:', type(exc), exc)
|
||||
continue
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(is_safe('Hello'))
|
||||
print(asyncio.run(is_safe('I wanna kill myself')))
|
||||
|
|
102
api/streaming.py
102
api/streaming.py
|
@ -49,19 +49,26 @@ async def stream(
|
|||
if is_chat and is_stream:
|
||||
chat_id = await chat.create_chat_id()
|
||||
|
||||
yield chat.create_chat_chunk(
|
||||
chunk = await chat.create_chat_chunk(
|
||||
chat_id=chat_id,
|
||||
model=model,
|
||||
content=chat.CompletionStart
|
||||
)
|
||||
yield chunk
|
||||
|
||||
yield chat.create_chat_chunk(
|
||||
chunk = await chat.create_chat_chunk(
|
||||
chat_id=chat_id,
|
||||
model=model,
|
||||
content=None
|
||||
)
|
||||
|
||||
for _ in range(3):
|
||||
yield chunk
|
||||
|
||||
json_response = {
|
||||
'error': 'No JSON response could be received'
|
||||
}
|
||||
|
||||
for _ in range(5):
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
@ -81,18 +88,17 @@ async def stream(
|
|||
webhook = dhooks.Webhook(os.getenv('DISCORD_WEBHOOK__API_ISSUE'))
|
||||
|
||||
webhook.send(content=f'API Issue: **`{exc}`**\nhttps://i.imgflip.com/7uv122.jpg')
|
||||
yield errors.yield_error(
|
||||
error = errors.yield_error(
|
||||
500,
|
||||
'Sorry, the API has no working keys anymore.',
|
||||
'The admins have been messaged automatically.'
|
||||
)
|
||||
yield error
|
||||
return
|
||||
|
||||
for k, v in target_request.get('headers', {}).items():
|
||||
headers[k] = v
|
||||
|
||||
json.dump(target_request, open('api.log.json', 'w'), indent=4)
|
||||
|
||||
async with aiohttp.ClientSession(connector=proxies.default_proxy.connector) as session:
|
||||
try:
|
||||
async with session.request(
|
||||
|
@ -109,52 +115,38 @@ async def stream(
|
|||
|
||||
timeout=aiohttp.ClientTimeout(total=float(os.getenv('TRANSFER_TIMEOUT', '120'))),
|
||||
) as response:
|
||||
print(5)
|
||||
|
||||
if not is_stream:
|
||||
json_response = await response.json()
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except Exception as exc:
|
||||
if 'Too Many Requests' in str(exc):
|
||||
print(429)
|
||||
continue
|
||||
|
||||
if user and incoming_request:
|
||||
await logs.log_api_request(
|
||||
user=user,
|
||||
incoming_request=incoming_request,
|
||||
target_url=target_request['url']
|
||||
)
|
||||
|
||||
if credits_cost and user:
|
||||
await users.update_by_id(user['_id'], {
|
||||
'$inc': {'credits': -credits_cost}
|
||||
})
|
||||
|
||||
print(6)
|
||||
|
||||
if is_stream:
|
||||
try:
|
||||
async for chunk in response.content.iter_any():
|
||||
send = False
|
||||
chunk = f'{chunk.decode("utf8")}\n\n'
|
||||
chunk = chunk.replace(os.getenv('MAGIC_WORD', 'novaOSScheckKeyword'), payload['model'])
|
||||
# chunk = chunk.replace(os.getenv('MAGIC_USER_WORD', 'novaOSSuserKeyword'), user['_id'])
|
||||
print(chunk)
|
||||
chunk = chunk.replace(os.getenv('MAGIC_USER_WORD', 'novaOSSuserKeyword'), str(user['_id']))
|
||||
|
||||
if not chunk.strip():
|
||||
send = False
|
||||
|
||||
if is_chat and '{' in chunk:
|
||||
data = json.loads(chunk.split('data: ')[1])
|
||||
chunk = chunk.replace(data['id'], chat_id)
|
||||
send = True
|
||||
|
||||
if target_request['module'] == 'twa' and data.get('text'):
|
||||
chunk = chat.create_chat_chunk(
|
||||
chunk = await chat.create_chat_chunk(
|
||||
chat_id=chat_id,
|
||||
model=model,
|
||||
content=['text']
|
||||
)
|
||||
|
||||
if not data['choices'][0]['delta']:
|
||||
send = False
|
||||
|
||||
|
@ -162,24 +154,18 @@ async def stream(
|
|||
send = False
|
||||
|
||||
if send:
|
||||
yield chunk
|
||||
final_chunk = chunk.strip().replace('data: [DONE]', '') + '\n\n'
|
||||
yield final_chunk
|
||||
|
||||
except Exception as exc:
|
||||
if 'Connection closed' in str(exc):
|
||||
print('connection closed: ', exc)
|
||||
continue
|
||||
|
||||
if not demo_mode:
|
||||
ip_address = await network.get_ip(incoming_request)
|
||||
|
||||
await stats.add_date()
|
||||
await stats.add_ip_address(ip_address)
|
||||
await stats.add_path(path)
|
||||
await stats.add_target(target_request['url'])
|
||||
|
||||
if is_chat:
|
||||
await stats.add_model(model)
|
||||
await stats.add_tokens(input_tokens, model)
|
||||
error = errors.yield_error(
|
||||
500,
|
||||
'Sorry, there was an issue with the connection.',
|
||||
'Please first check if the issue on your end. If this error repeats, please don\'t heistate to contact the staff!.'
|
||||
)
|
||||
yield error
|
||||
return
|
||||
|
||||
break
|
||||
|
||||
|
@ -187,21 +173,43 @@ async def stream(
|
|||
print('proxy error')
|
||||
continue
|
||||
|
||||
print(3)
|
||||
|
||||
if is_chat and is_stream:
|
||||
chat_chunk = chat.create_chat_chunk(
|
||||
chunk = await chat.create_chat_chunk(
|
||||
chat_id=chat_id,
|
||||
model=model,
|
||||
content=chat.CompletionStop
|
||||
)
|
||||
data = json.dumps(chat_chunk)
|
||||
yield chunk
|
||||
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
if not is_stream:
|
||||
json_response = await response.json()
|
||||
yield json_response.encode('utf8')
|
||||
yield json.dumps(json_response)
|
||||
|
||||
# DONE =========================================================
|
||||
|
||||
if user and incoming_request:
|
||||
await logs.log_api_request(
|
||||
user=user,
|
||||
incoming_request=incoming_request,
|
||||
target_url=target_request['url']
|
||||
)
|
||||
|
||||
if credits_cost and user:
|
||||
await users.update_by_id(user['_id'], {
|
||||
'$inc': {'credits': -credits_cost}
|
||||
})
|
||||
|
||||
ip_address = await network.get_ip(incoming_request)
|
||||
|
||||
await stats.add_date()
|
||||
await stats.add_ip_address(ip_address)
|
||||
await stats.add_path(path)
|
||||
await stats.add_target(target_request['url'])
|
||||
|
||||
if is_chat:
|
||||
await stats.add_model(model)
|
||||
await stats.add_tokens(input_tokens, model)
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(stream())
|
||||
|
|
|
@ -4,11 +4,13 @@ import os
|
|||
import json
|
||||
import yaml
|
||||
import logging
|
||||
import fastapi
|
||||
import starlette
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import streaming
|
||||
import moderation
|
||||
|
||||
from db import logs, users
|
||||
from helpers import tokens, errors, exceptions
|
||||
|
@ -32,7 +34,8 @@ async def handle(incoming_request):
|
|||
|
||||
# METHOD
|
||||
if incoming_request.method not in ['GET', 'POST', 'PUT', 'DELETE', 'PATCH']:
|
||||
return errors.error(405, f'Method "{incoming_request.method}" is not allowed.', 'Change the request method to the correct one.')
|
||||
error = await errors.error(405, f'Method "{incoming_request.method}" is not allowed.', 'Change the request method to the correct one.')
|
||||
return error
|
||||
|
||||
# PAYLOAD
|
||||
try:
|
||||
|
@ -42,7 +45,7 @@ async def handle(incoming_request):
|
|||
|
||||
# TOKENS
|
||||
try:
|
||||
input_tokens = tokens.count_for_messages(payload['messages'])
|
||||
input_tokens = await tokens.count_for_messages(payload['messages'])
|
||||
except (KeyError, TypeError):
|
||||
input_tokens = 0
|
||||
|
||||
|
@ -50,7 +53,8 @@ async def handle(incoming_request):
|
|||
received_key = incoming_request.headers.get('Authorization')
|
||||
|
||||
if not received_key:
|
||||
return errors.error(401, 'No NovaAI API key given!', 'Add "Authorization: Bearer nv-..." to your request headers.')
|
||||
error = await errors.error(401, 'No NovaAI API key given!', 'Add "Authorization: Bearer nv-..." to your request headers.')
|
||||
return error
|
||||
|
||||
if received_key.startswith('Bearer '):
|
||||
received_key = received_key.split('Bearer ')[1]
|
||||
|
@ -59,38 +63,60 @@ async def handle(incoming_request):
|
|||
user = await users.by_api_key(received_key.strip())
|
||||
|
||||
if not user:
|
||||
return errors.error(401, 'Invalid NovaAI API key!', 'Create a new NovaOSS API key.')
|
||||
error = await errors.error(401, 'Invalid NovaAI API key!', 'Create a new NovaOSS API key.')
|
||||
return error
|
||||
|
||||
ban_reason = user['status']['ban_reason']
|
||||
if ban_reason:
|
||||
return errors.error(403, f'Your NovaAI account has been banned. Reason: "{ban_reason}".', 'Contact the staff for an appeal.')
|
||||
error = await errors.error(403, f'Your NovaAI account has been banned. Reason: "{ban_reason}".', 'Contact the staff for an appeal.')
|
||||
return error
|
||||
|
||||
if not user['status']['active']:
|
||||
return errors.error(418, 'Your NovaAI account is not active (paused).', 'Simply re-activate your account using a Discord command or the web panel.')
|
||||
error = await errors.error(418, 'Your NovaAI account is not active (paused).', 'Simply re-activate your account using a Discord command or the web panel.')
|
||||
return error
|
||||
|
||||
# COST
|
||||
costs = credits_config['costs']
|
||||
cost = costs['other']
|
||||
|
||||
is_safe = True
|
||||
|
||||
if 'chat/completions' in path:
|
||||
for model_name, model_cost in costs['chat-models'].items():
|
||||
if model_name in payload['model']:
|
||||
cost = model_cost
|
||||
|
||||
is_safe = await moderation.is_safe(payload['messages'])
|
||||
|
||||
else:
|
||||
inp = payload.get('input', payload.get('prompt'))
|
||||
|
||||
if inp and not '/moderations' in path:
|
||||
is_safe = await moderation.is_safe(inp)
|
||||
|
||||
if not is_safe:
|
||||
error = await errors.error(400, 'The request contains content which violates this model\'s policies.', 'We currently don\'t support any NSFW models.')
|
||||
return error
|
||||
|
||||
return
|
||||
|
||||
role_cost_multiplier = credits_config['bonuses'].get(user['role'], 1)
|
||||
cost = round(cost * role_cost_multiplier)
|
||||
|
||||
if user['credits'] < cost:
|
||||
return errors.error(429, 'Not enough credits.', 'Wait or earn more credits. Learn more on our website or Discord server.')
|
||||
error = await errors.error(429, 'Not enough credits.', 'Wait or earn more credits. Learn more on our website or Discord server.')
|
||||
return error
|
||||
|
||||
# READY
|
||||
|
||||
payload['user'] = str(user['_id'])
|
||||
# payload['user'] = str(user['_id'])
|
||||
|
||||
if 'chat/completions' in path and not payload.get('stream') is True:
|
||||
payload['stream'] = False
|
||||
|
||||
return starlette.responses.StreamingResponse(
|
||||
media_type = 'text/event-stream' if payload.get('stream', False) else 'application/json'
|
||||
|
||||
return fastapi.responses.StreamingResponse(
|
||||
content=streaming.stream(
|
||||
user=user,
|
||||
path=path,
|
||||
|
@ -99,5 +125,5 @@ async def handle(incoming_request):
|
|||
input_tokens=input_tokens,
|
||||
incoming_request=incoming_request,
|
||||
),
|
||||
media_type='text/event-stream' if payload.get('stream', False) else 'application/json'
|
||||
media_type=media_type
|
||||
)
|
||||
|
|
|
@ -2,11 +2,11 @@ async def get_all_users(client):
|
|||
users = client['nova-core']['users']
|
||||
return users
|
||||
|
||||
async def update_credits(users, settings = None):
|
||||
async def update_credits(users, settings=None):
|
||||
if not settings:
|
||||
users.update_many({}, {"$inc": {"credits": 250}})
|
||||
|
||||
else:
|
||||
for key, value in settings.items():
|
||||
users.update_many({'role': key}, {"$inc": {"credits": int(value)}})
|
||||
print(f"Updated {key} to {value}")
|
||||
print(f'Updated {key} to {value}')
|
|
@ -1,11 +1,13 @@
|
|||
import asyncio
|
||||
from settings import roles
|
||||
import autocredits
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
import pymongo
|
||||
|
||||
from settings import roles
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
CONNECTION_STRING = os.getenv("CONNECTION_STRING")
|
||||
|
@ -23,10 +25,9 @@ async def update_roles(users):
|
|||
async with session.get('http://localhost:50000/get_roles') as response:
|
||||
data = await response.json()
|
||||
except aiohttp.ClientError as e:
|
||||
print(f"Error: {e}")
|
||||
return
|
||||
|
||||
lvlroles = [f"lvl{lvl}" for lvl in range(10, 110, 10)] + ['']
|
||||
raise ValueError('Could not get roles') from exc
|
||||
|
||||
lvlroles = [f'lvl{lvl}' for lvl in range(10, 110, 10)] + ['']
|
||||
discord_users = data
|
||||
users = await autocredits.get_all_users(pymongo_client)
|
||||
|
||||
|
@ -41,11 +42,12 @@ async def update_roles(users):
|
|||
for role in lvlroles:
|
||||
if role in roles:
|
||||
bulk_updates.append(pymongo.UpdateOne({'auth.discord': int(discord)}, {'$set': {'role': role}}))
|
||||
print(f"Updated {id_} to {role}")
|
||||
print(f'Updated {id_} to {role}')
|
||||
break
|
||||
|
||||
if bulk_updates:
|
||||
with pymongo_client:
|
||||
users.bulk_write(bulk_updates)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
asyncio.run(main())
|
|
@ -23,7 +23,7 @@ MODEL = 'gpt-3.5-turbo'
|
|||
MESSAGES = [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': '1+1=',
|
||||
'content': 'fuck you',
|
||||
}
|
||||
]
|
||||
|
||||
|
@ -66,15 +66,10 @@ def test_library():
|
|||
|
||||
completion = closedai.ChatCompletion.create(
|
||||
model=MODEL,
|
||||
messages=MESSAGES,
|
||||
stream=True
|
||||
messages=MESSAGES
|
||||
)
|
||||
|
||||
for event in completion:
|
||||
try:
|
||||
print(event['choices'][0]['delta']['content'])
|
||||
except:
|
||||
print('-')
|
||||
return completion['choices'][0]['message']['content']
|
||||
|
||||
def test_library_moderation():
|
||||
return closedai.Moderation.create("I wanna kill myself, I wanna kill myself; It's all I hear right now, it's all I hear right now")
|
||||
|
@ -83,8 +78,8 @@ def test_all():
|
|||
"""Runs all tests."""
|
||||
|
||||
# print(test_server())
|
||||
print(test_api())
|
||||
# print(test_library())
|
||||
# print(test_api())
|
||||
print(test_library())
|
||||
# print(test_library_moderation())
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue