Compare commits

...

2 commits

Author SHA1 Message Date
nsde ce24c3a5a2 Fixed some several issues with moderation, models etc. 2023-08-08 01:04:35 +02:00
nsde 1a3e275a1c Added auto-rewards and invalid key system 2023-08-07 23:28:24 +02:00
10 changed files with 108 additions and 88 deletions

View file

@ -3,14 +3,6 @@ import asyncio
import providers
provider_modules = [
# providers.twa,
# providers.quantum,
providers.churchless,
providers.closed,
providers.closed4
]
async def _get_module_name(module) -> str:
name = module.__name__
if '.' in name:
@ -22,7 +14,7 @@ async def balance_chat_request(payload: dict) -> dict:
providers_available = []
for provider_module in provider_modules:
for provider_module in providers.MODULES:
if payload['stream'] and not provider_module.STREAMING:
continue
@ -32,7 +24,7 @@ async def balance_chat_request(payload: dict) -> dict:
providers_available.append(provider_module)
if not providers_available:
raise NotImplementedError('This model does not exist.')
raise NotImplementedError(f'The model "{payload["model"]}" is not available. MODEl_UNAVAILABLE')
provider = random.choice(providers_available)
target = provider.chat_completion(**payload)
@ -52,7 +44,7 @@ async def balance_organic_request(request: dict) -> dict:
'Content-Type': 'application/json'
}
for provider_module in provider_modules:
for provider_module in providers.MODULES:
if not provider_module.ORGANIC:
continue

View file

@ -1,10 +1,11 @@
import asyncio
import aiohttp
import proxies
import provider_auth
import load_balancing
async def is_safe(inp) -> bool:
async def is_policy_violated(inp) -> bool:
text = inp
if isinstance(inp, list):
@ -34,16 +35,21 @@ async def is_safe(inp) -> bool:
headers=req.get('headers'),
cookies=req.get('cookies'),
ssl=False,
timeout=aiohttp.ClientTimeout(total=5),
timeout=aiohttp.ClientTimeout(total=2),
) as res:
res.raise_for_status()
json_response = await res.json()
categories = json_response['results'][0]['category_scores']
if json_response['results'][0]['flagged']:
return max(categories, key=categories.get)
return False
return not json_response['results'][0]['flagged']
except Exception as exc:
# await provider_auth.invalidate_key(req.get('provider_auth'))
print('[!] moderation error:', type(exc), exc)
continue
if __name__ == '__main__':
print(asyncio.run(is_safe('I wanna kill myself')))
print(asyncio.run(is_policy_violated('I wanna kill myself')))

21
api/provider_auth.py Normal file
View file

@ -0,0 +1,21 @@
import asyncio
async def invalidate_key(provider_and_key):
if not provider_and_key:
return
provider = provider_and_key.split('>')[0]
provider_file = f'secret/{provider}.txt'
key = provider_and_key.split('>')[1]
with open(provider_file, encoding='utf8') as f_in:
text = f_in.read()
with open(provider_file, 'w', encoding='utf8') as f_out:
f_out.write(text.replace(key, ''))
with open(f'secret/{provider}.invalid.txt', 'a', encoding='utf8') as f:
f.write(key + '\n')
if __name__ == '__main__':
asyncio.run(invalidate_key('closed>sk-...'))

View file

