Merge pull request #7 from RayBytes/main

General changes
This commit is contained in:
Game_Time 2023-08-12 21:25:23 +05:00 committed by GitHub
commit bea0cecdd3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 85 additions and 67 deletions

View file

@ -16,8 +16,12 @@ UA_SIMPLIFY = {
'AppleWebKit/537.36 (KHTML, like Gecko)': 'K', 'AppleWebKit/537.36 (KHTML, like Gecko)': 'K',
} }
async def _get_mongo(collection_name: str): ## MONGODB Setup
return AsyncIOMotorClient(os.getenv('MONGO_URI'))['nova-core'][collection_name]
conn = AsyncIOMotorClient(os.getenv('MONGO_URI'))
async def _get_collection(collection_name: str):
return conn['nova-core'][collection_name]
async def replacer(text: str, dict_: dict) -> str: async def replacer(text: str, dict_: dict) -> str:
for k, v in dict_.items(): for k, v in dict_.items():
@ -25,7 +29,7 @@ async def replacer(text: str, dict_: dict) -> str:
return text return text
async def log_api_request(user: dict, incoming_request, target_url: str): async def log_api_request(user: dict, incoming_request, target_url: str):
db = await _get_mongo('logs') db = await _get_collection('logs')
payload = {} payload = {}
try: try:
@ -58,19 +62,19 @@ async def log_api_request(user: dict, incoming_request, target_url: str):
return log_item return log_item
async def by_id(log_id: str): async def by_id(log_id: str):
db = await _get_mongo('logs') db = await _get_collection('logs')
return await db.find_one({'_id': log_id}) return await db.find_one({'_id': log_id})
async def by_user_id(user_id: str): async def by_user_id(user_id: str):
db = await _get_mongo('logs') db = await _get_collection('logs')
return await db.find({'user_id': user_id}) return await db.find({'user_id': user_id})
async def delete_by_id(log_id: str): async def delete_by_id(log_id: str):
db = await _get_mongo('logs') db = await _get_collection('logs')
return await db.delete_one({'_id': log_id}) return await db.delete_one({'_id': log_id})
async def delete_by_user_id(user_id: str): async def delete_by_user_id(user_id: str):
db = await _get_mongo('logs') db = await _get_collection('logs')
return await db.delete_many({'user_id': user_id}) return await db.delete_many({'user_id': user_id})
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -8,40 +8,46 @@ from motor.motor_asyncio import AsyncIOMotorClient
load_dotenv() load_dotenv()
async def _get_mongo(collection_name: str): ## MONGODB Setup
return AsyncIOMotorClient(os.getenv('MONGO_URI'))['nova-core'][collection_name]
conn = AsyncIOMotorClient(os.getenv('MONGO_URI'))
async def _get_collection(collection_name: str):
return conn['nova-core'][collection_name]
## Statistics
async def add_date(): async def add_date():
date = datetime.datetime.now(pytz.timezone('GMT')).strftime('%Y.%m.%d') date = datetime.datetime.now(pytz.timezone('GMT')).strftime('%Y.%m.%d')
year, month, day = date.split('.') year, month, day = date.split('.')
db = await _get_mongo('stats') db = await _get_collection('stats')
await db.update_one({}, {'$inc': {f'dates.{year}.{month}.{day}': 1}}, upsert=True) await db.update_one({}, {'$inc': {f'dates.{year}.{month}.{day}': 1}}, upsert=True)
async def add_ip_address(ip_address: str): async def add_ip_address(ip_address: str):
ip_address = ip_address.replace('.', '_') ip_address = ip_address.replace('.', '_')
db = await _get_mongo('stats') db = await _get_collection('stats')
await db.update_one({}, {'$inc': {f'ips.{ip_address}': 1}}, upsert=True) await db.update_one({}, {'$inc': {f'ips.{ip_address}': 1}}, upsert=True)
async def add_target(url: str): async def add_target(url: str):
db = await _get_mongo('stats') db = await _get_collection('stats')
await db.update_one({}, {'$inc': {f'targets.{url}': 1}}, upsert=True) await db.update_one({}, {'$inc': {f'targets.{url}': 1}}, upsert=True)
async def add_tokens(tokens: int, model: str): async def add_tokens(tokens: int, model: str):
db = await _get_mongo('stats') db = await _get_collection('stats')
await db.update_one({}, {'$inc': {f'tokens.{model}': tokens}}, upsert=True) await db.update_one({}, {'$inc': {f'tokens.{model}': tokens}}, upsert=True)
async def add_model(model: str): async def add_model(model: str):
db = await _get_mongo('stats') db = await _get_collection('stats')
await db.update_one({}, {'$inc': {f'models.{model}': 1}}, upsert=True) await db.update_one({}, {'$inc': {f'models.{model}': 1}}, upsert=True)
async def add_path(path: str): async def add_path(path: str):
path = path.replace('/', '_') path = path.replace('/', '_')
db = await _get_mongo('stats') db = await _get_collection('stats')
await db.update_one({}, {'$inc': {f'paths.{path}': 1}}, upsert=True) await db.update_one({}, {'$inc': {f'paths.{path}': 1}}, upsert=True)
async def get_value(obj_filter): async def get_value(obj_filter):
db = await _get_mongo('stats') db = await _get_collection('stats')
return await db.find_one({obj_filter}) return await db.find_one({obj_filter})
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -12,8 +12,12 @@ load_dotenv()
with open('config/credits.yml', encoding='utf8') as f: with open('config/credits.yml', encoding='utf8') as f:
credits_config = yaml.safe_load(f) credits_config = yaml.safe_load(f)
async def _get_mongo(collection_name: str): ## MONGODB Setup
return AsyncIOMotorClient(os.getenv('MONGO_URI'))['nova-core'][collection_name]
conn = AsyncIOMotorClient(os.getenv('MONGO_URI'))
async def _get_collection(collection_name: str):
return conn['nova-core'][collection_name]
async def create(discord_id: str='') -> dict: async def create(discord_id: str='') -> dict:
"""Adds a new user to the MongoDB collection.""" """Adds a new user to the MongoDB collection."""
@ -41,33 +45,33 @@ async def create(discord_id: str='') -> dict:
} }
} }
db = await _get_mongo('users') db = await _get_collection('users')
await db.insert_one(new_user) await db.insert_one(new_user)
user = await db.find_one({'api_key': new_api_key}) user = await db.find_one({'api_key': new_api_key})
return user return user
async def by_id(user_id: str): async def by_id(user_id: str):
db = await _get_mongo('users') db = await _get_collection('users')
return await db.find_one({'_id': user_id}) return await db.find_one({'_id': user_id})
async def by_discord_id(discord_id: str): async def by_discord_id(discord_id: str):
db = await _get_mongo('users') db = await _get_collection('users')
return await db.find_one({'auth.discord': str(int(discord_id))}) return await db.find_one({'auth.discord': str(int(discord_id))})
async def by_api_key(key: str): async def by_api_key(key: str):
db = await _get_mongo('users') db = await _get_collection('users')
return await db.find_one({'api_key': key}) return await db.find_one({'api_key': key})
async def update_by_id(user_id: str, update): async def update_by_id(user_id: str, update):
db = await _get_mongo('users') db = await _get_collection('users')
return await db.update_one({'_id': user_id}, update) return await db.update_one({'_id': user_id}, update)
async def update_by_filter(obj_filter, update): async def update_by_filter(obj_filter, update):
db = await _get_mongo('users') db = await _get_collection('users')
return await db.update_one(obj_filter, update) return await db.update_one(obj_filter, update)
async def delete(user_id: str): async def delete(user_id: str):
db = await _get_mongo('users') db = await _get_collection('users')
await db.delete_one({'_id': user_id}) await db.delete_one({'_id': user_id})
async def demo(): async def demo():

