moderation is done yay

This commit is contained in:
nsde 2023-08-06 21:42:07 +02:00
parent 558ea89722
commit 83d57307cc
16 changed files with 250 additions and 133 deletions

View file

@ -32,7 +32,7 @@ async def get_users(discord_id: int, incoming_request: fastapi.Request):
return user return user
def new_user_webhook(user: dict) -> None: async def new_user_webhook(user: dict) -> None:
dhook = Webhook(os.getenv('DISCORD_WEBHOOK__USER_CREATED')) dhook = Webhook(os.getenv('DISCORD_WEBHOOK__USER_CREATED'))
embed = Embed( embed = Embed(
@ -40,7 +40,7 @@ def new_user_webhook(user: dict) -> None:
color=0x90ee90, color=0x90ee90,
) )
embed.add_field(name='ID', value=user['_id'], inline=False) embed.add_field(name='ID', value=str(user['_id']), inline=False)
embed.add_field(name='Discord', value=user['auth']['discord']) embed.add_field(name='Discord', value=user['auth']['discord'])
embed.add_field(name='Github', value=user['auth']['github']) embed.add_field(name='Github', value=user['auth']['github'])
@ -60,15 +60,17 @@ async def create_user(incoming_request: fastapi.Request):
return fastapi.Response(status_code=400, content='Invalid or no payload received.') return fastapi.Response(status_code=400, content='Invalid or no payload received.')
user = await users.create(discord_id) user = await users.create(discord_id)
new_user_webhook(user) await new_user_webhook(user)
return user return user
if __name__ == '__main__': if __name__ == '__main__':
new_user_webhook({ # new_user_webhook({
'_id': 'JUST_A_TEST_IGNORE_ME', # '_id': 'JUST_A_TEST_IGNORE_ME',
'auth': { # 'auth': {
'discord': 123, # 'discord': 123,
'github': 'abc' # 'github': 'abc'
} # }
}) # })
pass

View file

@ -8,10 +8,24 @@ from helpers import network
load_dotenv() load_dotenv()
def _get_mongo(collection_name: str): UA_SIMPLIFY = {
'Windows NT': 'W',
'Mozilla/5.0': 'M',
'Win64; x64': '64',
'Safari/537.36': 'S',
'AppleWebKit/537.36 (KHTML, like Gecko)': 'K',
}
async def _get_mongo(collection_name: str):
return AsyncIOMotorClient(os.getenv('MONGO_URI'))['nova-core'][collection_name] return AsyncIOMotorClient(os.getenv('MONGO_URI'))['nova-core'][collection_name]
async def replacer(text: str, dict_: dict) -> str:
for k, v in dict_.items():
text = text.replace(k, v)
return text
async def log_api_request(user: dict, incoming_request, target_url: str): async def log_api_request(user: dict, incoming_request, target_url: str):
db = await _get_mongo('logs')
payload = {} payload = {}
try: try:
@ -22,19 +36,23 @@ async def log_api_request(user: dict, incoming_request, target_url: str):
last_prompt = None last_prompt = None
if 'messages' in payload: if 'messages' in payload:
last_prompt = payload['messages'][-1]['content'] last_prompt = payload['messages'][-1]['content'][:50]
if len(last_prompt) == 50:
last_prompt += '...'
model = payload.get('model') model = payload.get('model')
ip_address = await network.get_ip(incoming_request) ip_address = await network.get_ip(incoming_request)
useragent = await replacer(incoming_request.headers.get('User-Agent'), UA_SIMPLIFY)
new_log_item = { new_log_item = {
'timestamp': time.time(), 'timestamp': time.time(),
'method': incoming_request.method, 'method': incoming_request.method,
'path': incoming_request.url.path, 'path': incoming_request.url.path,
'user_id': user['_id'], 'user_id': str(user['_id']),
'security': { 'security': {
'ip': ip_address, 'ip': ip_address,
'useragent': incoming_request.headers.get('User-Agent') 'useragent': useragent,
}, },
'details': { 'details': {
'model': model, 'model': model,
@ -43,21 +61,25 @@ async def log_api_request(user: dict, incoming_request, target_url: str):
} }
} }
inserted = await _get_mongo('logs').insert_one(new_log_item) inserted = await db.insert_one(new_log_item)
log_item = await _get_mongo('logs').find_one({'_id': inserted.inserted_id}) log_item = await db.find_one({'_id': inserted.inserted_id})
return log_item return log_item
async def by_id(log_id: str): async def by_id(log_id: str):
return await _get_mongo('logs').find_one({'_id': log_id}) db = await _get_mongo('logs')
return await db.find_one({'_id': log_id})
async def by_user_id(user_id: str): async def by_user_id(user_id: str):
return await _get_mongo('logs').find({'user_id': user_id}) db = await _get_mongo('logs')
return await db.find({'user_id': user_id})
async def delete_by_id(log_id: str): async def delete_by_id(log_id: str):
return await _get_mongo('logs').delete_one({'_id': log_id}) db = await _get_mongo('logs')
return await db.delete_one({'_id': log_id})
async def delete_by_user_id(user_id: str): async def delete_by_user_id(user_id: str):
return await _get_mongo('logs').delete_many({'user_id': user_id}) db = await _get_mongo('logs')
return await db.delete_many({'user_id': user_id})
if __name__ == '__main__': if __name__ == '__main__':
pass pass

