Compare commits

..

3 commits

Author SHA1 Message Date
nsde 008bf56fdf Fixed spelling 2023-08-30 22:19:27 +02:00
nsde 15f816fd1d Uses own AI for moderation 2023-08-30 22:13:23 +02:00
nsde d4237dd65e Small improvements 2023-08-30 20:55:31 +02:00
7 changed files with 83 additions and 37 deletions

View file

@ -41,6 +41,8 @@ async def handle(incoming_request: fastapi.Request):
payload = await incoming_request.json() payload = await incoming_request.json()
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
payload = {} payload = {}
except UnicodeDecodeError:
payload = {}
received_key = incoming_request.headers.get('Authorization') received_key = incoming_request.headers.get('Authorization')
@ -83,29 +85,31 @@ async def handle(incoming_request: fastapi.Request):
if user['credits'] < cost: if user['credits'] < cost:
return await errors.error(429, 'Not enough credits.', 'Wait or earn more credits. Learn more on our website or Discord server.') return await errors.error(429, 'Not enough credits.', 'Wait or earn more credits. Learn more on our website or Discord server.')
payload_with_vars = json.dumps(payload)
replace_dict = { if not 'DISABLE_VARS' in key_tags:
'timestamp': str(int(time.time())), payload_with_vars = json.dumps(payload)
'date': time.strftime('%Y-%m-%d'),
'time': time.strftime('%H:%M:%S'),
'datetime': time.strftime('%Y-%m-%d %H:%M:%S'),
'model': payload.get('model', 'unknown'),
}
if 'ALLOW_INSECURE_VARS' in key_tags: replace_dict = {
replace_dict.update({ 'timestamp': str(int(time.time())),
'my.ip': ip_address, 'date': time.strftime('%Y-%m-%d'),
'my.id': str(user['_id']), 'time': time.strftime('%H:%M:%S'),
'my.role': user.get('role', 'default'), 'datetime': time.strftime('%Y-%m-%d %H:%M:%S'),
'my.credits': str(user['credits']), 'model': payload.get('model', 'unknown'),
'my.discord': user.get('auth', {}).get('discord', ''), }
})
for key, value in replace_dict.items(): if 'ALLOW_INSECURE_VARS' in key_tags:
payload_with_vars = payload_with_vars.replace(f'[[{key}]]', value) replace_dict.update({
'my.ip': ip_address,
'my.id': str(user['_id']),
'my.role': user.get('role', 'default'),
'my.credits': str(user['credits']),
'my.discord': user.get('auth', {}).get('discord', ''),
})
payload = json.loads(payload_with_vars) for key, value in replace_dict.items():
payload_with_vars = payload_with_vars.replace(f'[[{key}]]', value)
payload = json.loads(payload_with_vars)
policy_violation = False policy_violation = False
if '/moderations' not in path: if '/moderations' not in path:

View file

@ -1,4 +1,9 @@
import asyncio import os
import time
from dotenv import load_dotenv
load_dotenv()
async def get_ip(request) -> str: async def get_ip(request) -> str:
"""Get the IP address of the incoming request.""" """Get the IP address of the incoming request."""
@ -17,7 +22,7 @@ async def get_ip(request) -> str:
return detected_ip return detected_ip
def get_ip_sync(request) -> str: def get_ratelimit_key(request) -> str:
"""Get the IP address of the incoming request.""" """Get the IP address of the incoming request."""
xff = None xff = None
@ -32,4 +37,9 @@ def get_ip_sync(request) -> str:
detected_ip = next((i for i in possible_ips if i), None) detected_ip = next((i for i in possible_ips if i), None)
for whitelisted_ip in os.getenv('NO_RATELIMIT_IPS', '').split():
if whitelisted_ip in detected_ip:
custom_key = f'whitelisted-{time.time()}'
return custom_key
return detected_ip return detected_ip

View file

