Compare commits

...

2 commits

Author SHA1 Message Date
nsde ce24c3a5a2 Fixed some several issues with moderation, models etc. 2023-08-08 01:04:35 +02:00
nsde 1a3e275a1c Added auto-rewards and invalid key system 2023-08-07 23:28:24 +02:00
10 changed files with 108 additions and 88 deletions

View file

@ -3,14 +3,6 @@ import asyncio
import providers import providers
provider_modules = [
# providers.twa,
# providers.quantum,
providers.churchless,
providers.closed,
providers.closed4
]
async def _get_module_name(module) -> str: async def _get_module_name(module) -> str:
name = module.__name__ name = module.__name__
if '.' in name: if '.' in name:
@ -22,7 +14,7 @@ async def balance_chat_request(payload: dict) -> dict:
providers_available = [] providers_available = []
for provider_module in provider_modules: for provider_module in providers.MODULES:
if payload['stream'] and not provider_module.STREAMING: if payload['stream'] and not provider_module.STREAMING:
continue continue
@ -32,7 +24,7 @@ async def balance_chat_request(payload: dict) -> dict:
providers_available.append(provider_module) providers_available.append(provider_module)
if not providers_available: if not providers_available:
raise NotImplementedError('This model does not exist.') raise NotImplementedError(f'The model "{payload["model"]}" is not available. MODEl_UNAVAILABLE')
provider = random.choice(providers_available) provider = random.choice(providers_available)
target = provider.chat_completion(**payload) target = provider.chat_completion(**payload)
@ -52,7 +44,7 @@ async def balance_organic_request(request: dict) -> dict:
'Content-Type': 'application/json' 'Content-Type': 'application/json'
} }
for provider_module in provider_modules: for provider_module in providers.MODULES:
if not provider_module.ORGANIC: if not provider_module.ORGANIC:
continue continue

View file

@ -1,10 +1,11 @@
import asyncio import asyncio
import aiohttp import aiohttp
import proxies import proxies
import provider_auth
import load_balancing import load_balancing
async def is_safe(inp) -> bool: async def is_policy_violated(inp) -> bool:
text = inp text = inp
if isinstance(inp, list): if isinstance(inp, list):
@ -34,16 +35,21 @@ async def is_safe(inp) -> bool:
headers=req.get('headers'), headers=req.get('headers'),
cookies=req.get('cookies'), cookies=req.get('cookies'),
ssl=False, ssl=False,
timeout=aiohttp.ClientTimeout(total=5), timeout=aiohttp.ClientTimeout(total=2),
) as res: ) as res:
res.raise_for_status() res.raise_for_status()
json_response = await res.json() json_response = await res.json()
categories = json_response['results'][0]['category_scores']
if json_response['results'][0]['flagged']:
return max(categories, key=categories.get)
return False
return not json_response['results'][0]['flagged']
except Exception as exc: except Exception as exc:
# await provider_auth.invalidate_key(req.get('provider_auth'))
print('[!] moderation error:', type(exc), exc) print('[!] moderation error:', type(exc), exc)
continue continue
if __name__ == '__main__': if __name__ == '__main__':
print(asyncio.run(is_safe('I wanna kill myself'))) print(asyncio.run(is_policy_violated('I wanna kill myself')))

21
api/provider_auth.py Normal file
View file

@ -0,0 +1,21 @@
import asyncio
async def invalidate_key(provider_and_key):
if not provider_and_key:
return
provider = provider_and_key.split('>')[0]
provider_file = f'secret/{provider}.txt'
key = provider_and_key.split('>')[1]
with open(provider_file, encoding='utf8') as f_in:
text = f_in.read()
with open(provider_file, 'w', encoding='utf8') as f_out:
f_out.write(text.replace(key, ''))
with open(f'secret/{provider}.invalid.txt', 'a', encoding='utf8') as f:
f.write(key + '\n')
if __name__ == '__main__':
asyncio.run(invalidate_key('closed>sk-...'))

View file