View file

@ -8,34 +8,41 @@ from motor.motor_asyncio import AsyncIOMotorClient
load_dotenv() load_dotenv()
def _get_mongo(collection_name: str): async def _get_mongo(collection_name: str):
return AsyncIOMotorClient(os.getenv('MONGO_URI'))['nova-core'][collection_name] return AsyncIOMotorClient(os.getenv('MONGO_URI'))['nova-core'][collection_name]
async def add_date(): async def add_date():
date = datetime.datetime.now(pytz.timezone('GMT')).strftime('%Y.%m.%d') date = datetime.datetime.now(pytz.timezone('GMT')).strftime('%Y.%m.%d')
year, month, day = date.split('.') year, month, day = date.split('.')
await _get_mongo('stats').update_one({}, {'$inc': {f'dates.{year}.{month}.{day}': 1}}, upsert=True) db = await _get_mongo('stats')
await db.update_one({}, {'$inc': {f'dates.{year}.{month}.{day}': 1}}, upsert=True)
async def add_ip_address(ip_address: str): async def add_ip_address(ip_address: str):
ip_address = ip_address.replace('.', '_') ip_address = ip_address.replace('.', '_')
await _get_mongo('stats').update_one({}, {'$inc': {f'ips.{ip_address}': 1}}, upsert=True) db = await _get_mongo('stats')
await db.update_one({}, {'$inc': {f'ips.{ip_address}': 1}}, upsert=True)
async def add_target(url: str): async def add_target(url: str):
await _get_mongo('stats').update_one({}, {'$inc': {f'targets.{url}': 1}}, upsert=True) db = await _get_mongo('stats')
await db.update_one({}, {'$inc': {f'targets.{url}': 1}}, upsert=True)
async def add_tokens(tokens: int, model: str): async def add_tokens(tokens: int, model: str):
await _get_mongo('stats').update_one({}, {'$inc': {f'tokens.{model}': tokens}}, upsert=True) db = await _get_mongo('stats')
await db.update_one({}, {'$inc': {f'tokens.{model}': tokens}}, upsert=True)
async def add_model(model: str): async def add_model(model: str):
await _get_mongo('stats').update_one({}, {'$inc': {f'models.{model}': 1}}, upsert=True) db = await _get_mongo('stats')
await db.update_one({}, {'$inc': {f'models.{model}': 1}}, upsert=True)
async def add_path(path: str): async def add_path(path: str):
path = path.replace('/', '_') path = path.replace('/', '_')
await _get_mongo('stats').update_one({}, {'$inc': {f'paths.{path}': 1}}, upsert=True) db = await _get_mongo('stats')
await db.update_one({}, {'$inc': {f'paths.{path}': 1}}, upsert=True)
async def get_value(obj_filter): async def get_value(obj_filter):
return await _get_mongo('stats').find_one({obj_filter}) db = await _get_mongo('stats')
return await db.find_one({obj_filter})
if __name__ == '__main__': if __name__ == '__main__':
asyncio.run(add_date()) asyncio.run(add_date())

View file