View file

@ -20,8 +20,16 @@ async def create_chat_id() -> str:
return f'chatcmpl-{chat_id}' return f'chatcmpl-{chat_id}'
async def create_chat_chunk(chat_id: str, model: str, content=None) -> dict: async def create_chat_chunk(chat_id: str, model: str, content=None) -> dict:
"""Creates a new chat chunk""" """Creates the chunk for streaming chat.
Args:
chat_id (str): _description_
model (str): _description_
content (_type_, optional): _description_. Defaults to None.
Returns:
dict: _description_
"""
content = content or {} content = content or {}
delta = {} delta = {}
@ -54,12 +62,3 @@ async def create_chat_chunk(chat_id: str, model: str, content=None) -> dict:
} }
return f'data: {json.dumps(chunk)}\n\n' return f'data: {json.dumps(chunk)}\n\n'
if __name__ == '__main__':
demo_chat_id = asyncio.run(create_chat_id())
print(demo_chat_id)
print(asyncio.run(create_chat_chunk(
model='gpt-4',
content='Hello',
chat_id=demo_chat_id,
)))

View file

@ -1,7 +1,18 @@
import tiktoken import tiktoken
async def count_for_messages(messages: list, model: str='gpt-3.5-turbo-0613') -> int: async def count_for_messages(messages: list, model: str='gpt-3.5-turbo-0613') -> int:
"""Return the number of tokens used by a list of messages.""" """Return the number of tokens used by a list of messages
Args:
messages (list): _description_
model (str, optional): _description_. Defaults to 'gpt-3.5-turbo-0613'.
Raises:
NotImplementedError: _description_
Returns:
int: _description_
"""
try: try:
encoding = tiktoken.encoding_for_model(model) encoding = tiktoken.encoding_for_model(model)