@ -1,36 +1,26 @@
"""Module for transferring requests to ClosedAI API""" """Module for transferring requests to ClosedAI API"""
import os
import json import json
import yaml import yaml
import logging
import fastapi import fastapi
import starlette
from dotenv import load_dotenv from dotenv import load_dotenv
import streaming import streaming
import moderation import moderation
from db import logs, users from db import users
from helpers import tokens, errors, exceptions from helpers import tokens, errors
load_dotenv() load_dotenv()
# log to "api.log" file
logging.basicConfig(
filename='api.log',
level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(name)s %(message)s'
)
with open('config/credits.yml', encoding='utf8') as f: with open('config/credits.yml', encoding='utf8') as f:
credits_config = yaml.safe_load(f) credits_config = yaml.safe_load(f)
async def handle(incoming_request): async def handle(incoming_request):
"""Transfer a streaming response from the incoming request to the target endpoint""" """Transfer a streaming response from the incoming request to the target endpoint"""
path = incoming_request.url.path path = incoming_request.url.path.replace('v1/v1/', 'v1/')
# METHOD # METHOD
if incoming_request.method not in ['GET', 'POST', 'PUT', 'DELETE', 'PATCH']: if incoming_request.method not in ['GET', 'POST', 'PUT', 'DELETE', 'PATCH']:
@ -79,23 +69,27 @@ async def handle(incoming_request):
costs = credits_config['costs'] costs = credits_config['costs']
cost = costs['other'] cost = costs['other']
is_safe = True policy_violation = False
if 'chat/completions' in path: if 'chat/completions' in path:
for model_name, model_cost in costs['chat-models'].items(): for model_name, model_cost in costs['chat-models'].items():
if model_name in payload['model']: if model_name in payload['model']:
cost = model_cost cost = model_cost
is_safe = await moderation.is_safe(payload['messages']) policy_violation = await moderation.is_policy_violated(payload['messages'])
elif '/moderations' in path:
pass
else: else:
inp = payload.get('input', payload.get('prompt')) inp = payload.get('input', payload.get('prompt'))
if inp and not '/moderations' in path: if inp:
is_safe = await moderation.is_safe(inp) if len(inp) > 2 and not inp.isnumeric():
policy_violation = await moderation.is_policy_violated(inp)
if not is_safe: if policy_violation:
error = await errors.error(400, 'The request contains content which violates this model\'s policies.', 'We currently don\'t support any NSFW models.') error = await errors.error(400, f'The request contains content which violates this model\'s policies for "{policy_violation}".', 'We currently don\'t support any NSFW models.')
return error return error
role_cost_multiplier = credits_config['bonuses'].get(user['role'], 1) role_cost_multiplier = credits_config['bonuses'].get(user['role'], 1)

2
rewards/__main__.py Normal file
View file

@ -0,0 +1,2 @@
import main
main.launch()

View file

@ -7,9 +7,9 @@ async def update_credits(pymongo_client, settings=None):
users = await get_all_users(pymongo_client) users = await get_all_users(pymongo_client)
if not settings: if not settings:
users.update_many({}, {"$inc": {"credits": 2500}}) users.update_many({}, {'$inc': {'credits': 2500}})
else: else:
for key, value in settings.items(): for key, value in settings.items():
users.update_many( users.update_many(
{'level': key}, {"$inc": {"credits": int(value)}}) {'level': key}, {'$inc': {'credits': int(value)}})

53
rewards/main.py Normal file
View file

@ -0,0 +1,53 @@
import os
import time
import aiohttp
import pymongo
import asyncio
import autocredits
from settings import roles
from dotenv import load_dotenv
load_dotenv()
async def main():
mongo = pymongo.MongoClient(os.getenv('MONGO_URI'))
await update_roles(mongo)
await autocredits.update_credits(mongo, roles)
async def update_roles(mongo):
async with aiohttp.ClientSession() as session:
try:
async with session.get('http://0.0.0.0:3224/get_roles') as response:
discord_users = await response.json()
except aiohttp.ClientError as e:
print(f'Error: {e}')
return
level_role_names = [f'lvl{lvl}' for lvl in range(10, 110, 10)]
users = await autocredits.get_all_users(mongo)
for user in users.find():
discord = str(user['auth']['discord'])
for user_id, role_names in discord_users.items():
if user_id == discord:
for role in level_role_names:
if role in role_names:
users.update_one(
{'auth.discord': int(discord)},
{'$set': {'level': role}}
)
print(f'Updated {discord} to {role}')
return users
def launch():
asyncio.run(main())
with open('rewards/last_update.txt', 'w') as f:
f.write(str(time.time()))
if __name__ == '__main__':
launch()

View file

@ -1,48 +0,0 @@
import asyncio
from settings import roles
import autocredits
import aiohttp
from dotenv import load_dotenv
import os
import pymongo
load_dotenv()
CONNECTION_STRING = os.getenv("MONGO_URI")
async def main():
pymongo_client = pymongo.MongoClient(CONNECTION_STRING)
await update_roles(pymongo_client)
await autocredits.update_credits(pymongo_client, roles)
async def update_roles(pymongo_client):
async with aiohttp.ClientSession() as session:
try:
async with session.get('http://0.0.0.0:3224/get_roles') as response:
data = await response.json()
except aiohttp.ClientError as e:
print(f"Error: {e}")
return
lvlroles = [f"lvl{lvl}" for lvl in range(10, 110, 10)]
discord_users = data
users = await autocredits.get_all_users(pymongo_client)
for user in users.find():
discord = str(user['auth']['discord'])
for id_, roles in discord_users.items():
if id_ == discord:
for role in lvlroles:
if role in roles:
users.update_one({'auth.discord': int(discord)}, {
'$set': {'level': role}})
print(f"Updated {discord} to {role}")
return users
if __name__ == "__main__":
asyncio.run(main())

View file

@ -23,7 +23,7 @@ MODEL = 'gpt-3.5-turbo'
MESSAGES = [ MESSAGES = [
{ {
'role': 'user', 'role': 'user',
'content': 'fuck you', 'content': '1+1=',
} }
] ]