@ -1,36 +1,26 @@
"""Module for transferring requests to ClosedAI API"""
import os
import json
import yaml
import logging
import fastapi
import starlette
from dotenv import load_dotenv
import streaming
import moderation
from db import logs, users
from helpers import tokens, errors, exceptions
from db import users
from helpers import tokens, errors
load_dotenv()
# log to "api.log" file
logging.basicConfig(
filename='api.log',
level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(name)s %(message)s'
)
with open('config/credits.yml', encoding='utf8') as f:
credits_config = yaml.safe_load(f)
async def handle(incoming_request):
"""Transfer a streaming response from the incoming request to the target endpoint"""
path = incoming_request.url.path
path = incoming_request.url.path.replace('v1/v1/', 'v1/')
# METHOD
if incoming_request.method not in ['GET', 'POST', 'PUT', 'DELETE', 'PATCH']:
@ -79,23 +69,27 @@ async def handle(incoming_request):
costs = credits_config['costs']
cost = costs['other']
is_safe = True
policy_violation = False
if 'chat/completions' in path:
for model_name, model_cost in costs['chat-models'].items():
if model_name in payload['model']:
cost = model_cost
is_safe = await moderation.is_safe(payload['messages'])
policy_violation = await moderation.is_policy_violated(payload['messages'])
elif '/moderations' in path:
pass
else:
inp = payload.get('input', payload.get('prompt'))
if inp and not '/moderations' in path:
is_safe = await moderation.is_safe(inp)
if inp:
if len(inp) > 2 and not inp.isnumeric():
policy_violation = await moderation.is_policy_violated(inp)
if not is_safe:
error = await errors.error(400, 'The request contains content which violates this model\'s policies.', 'We currently don\'t support any NSFW models.')
if policy_violation:
error = await errors.error(400, f'The request contains content which violates this model\'s policies for "{policy_violation}".', 'We currently don\'t support any NSFW models.')
return error
role_cost_multiplier = credits_config['bonuses'].get(user['role'], 1)

2
rewards/__main__.py Normal file
View file

@ -0,0 +1,2 @@
import main
main.launch()

View file

@ -7,9 +7,9 @@ async def update_credits(pymongo_client, settings=None):
users = await get_all_users(pymongo_client)
if not settings:
users.update_many({}, {"$inc": {"credits": 2500}})
users.update_many({}, {'$inc': {'credits': 2500}})
else:
for key, value in settings.items():
users.update_many(
{'level': key}, {"$inc": {"credits": int(value)}})
{'level': key}, {'$inc': {'credits': int(value)}})

53
rewards/main.py Normal file
View file

@ -0,0 +1,53 @@
import os
import time
import aiohttp
import pymongo
import asyncio
import autocredits
from settings import roles
from dotenv import load_dotenv
load_dotenv()
async def main():
mongo = pymongo.MongoClient(os.getenv('MONGO_URI'))
await update_roles(mongo)
await autocredits.update_credits(mongo, roles)
async def update_roles(mongo):
async with aiohttp.ClientSession() as session:
try:
async with session.get('http://0.0.0.0:3224/get_roles') as response:
discord_users = await response.json()
except aiohttp.ClientError as e:
print(f'Error: {e}')
return
level_role_names = [f'lvl{lvl}' for lvl in range(10, 110, 10)]
users = await autocredits.get_all_users(mongo)
for user in users.find():
discord = str(user['auth']['discord'])
for user_id, role_names in discord_users.items():
if user_id == discord:
for role in level_role_names:
if role in role_names:
users.update_one(
{'auth.discord': int(discord)},
{'$set': {'level': role}}
)
print(f'Updated {discord} to {role}')
return users
def launch():
asyncio.run(main())
with open('rewards/last_update.txt', 'w') as f:
f.write(str(time.time()))
if __name__ == '__main__':
launch()

View file

@ -1,48 +0,0 @@
import asyncio
from settings import roles
import autocredits
import aiohttp
from dotenv import load_dotenv
import os
import pymongo
load_dotenv()
CONNECTION_STRING = os.getenv("MONGO_URI")
async def main():
pymongo_client = pymongo.MongoClient(CONNECTION_STRING)
await update_roles(pymongo_client)
await autocredits.update_credits(pymongo_client, roles)
async def update_roles(pymongo_client):
async with aiohttp.ClientSession() as session:
try:
async with session.get('http://0.0.0.0:3224/get_roles') as response:
data = await response.json()
except aiohttp.ClientError as e:
print(f"Error: {e}")
return
lvlroles = [f"lvl{lvl}" for lvl in range(10, 110, 10)]
discord_users = data
users = await autocredits.get_all_users(pymongo_client)
for user in users.find():
discord = str(user['auth']['discord'])
for id_, roles in discord_users.items():
if id_ == discord:
for role in lvlroles:
if role in roles:
users.update_one({'auth.discord': int(discord)}, {
'$set': {'level': role}})
print(f"Updated {discord} to {role}")
return users
if __name__ == "__main__":
asyncio.run(main())

View file

@ -23,7 +23,7 @@ MODEL = 'gpt-3.5-turbo'
MESSAGES = [
{
'role': 'user',
'content': 'fuck you',
'content': '1+1=',
}
]