View file

@ -39,7 +39,6 @@ async def balance_organic_request(request: dict) -> dict:
"""Load balnace to non-chat completion request between other "organic" providers which respond in the desired format already. """Load balnace to non-chat completion request between other "organic" providers which respond in the desired format already.
Organic providers are used for non-chat completions, such as moderation and other paths. Organic providers are used for non-chat completions, such as moderation and other paths.
""" """
providers_available = [] providers_available = []
if not request.get('headers'): if not request.get('headers'):

View file

@ -2,12 +2,15 @@
import asyncio import asyncio
async def invalidate_key(provider_and_key: str) -> none: async def invalidate_key(provider_and_key: str) -> None:
"""Invalidates a key stored in the secret/ folder by storing it in the associated .invalid.txt file. """
The schmea in which <provider_and_key> should be passed is:
<provider_name><key>, e.g. Invalidates a key stored in the secret/ folder by storing it in the associated .invalid.txt file.
closed4>sk-... The schmea in which <provider_and_key> should be passed is:
""" <provider_name><key>, e.g.
closed4>sk-...
"""
if not provider_and_key: if not provider_and_key:
return return

View file

@ -96,12 +96,11 @@ async def stream(
webhook = dhooks.Webhook(os.getenv('DISCORD_WEBHOOK__API_ISSUE')) webhook = dhooks.Webhook(os.getenv('DISCORD_WEBHOOK__API_ISSUE'))
webhook.send(content=f'API Issue: **`{exc}`**\nhttps://i.imgflip.com/7uv122.jpg') webhook.send(content=f'API Issue: **`{exc}`**\nhttps://i.imgflip.com/7uv122.jpg')
error = await errors.yield_error( yield await errors.yield_error(
500, 500,
'Sorry, the API has no working keys anymore.', 'Sorry, the API has no working keys anymore.',
'The admins have been messaged automatically.' 'The admins have been messaged automatically.'
) )
yield error
return return
for k, v in target_request.get('headers', {}).items(): for k, v in target_request.get('headers', {}).items():
@ -180,12 +179,11 @@ async def stream(
except Exception as exc: except Exception as exc:
if 'Connection closed' in str(exc): if 'Connection closed' in str(exc):
error = await errors.yield_error( yield await errors.yield_error(
500, 500,
'Sorry, there was an issue with the connection.', 'Sorry, there was an issue with the connection.',
'Please first check if the issue on your end. If this error repeats, please don\'t heistate to contact the staff!.' 'Please first check if the issue on your end. If this error repeats, please don\'t heistate to contact the staff!.'
) )
yield error
return return
break break

View file

@ -26,8 +26,7 @@ async def handle(incoming_request):
# METHOD # METHOD
if incoming_request.method not in ['GET', 'POST', 'PUT', 'DELETE', 'PATCH']: if incoming_request.method not in ['GET', 'POST', 'PUT', 'DELETE', 'PATCH']:
error = await errors.error(405, f'Method "{incoming_request.method}" is not allowed.', 'Change the request method to the correct one.') return await errors.error(405, f'Method "{incoming_request.method}" is not allowed.', 'Change the request method to the correct one.')
return error
# PAYLOAD # PAYLOAD
try: try:
@ -35,42 +34,37 @@ async def handle(incoming_request):
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
payload = {} payload = {}
# TOKENS # Tokenise w/ tiktoken
try: try:
input_tokens = await tokens.count_for_messages(payload['messages']) input_tokens = await tokens.count_for_messages(payload['messages'])
except (KeyError, TypeError): except (KeyError, TypeError):
input_tokens = 0 input_tokens = 0
# AUTH # Check user auth
received_key = incoming_request.headers.get('Authorization') received_key = incoming_request.headers.get('Authorization')
if not received_key: if not received_key:
error = await errors.error(401, 'No NovaAI API key given!', 'Add "Authorization: Bearer nv-..." to your request headers.') return await errors.error(401, 'No NovaAI API key given!', 'Add "Authorization: Bearer nv-..." to your request headers.')
return error
if received_key.startswith('Bearer '): if received_key.startswith('Bearer '):
received_key = received_key.split('Bearer ')[1] received_key = received_key.split('Bearer ')[1]
# USER
user = await users.by_api_key(received_key.strip()) user = await users.by_api_key(received_key.strip())
if not user: if not user:
error = await errors.error(401, 'Invalid NovaAI API key!', 'Create a new NovaOSS API key.') return await errors.error(401, 'Invalid NovaAI API key!', 'Create a new NovaOSS API key.')
return error
ban_reason = user['status']['ban_reason'] ban_reason = user['status']['ban_reason']
if ban_reason: if ban_reason:
error = await errors.error(403, f'Your NovaAI account has been banned. Reason: "{ban_reason}".', 'Contact the staff for an appeal.') return await errors.error(403, f'Your NovaAI account has been banned. Reason: "{ban_reason}".', 'Contact the staff for an appeal.')
return error
if not user['status']['active']: if not user['status']['active']:
error = await errors.error(418, 'Your NovaAI account is not active (paused).', 'Simply re-activate your account using a Discord command or the web panel.') return await errors.error(418, 'Your NovaAI account is not active (paused).', 'Simply re-activate your account using a Discord command or the web panel.')
return error
if '/models' in path: if '/models' in path:
return fastapi.responses.JSONResponse(content=models_list) return fastapi.responses.JSONResponse(content=models_list)
# COST # Calculate cost of tokens & check for nsfw prompts
costs = credits_config['costs'] costs = credits_config['costs']
cost = costs['other'] cost = costs['other']
@ -94,17 +88,17 @@ async def handle(incoming_request):
policy_violation = await moderation.is_policy_violated(inp) policy_violation = await moderation.is_policy_violated(inp)
if policy_violation: if policy_violation:
error = await errors.error(400, f'The request contains content which violates this model\'s policies for "{policy_violation}".', 'We currently don\'t support any NSFW models.') return await errors.error(400, f'The request contains content which violates this model\'s policies for "{policy_violation}".', 'We currently don\'t support any NSFW models.')
return error
role_cost_multiplier = credits_config['bonuses'].get(user['role'], 1) role_cost_multiplier = credits_config['bonuses'].get(user['role'], 1)
cost = round(cost * role_cost_multiplier) cost = round(cost * role_cost_multiplier)
if user['credits'] < cost: if user['credits'] < cost:
error = await errors.error(429, 'Not enough credits.', 'Wait or earn more credits. Learn more on our website or Discord server.') return await errors.error(429, 'Not enough credits.', 'Wait or earn more credits. Learn more on our website or Discord server.')
return error
# READY
# Send the completion request
if 'chat/completions' in path and not payload.get('stream') is True: if 'chat/completions' in path and not payload.get('stream') is True:
payload['stream'] = False payload['stream'] = False