@ -12,7 +12,7 @@ load_dotenv()
with open('config/credits.yml', encoding='utf8') as f: with open('config/credits.yml', encoding='utf8') as f:
credits_config = yaml.safe_load(f) credits_config = yaml.safe_load(f)
def _get_mongo(collection_name: str): async def _get_mongo(collection_name: str):
return AsyncIOMotorClient(os.getenv('MONGO_URI'))['nova-core'][collection_name] return AsyncIOMotorClient(os.getenv('MONGO_URI'))['nova-core'][collection_name]
async def create(discord_id: int=0) -> dict: async def create(discord_id: int=0) -> dict:
@ -46,23 +46,28 @@ async def create(discord_id: int=0) -> dict:
return user return user
async def by_id(user_id: str): async def by_id(user_id: str):
return await _get_mongo('users').find_one({'_id': user_id}) db = await _get_mongo('users')
return await db.find_one({'_id': user_id})
async def by_discord_id(discord_id: str): async def by_discord_id(discord_id: str):
return await _get_mongo('users').find_one({'auth.discord': discord_id}) db = await _get_mongo('users')
return await db.find_one({'auth.discord': discord_id})
async def by_api_key(key: str): async def by_api_key(key: str):
return await _get_mongo('users').find_one({'api_key': key}) db = await _get_mongo('users')
return await db.find_one({'api_key': key})
async def update_by_id(user_id: str, update): async def update_by_id(user_id: str, update):
return await _get_mongo('users').update_one({'_id': user_id}, update) db = await _get_mongo('users')
return await db.update_one({'_id': user_id}, update)
async def update_by_filter(obj_filter, update): async def update_by_filter(obj_filter, update):
return await _get_mongo('users').update_one(obj_filter, update) db = await _get_mongo('users')
return await db.update_one(obj_filter, update)
async def delete(user_id: str): async def delete(user_id: str):
await _get_mongo('users').delete_one({'_id': user_id}) db = await _get_mongo('users')
await db.delete_one({'_id': user_id})
async def demo(): async def demo():
user = await create(69420) user = await create(69420)

View file

@ -17,7 +17,7 @@ async def create_chat_id() -> str:
return f'chatcmpl-{chat_id}' return f'chatcmpl-{chat_id}'
def create_chat_chunk(chat_id: str, model: str, content=None) -> dict: async def create_chat_chunk(chat_id: str, model: str, content=None) -> dict:
content = content or {} content = content or {}
delta = {} delta = {}

View file