@ -32,7 +32,7 @@ app.include_router(core.router)
limiter = Limiter( limiter = Limiter(
swallow_errors=True, swallow_errors=True,
key_func=network.get_ip_sync, default_limits=[ key_func=network.get_ratelimit_key, default_limits=[
'2/second', '2/second',
'20/minute', '20/minute',
'300/hour' '300/hour'

View file

@ -3,6 +3,7 @@
import time import time
import asyncio import asyncio
import aiohttp import aiohttp
import profanity_check
import proxies import proxies
import provider_auth import provider_auth
@ -10,14 +11,8 @@ import load_balancing
from typing import Union from typing import Union
async def is_policy_violated(inp: Union[str, list]) -> bool: def input_to_text(inp: Union[str, list]) -> str:
""" """Converts the input to a string."""
### Check if a message violates the moderation policy.
You can either pass a list of messages consisting of dicts with "role" and "content", as used in the API parameter,
or just a simple string.
Returns True if the message violates the policy, False otherwise.
"""
text = inp text = inp
@ -30,7 +25,21 @@ async def is_policy_violated(inp: Union[str, list]) -> bool:
else: else:
text = '\n'.join(inp) text = '\n'.join(inp)
for _ in range(5): return text
async def is_policy_violated(inp: Union[str, list]) -> bool:
"""
### Check if a message violates the moderation policy.
You can either pass a list of messages consisting of dicts with "role" and "content", as used in the API parameter,
or just a simple string.
Returns True if the message violates the policy, False otherwise.
"""
text = input_to_text(inp)
return await is_policy_violated__own_model(text)
for _ in range(1):
req = await load_balancing.balance_organic_request( req = await load_balancing.balance_organic_request(
{ {
'path': '/v1/moderations', 'path': '/v1/moderations',
@ -61,11 +70,18 @@ async def is_policy_violated(inp: Union[str, list]) -> bool:
return False return False
except Exception as exc: except Exception as exc:
if '401' in str(exc): if '401' in str(exc):
await provider_auth.invalidate_key(req.get('provider_auth')) await provider_auth.invalidate_key(req.get('provider_auth'))
print('[!] moderation error:', type(exc), exc) print('[!] moderation error:', type(exc), exc)
continue continue
async def is_policy_violated__own_model(inp: Union[str, list]) -> bool:
inp = input_to_text(inp)
if profanity_check.predict([inp])[0]:
return 'own model detected'
return False
if __name__ == '__main__': if __name__ == '__main__':
print(asyncio.run(is_policy_violated('I wanna kill myself'))) print(asyncio.run(is_policy_violated('kill ms')))

View file

@ -177,5 +177,7 @@ async def stream(
model=model, model=model,
) )
print(f'[+] {path} -> {model or ""}')
if __name__ == '__main__': if __name__ == '__main__':
asyncio.run(stream()) asyncio.run(stream())

View file

@ -3,17 +3,20 @@
# Commit to the production branch # Commit to the production branch
# git commit -am "Auto-trigger - Production server started" && git push origin Production # git commit -am "Auto-trigger - Production server started" && git push origin Production
# Kill production server
fuser -k 2333/tcp
# Clear production directory
rm -rf /home/nova-prod/*
# Copy files to production # Copy files to production
cp -r * /home/nova-prod cp -r * /home/nova-prod
# Copy env file to production # Copy .prod.env file to production
cp env/.prod.env /home/nova-prod/.env cp env/.prod.env /home/nova-prod/.env
# Change directory # Change directory
cd /home/nova-prod cd /home/nova-prod
# Kill the production server
fuser -k 2333/tcp
# Start screen # Start screen
screen -S nova-api python run prod && sleep 5 screen -S nova-api python run prod && sleep 5

View file

@ -23,6 +23,17 @@ or
pip install . pip install .
``` ```
***
Profanity checking requires:
```
pip install alt-profanity-check
# doesn't work? try
pip install git+https://github.com/dimitrismistriotis/alt-profanity-check.git
```
## `.env` configuration ## `.env` configuration
Create a `.env` file, make sure not to reveal any of its contents to anyone, and fill in the required values in the format `KEY=VALUE`. Otherwise, the code won't run. Create a `.env` file, make sure not to reveal any of its contents to anyone, and fill in the required values in the format `KEY=VALUE`. Otherwise, the code won't run.