@ -1,7 +1,7 @@
import json import json
import starlette import starlette
def error(code: int, message: str, tip: str) -> starlette.responses.Response: async def error(code: int, message: str, tip: str) -> starlette.responses.Response:
info = {'error': { info = {'error': {
'code': code, 'code': code,
'message': message, 'message': message,
@ -12,7 +12,7 @@ def error(code: int, message: str, tip: str) -> starlette.responses.Response:
return starlette.responses.Response(status_code=code, content=json.dumps(info)) return starlette.responses.Response(status_code=code, content=json.dumps(info))
def yield_error(code: int, message: str, tip: str) -> str: async def yield_error(code: int, message: str, tip: str) -> str:
return json.dumps({ return json.dumps({
'code': code, 'code': code,
'message': message, 'message': message,

View file

@ -1,6 +1,6 @@
import tiktoken import tiktoken
def count_for_messages(messages: list, model: str='gpt-3.5-turbo-0613') -> int: async def count_for_messages(messages: list, model: str='gpt-3.5-turbo-0613') -> int:
"""Return the number of tokens used by a list of messages.""" """Return the number of tokens used by a list of messages."""
try: try:

View file

@ -11,13 +11,15 @@ provider_modules = [
providers.closed4 providers.closed4
] ]
def _get_module_name(module) -> str: async def _get_module_name(module) -> str:
name = module.__name__ name = module.__name__
if '.' in name: if '.' in name:
return name.split('.')[-1] return name.split('.')[-1]
return name return name
async def balance_chat_request(payload: dict) -> dict: async def balance_chat_request(payload: dict) -> dict:
"""Load balance the chat completion request between chat providers."""
providers_available = [] providers_available = []
for provider_module in provider_modules: for provider_module in provider_modules:
@ -34,20 +36,37 @@ async def balance_chat_request(payload: dict) -> dict:
provider = random.choice(providers_available) provider = random.choice(providers_available)
target = provider.chat_completion(**payload) target = provider.chat_completion(**payload)
target['module'] = _get_module_name(provider)
module_name = await _get_module_name(provider)
target['module'] = module_name
return target return target
async def balance_organic_request(request: dict) -> dict: async def balance_organic_request(request: dict) -> dict:
"""Load balnace to non-chat completion request between other "organic" providers which respond in the desired format already."""
providers_available = [] providers_available = []
if not request.get('headers'):
request['headers'] = {
'Content-Type': 'application/json'
}
for provider_module in provider_modules: for provider_module in provider_modules:
if provider_module.ORGANIC: if not provider_module.ORGANIC:
continue
if '/moderations' in request['path']:
if not provider_module.MODERATIONS:
continue
providers_available.append(provider_module) providers_available.append(provider_module)
provider = random.choice(providers_available) provider = random.choice(providers_available)
target = provider.organify(request) target = provider.organify(request)
target['module'] = _get_module_name(provider)
module_name = await _get_module_name(provider)
target['module'] = module_name
return target return target

View file

@ -1,18 +1,49 @@
import os
import asyncio import asyncio
import openai as closedai
from typing import Union import aiohttp
from dotenv import load_dotenv import proxies
import load_balancing
load_dotenv() async def is_safe(inp) -> bool:
text = inp
closedai.api_key = os.getenv('LEGIT_CLOSEDAI_KEY') if isinstance(inp, list):
text = ''
if isinstance(inp[0], dict):
for msg in inp:
text += msg['content'] + '\n'
async def is_safe(text: Union[str, list]) -> bool: else:
return closedai.Moderation.create( text = '\n'.join(inp)
input=text,
)['results'][0]['flagged'] for _ in range(3):
req = await load_balancing.balance_organic_request(
{
'path': '/v1/moderations',
'payload': {'input': text}
}
)
async with aiohttp.ClientSession(connector=proxies.default_proxy.connector) as session:
try:
async with session.request(
method=req.get('method', 'POST'),
url=req['url'],
data=req.get('data'),
json=req.get('payload'),
headers=req.get('headers'),
cookies=req.get('cookies'),
ssl=False,
timeout=aiohttp.ClientTimeout(total=5),
) as res:
res.raise_for_status()
json_response = await res.json()
return not json_response['results'][0]['flagged']
except Exception as exc:
print('[!] moderation error:', type(exc), exc)
continue
if __name__ == '__main__': if __name__ == '__main__':
asyncio.run(is_safe('Hello')) print(asyncio.run(is_safe('I wanna kill myself')))

View file

@ -49,19 +49,26 @@ async def stream(
if is_chat and is_stream: if is_chat and is_stream:
chat_id = await chat.create_chat_id() chat_id = await chat.create_chat_id()
yield chat.create_chat_chunk( chunk = await chat.create_chat_chunk(
chat_id=chat_id, chat_id=chat_id,
model=model, model=model,
content=chat.CompletionStart content=chat.CompletionStart
) )
yield chunk
yield chat.create_chat_chunk( chunk = await chat.create_chat_chunk(
chat_id=chat_id, chat_id=chat_id,
model=model, model=model,
content=None content=None
) )
for _ in range(3): yield chunk
json_response = {
'error': 'No JSON response could be received'
}
for _ in range(5):
headers = { headers = {
'Content-Type': 'application/json' 'Content-Type': 'application/json'
} }
@ -81,18 +88,17 @@ async def stream(
webhook = dhooks.Webhook(os.getenv('DISCORD_WEBHOOK__API_ISSUE')) webhook = dhooks.Webhook(os.getenv('DISCORD_WEBHOOK__API_ISSUE'))
webhook.send(content=f'API Issue: **`{exc}`**\nhttps://i.imgflip.com/7uv122.jpg') webhook.send(content=f'API Issue: **`{exc}`**\nhttps://i.imgflip.com/7uv122.jpg')
yield errors.yield_error( error = errors.yield_error(
500, 500,
'Sorry, the API has no working keys anymore.', 'Sorry, the API has no working keys anymore.',
'The admins have been messaged automatically.' 'The admins have been messaged automatically.'
) )
yield error
return return
for k, v in target_request.get('headers', {}).items(): for k, v in target_request.get('headers', {}).items():
headers[k] = v headers[k] = v
json.dump(target_request, open('api.log.json', 'w'), indent=4)
async with aiohttp.ClientSession(connector=proxies.default_proxy.connector) as session: async with aiohttp.ClientSession(connector=proxies.default_proxy.connector) as session:
try: try:
async with session.request( async with session.request(
@ -109,15 +115,79 @@ async def stream(
timeout=aiohttp.ClientTimeout(total=float(os.getenv('TRANSFER_TIMEOUT', '120'))), timeout=aiohttp.ClientTimeout(total=float(os.getenv('TRANSFER_TIMEOUT', '120'))),
) as response: ) as response:
print(5)
if not is_stream:
json_response = await response.json()
try: try:
response.raise_for_status() response.raise_for_status()
except Exception as exc: except Exception as exc:
if 'Too Many Requests' in str(exc): if 'Too Many Requests' in str(exc):
print(429)
continue continue
if is_stream:
try:
async for chunk in response.content.iter_any():
send = False
chunk = f'{chunk.decode("utf8")}\n\n'
chunk = chunk.replace(os.getenv('MAGIC_WORD', 'novaOSScheckKeyword'), payload['model'])
chunk = chunk.replace(os.getenv('MAGIC_USER_WORD', 'novaOSSuserKeyword'), str(user['_id']))
if not chunk.strip():
send = False
if is_chat and '{' in chunk:
data = json.loads(chunk.split('data: ')[1])
chunk = chunk.replace(data['id'], chat_id)
send = True
if target_request['module'] == 'twa' and data.get('text'):
chunk = await chat.create_chat_chunk(
chat_id=chat_id,
model=model,
content=['text']
)
if not data['choices'][0]['delta']:
send = False
if data['choices'][0]['delta'] == {'role': 'assistant'}:
send = False
if send:
final_chunk = chunk.strip().replace('data: [DONE]', '') + '\n\n'
yield final_chunk
except Exception as exc:
if 'Connection closed' in str(exc):
error = errors.yield_error(
500,
'Sorry, there was an issue with the connection.',
'Please first check if the issue on your end. If this error repeats, please don\'t heistate to contact the staff!.'
)
yield error
return
break
except ProxyError as exc:
print('proxy error')
continue
if is_chat and is_stream:
chunk = await chat.create_chat_chunk(
chat_id=chat_id,
model=model,
content=chat.CompletionStop
)
yield chunk
yield 'data: [DONE]\n\n'
if not is_stream:
yield json.dumps(json_response)
# DONE =========================================================
if user and incoming_request: if user and incoming_request:
await logs.log_api_request( await logs.log_api_request(
user=user, user=user,
@ -130,46 +200,6 @@ async def stream(
'$inc': {'credits': -credits_cost} '$inc': {'credits': -credits_cost}
}) })
print(6)
if is_stream:
try:
async for chunk in response.content.iter_any():
send = False
chunk = f'{chunk.decode("utf8")}\n\n'
chunk = chunk.replace(os.getenv('MAGIC_WORD', 'novaOSScheckKeyword'), payload['model'])
# chunk = chunk.replace(os.getenv('MAGIC_USER_WORD', 'novaOSSuserKeyword'), user['_id'])
print(chunk)
if not chunk.strip():
send = False
if is_chat and '{' in chunk:
data = json.loads(chunk.split('data: ')[1])
send = True
if target_request['module'] == 'twa' and data.get('text'):
chunk = chat.create_chat_chunk(
chat_id=chat_id,
model=model,
content=['text']
)
if not data['choices'][0]['delta']:
send = False
if data['choices'][0]['delta'] == {'role': 'assistant'}:
send = False
if send:
yield chunk
except Exception as exc:
if 'Connection closed' in str(exc):
print('connection closed: ', exc)
continue
if not demo_mode:
ip_address = await network.get_ip(incoming_request) ip_address = await network.get_ip(incoming_request)
await stats.add_date() await stats.add_date()
@ -181,27 +211,5 @@ async def stream(
await stats.add_model(model) await stats.add_model(model)
await stats.add_tokens(input_tokens, model) await stats.add_tokens(input_tokens, model)
break
except ProxyError as exc:
print('proxy error')
continue
print(3)
if is_chat and is_stream:
chat_chunk = chat.create_chat_chunk(
chat_id=chat_id,
model=model,
content=chat.CompletionStop
)
data = json.dumps(chat_chunk)
yield 'data: [DONE]\n\n'
if not is_stream:
json_response = await response.json()
yield json_response.encode('utf8')
if __name__ == '__main__': if __name__ == '__main__':
asyncio.run(stream()) asyncio.run(stream())

View file

@ -4,11 +4,13 @@ import os
import json import json
import yaml import yaml
import logging import logging
import fastapi
import starlette import starlette
from dotenv import load_dotenv from dotenv import load_dotenv
import streaming import streaming
import moderation
from db import logs, users from db import logs, users
from helpers import tokens, errors, exceptions from helpers import tokens, errors, exceptions
@ -32,7 +34,8 @@ async def handle(incoming_request):
# METHOD # METHOD
if incoming_request.method not in ['GET', 'POST', 'PUT', 'DELETE', 'PATCH']: if incoming_request.method not in ['GET', 'POST', 'PUT', 'DELETE', 'PATCH']:
return errors.error(405, f'Method "{incoming_request.method}" is not allowed.', 'Change the request method to the correct one.') error = await errors.error(405, f'Method "{incoming_request.method}" is not allowed.', 'Change the request method to the correct one.')
return error
# PAYLOAD # PAYLOAD
try: try:
@ -42,7 +45,7 @@ async def handle(incoming_request):
# TOKENS # TOKENS
try: try:
input_tokens = tokens.count_for_messages(payload['messages']) input_tokens = await tokens.count_for_messages(payload['messages'])
except (KeyError, TypeError): except (KeyError, TypeError):
input_tokens = 0 input_tokens = 0
@ -50,7 +53,8 @@ async def handle(incoming_request):
received_key = incoming_request.headers.get('Authorization') received_key = incoming_request.headers.get('Authorization')
if not received_key: if not received_key:
return errors.error(401, 'No NovaAI API key given!', 'Add "Authorization: Bearer nv-..." to your request headers.') error = await errors.error(401, 'No NovaAI API key given!', 'Add "Authorization: Bearer nv-..." to your request headers.')
return error
if received_key.startswith('Bearer '): if received_key.startswith('Bearer '):
received_key = received_key.split('Bearer ')[1] received_key = received_key.split('Bearer ')[1]
@ -59,38 +63,60 @@ async def handle(incoming_request):
user = await users.by_api_key(received_key.strip()) user = await users.by_api_key(received_key.strip())
if not user: if not user:
return errors.error(401, 'Invalid NovaAI API key!', 'Create a new NovaOSS API key.') error = await errors.error(401, 'Invalid NovaAI API key!', 'Create a new NovaOSS API key.')
return error
ban_reason = user['status']['ban_reason'] ban_reason = user['status']['ban_reason']
if ban_reason: if ban_reason:
return errors.error(403, f'Your NovaAI account has been banned. Reason: "{ban_reason}".', 'Contact the staff for an appeal.') error = await errors.error(403, f'Your NovaAI account has been banned. Reason: "{ban_reason}".', 'Contact the staff for an appeal.')
return error
if not user['status']['active']: if not user['status']['active']:
return errors.error(418, 'Your NovaAI account is not active (paused).', 'Simply re-activate your account using a Discord command or the web panel.') error = await errors.error(418, 'Your NovaAI account is not active (paused).', 'Simply re-activate your account using a Discord command or the web panel.')
return error
# COST # COST
costs = credits_config['costs'] costs = credits_config['costs']
cost = costs['other'] cost = costs['other']
is_safe = True
if 'chat/completions' in path: if 'chat/completions' in path:
for model_name, model_cost in costs['chat-models'].items(): for model_name, model_cost in costs['chat-models'].items():
if model_name in payload['model']: if model_name in payload['model']:
cost = model_cost cost = model_cost
is_safe = await moderation.is_safe(payload['messages'])
else:
inp = payload.get('input', payload.get('prompt'))
if inp and not '/moderations' in path:
is_safe = await moderation.is_safe(inp)
if not is_safe:
error = await errors.error(400, 'The request contains content which violates this model\'s policies.', 'We currently don\'t support any NSFW models.')
return error
return
role_cost_multiplier = credits_config['bonuses'].get(user['role'], 1) role_cost_multiplier = credits_config['bonuses'].get(user['role'], 1)
cost = round(cost * role_cost_multiplier) cost = round(cost * role_cost_multiplier)
if user['credits'] < cost: if user['credits'] < cost:
return errors.error(429, 'Not enough credits.', 'Wait or earn more credits. Learn more on our website or Discord server.') error = await errors.error(429, 'Not enough credits.', 'Wait or earn more credits. Learn more on our website or Discord server.')
return error
# READY # READY
payload['user'] = str(user['_id']) # payload['user'] = str(user['_id'])
if 'chat/completions' in path and not payload.get('stream') is True: if 'chat/completions' in path and not payload.get('stream') is True:
payload['stream'] = False payload['stream'] = False
return starlette.responses.StreamingResponse( media_type = 'text/event-stream' if payload.get('stream', False) else 'application/json'
return fastapi.responses.StreamingResponse(
content=streaming.stream( content=streaming.stream(
user=user, user=user,
path=path, path=path,
@ -99,5 +125,5 @@ async def handle(incoming_request):
input_tokens=input_tokens, input_tokens=input_tokens,
incoming_request=incoming_request, incoming_request=incoming_request,
), ),
media_type='text/event-stream' if payload.get('stream', False) else 'application/json' media_type=media_type
) )

View file

@ -2,11 +2,11 @@ async def get_all_users(client):
users = client['nova-core']['users'] users = client['nova-core']['users']
return users return users
async def update_credits(users, settings = None): async def update_credits(users, settings=None):
if not settings: if not settings:
users.update_many({}, {"$inc": {"credits": 250}}) users.update_many({}, {"$inc": {"credits": 250}})
else: else:
for key, value in settings.items(): for key, value in settings.items():
users.update_many({'role': key}, {"$inc": {"credits": int(value)}}) users.update_many({'role': key}, {"$inc": {"credits": int(value)}})
print(f"Updated {key} to {value}") print(f'Updated {key} to {value}')

View file

@ -1,11 +1,13 @@
import asyncio import asyncio
from settings import roles
import autocredits import autocredits
import aiohttp import aiohttp
from dotenv import load_dotenv
import os import os
import pymongo import pymongo
from settings import roles
from dotenv import load_dotenv
load_dotenv() load_dotenv()
CONNECTION_STRING = os.getenv("CONNECTION_STRING") CONNECTION_STRING = os.getenv("CONNECTION_STRING")
@ -23,10 +25,9 @@ async def update_roles(users):
async with session.get('http://localhost:50000/get_roles') as response: async with session.get('http://localhost:50000/get_roles') as response:
data = await response.json() data = await response.json()
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
print(f"Error: {e}") raise ValueError('Could not get roles') from exc
return
lvlroles = [f"lvl{lvl}" for lvl in range(10, 110, 10)] + [''] lvlroles = [f'lvl{lvl}' for lvl in range(10, 110, 10)] + ['']
discord_users = data discord_users = data
users = await autocredits.get_all_users(pymongo_client) users = await autocredits.get_all_users(pymongo_client)
@ -41,8 +42,9 @@ async def update_roles(users):
for role in lvlroles: for role in lvlroles:
if role in roles: if role in roles:
bulk_updates.append(pymongo.UpdateOne({'auth.discord': int(discord)}, {'$set': {'role': role}})) bulk_updates.append(pymongo.UpdateOne({'auth.discord': int(discord)}, {'$set': {'role': role}}))
print(f"Updated {id_} to {role}") print(f'Updated {id_} to {role}')
break break
if bulk_updates: if bulk_updates:
with pymongo_client: with pymongo_client:
users.bulk_write(bulk_updates) users.bulk_write(bulk_updates)

View file

@ -23,7 +23,7 @@ MODEL = 'gpt-3.5-turbo'
MESSAGES = [ MESSAGES = [
{ {
'role': 'user', 'role': 'user',
'content': '1+1=', 'content': 'fuck you',
} }
] ]
@ -66,15 +66,10 @@ def test_library():
completion = closedai.ChatCompletion.create( completion = closedai.ChatCompletion.create(
model=MODEL, model=MODEL,
messages=MESSAGES, messages=MESSAGES
stream=True
) )
for event in completion: return completion['choices'][0]['message']['content']
try:
print(event['choices'][0]['delta']['content'])
except:
print('-')
def test_library_moderation(): def test_library_moderation():
return closedai.Moderation.create("I wanna kill myself, I wanna kill myself; It's all I hear right now, it's all I hear right now") return closedai.Moderation.create("I wanna kill myself, I wanna kill myself; It's all I hear right now, it's all I hear right now")
@ -83,8 +78,8 @@ def test_all():
"""Runs all tests.""" """Runs all tests."""
# print(test_server()) # print(test_server())
print(test_api()) # print(test_api())
# print(test_library()) print(test_library())
# print(test_library_moderation()) # print(test_library_moderation())