mirror of
https://github.com/NovaOSS/nova-api.git
synced 2024-11-25 20:43:56 +01:00
Alotta changes.
Setup managers and modified other files, cleaning up codebase. Created user.py class for future type usage.
This commit is contained in:
parent
bb1e9de563
commit
f896b18968
4
.vscode/settings.json
vendored
4
.vscode/settings.json
vendored
|
@ -14,5 +14,9 @@
|
||||||
"hide-files.files": [
|
"hide-files.files": [
|
||||||
"tests/__pycache__",
|
"tests/__pycache__",
|
||||||
"api/__pycache__"
|
"api/__pycache__"
|
||||||
|
],
|
||||||
|
"python.analysis.extraPaths": [
|
||||||
|
".",
|
||||||
|
"./api/db"
|
||||||
]
|
]
|
||||||
}
|
}
|
12
api/core.py
12
api/core.py
|
@ -4,7 +4,7 @@ import os
|
||||||
import json
|
import json
|
||||||
import fastapi
|
import fastapi
|
||||||
|
|
||||||
from db import users
|
from users import UserManager
|
||||||
|
|
||||||
from dhooks import Webhook, Embed
|
from dhooks import Webhook, Embed
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
@ -29,10 +29,12 @@ async def check_core_auth(request):
|
||||||
async def get_users(discord_id: int, incoming_request: fastapi.Request):
|
async def get_users(discord_id: int, incoming_request: fastapi.Request):
|
||||||
auth = await check_core_auth(incoming_request)
|
auth = await check_core_auth(incoming_request)
|
||||||
if auth:
|
if auth:
|
||||||
return auth_error
|
return auth
|
||||||
|
|
||||||
# Get user by discord ID
|
# Get user by discord ID
|
||||||
if not await users.by_discord_id(discord_id):
|
manager = UserManager()
|
||||||
|
user = await manager.user_by_discord_id(discord_id)
|
||||||
|
if not user:
|
||||||
return fastapi.Response(status_code=404, content='User not found.')
|
return fastapi.Response(status_code=404, content='User not found.')
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
@ -64,7 +66,9 @@ async def create_user(incoming_request: fastapi.Request):
|
||||||
except (json.decoder.JSONDecodeError, AttributeError):
|
except (json.decoder.JSONDecodeError, AttributeError):
|
||||||
return fastapi.Response(status_code=400, content='Invalid or no payload received.')
|
return fastapi.Response(status_code=400, content='Invalid or no payload received.')
|
||||||
|
|
||||||
user = await users.create(discord_id)
|
# Create the user
|
||||||
|
manager = UserManager()
|
||||||
|
user = await manager.create(discord_id)
|
||||||
await new_user_webhook(user)
|
await new_user_webhook(user)
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
|
@ -8,16 +8,9 @@ from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
## MONGODB Setup
|
|
||||||
|
|
||||||
conn = AsyncIOMotorClient(os.getenv('MONGO_URI'))
|
|
||||||
|
|
||||||
async def _get_collection(collection_name: str):
|
|
||||||
return conn['nova-core'][collection_name]
|
|
||||||
|
|
||||||
## Statistics
|
## Statistics
|
||||||
|
|
||||||
class Stats:
|
class StatsManager:
|
||||||
"""
|
"""
|
||||||
### The manager for all statistics tracking
|
### The manager for all statistics tracking
|
||||||
Stats tracked:
|
Stats tracked:
|
||||||
|
@ -28,37 +21,44 @@ class Stats:
|
||||||
- Models
|
- Models
|
||||||
- URL Paths
|
- URL Paths
|
||||||
"""
|
"""
|
||||||
async def add_date():
|
|
||||||
|
def __init__(self):
|
||||||
|
self.conn = AsyncIOMotorClient(os.getenv('MONGO_URI'))
|
||||||
|
|
||||||
|
async def _get_collection(self, collection_name: str):
|
||||||
|
return self.conn['nova-core'][collection_name]
|
||||||
|
|
||||||
|
async def add_date(self):
|
||||||
date = datetime.datetime.now(pytz.timezone('GMT')).strftime('%Y.%m.%d')
|
date = datetime.datetime.now(pytz.timezone('GMT')).strftime('%Y.%m.%d')
|
||||||
year, month, day = date.split('.')
|
year, month, day = date.split('.')
|
||||||
|
|
||||||
db = await _get_collection('stats')
|
db = await self._get_collection('stats')
|
||||||
await db.update_one({}, {'$inc': {f'dates.{year}.{month}.{day}': 1}}, upsert=True)
|
await db.update_one({}, {'$inc': {f'dates.{year}.{month}.{day}': 1}}, upsert=True)
|
||||||
|
|
||||||
async def add_ip_address(ip_address: str):
|
async def add_ip_address(self, ip_address: str):
|
||||||
ip_address = ip_address.replace('.', '_')
|
ip_address = ip_address.replace('.', '_')
|
||||||
db = await _get_collection('stats')
|
db = await self._get_collection('stats')
|
||||||
await db.update_one({}, {'$inc': {f'ips.{ip_address}': 1}}, upsert=True)
|
await db.update_one({}, {'$inc': {f'ips.{ip_address}': 1}}, upsert=True)
|
||||||
|
|
||||||
async def add_target(url: str):
|
async def add_target(self, url: str):
|
||||||
db = await _get_collection('stats')
|
db = await self._get_collection('stats')
|
||||||
await db.update_one({}, {'$inc': {f'targets.{url}': 1}}, upsert=True)
|
await db.update_one({}, {'$inc': {f'targets.{url}': 1}}, upsert=True)
|
||||||
|
|
||||||
async def add_tokens(tokens: int, model: str):
|
async def add_tokens(self, tokens: int, model: str):
|
||||||
db = await _get_collection('stats')
|
db = await self._get_collection('stats')
|
||||||
await db.update_one({}, {'$inc': {f'tokens.{model}': tokens}}, upsert=True)
|
await db.update_one({}, {'$inc': {f'tokens.{model}': tokens}}, upsert=True)
|
||||||
|
|
||||||
async def add_model(model: str):
|
async def add_model(self, model: str):
|
||||||
db = await _get_collection('stats')
|
db = await self._get_collection('stats')
|
||||||
await db.update_one({}, {'$inc': {f'models.{model}': 1}}, upsert=True)
|
await db.update_one({}, {'$inc': {f'models.{model}': 1}}, upsert=True)
|
||||||
|
|
||||||
async def add_path(path: str):
|
async def add_path(self, path: str):
|
||||||
path = path.replace('/', '_')
|
path = path.replace('/', '_')
|
||||||
db = await _get_collection('stats')
|
db = await self._get_collection('stats')
|
||||||
await db.update_one({}, {'$inc': {f'paths.{path}': 1}}, upsert=True)
|
await db.update_one({}, {'$inc': {f'paths.{path}': 1}}, upsert=True)
|
||||||
|
|
||||||
async def get_value(obj_filter):
|
async def get_value(self, obj_filter):
|
||||||
db = await _get_collection('stats')
|
db = await self._get_collection('stats')
|
||||||
return await db.find_one({obj_filter})
|
return await db.find_one({obj_filter})
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -14,21 +14,31 @@ with open('config/credits.yml', encoding='utf8') as f:
|
||||||
|
|
||||||
## MONGODB Setup
|
## MONGODB Setup
|
||||||
|
|
||||||
conn = AsyncIOMotorClient(os.getenv('MONGO_URI'))
|
class UserManager:
|
||||||
|
"""
|
||||||
|
### Manager of all users in the database.
|
||||||
|
Following methods are available:
|
||||||
|
|
||||||
async def _get_collection(collection_name: str):
|
- `_get_collection(collection_name)`
|
||||||
return conn['nova-core'][collection_name]
|
- `create(discord_id)`
|
||||||
|
- `user_by_id(user_id)`
|
||||||
async def create(discord_id: str='') -> dict:
|
- `user_by_discord_id(discord_id)`
|
||||||
"""Add a user to the mongodb
|
- `user_by_api_key(api_key)`
|
||||||
|
- `update_by_id(user_id, new_obj)`
|
||||||
Args:
|
- `update_by_filter(filter_object, new_obj)`
|
||||||
discord_id (str): Defaults to ''.
|
- `delete(user_id)`
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: The user object
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.conn = AsyncIOMotorClient(os.getenv('MONGO_URI'))
|
||||||
|
|
||||||
|
async def _get_collection(self, collection_name: str):
|
||||||
|
return self.conn['nova-core'][collection_name]
|
||||||
|
|
||||||
|
async def get_all_users(self):
|
||||||
|
return self.conn['nova-core']['users']
|
||||||
|
|
||||||
|
async def create(self, discord_id: str = '') -> dict:
|
||||||
chars = string.ascii_letters + string.digits
|
chars = string.ascii_letters + string.digits
|
||||||
|
|
||||||
infix = os.getenv('KEYGEN_INFIX')
|
infix = os.getenv('KEYGEN_INFIX')
|
||||||
|
@ -52,37 +62,37 @@ async def create(discord_id: str='') -> dict:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
db = await _get_collection('users')
|
db = await self._get_collection('users')
|
||||||
await db.insert_one(new_user)
|
await db.insert_one(new_user)
|
||||||
user = await db.find_one({'api_key': new_api_key})
|
user = await db.find_one({'api_key': new_api_key})
|
||||||
return user
|
return user
|
||||||
|
|
||||||
async def by_id(user_id: str):
|
async def user_by_id(self, user_id: str):
|
||||||
db = await _get_collection('users')
|
db = await self._get_collection('users')
|
||||||
return await db.find_one({'_id': user_id})
|
return await db.find_one({'_id': user_id})
|
||||||
|
|
||||||
async def by_discord_id(discord_id: str):
|
async def user_by_discord_id(self, discord_id: str):
|
||||||
db = await _get_collection('users')
|
db = await self._get_collection('users')
|
||||||
return await db.find_one({'auth.discord': str(int(discord_id))})
|
return await db.find_one({'auth.discord': str(int(discord_id))})
|
||||||
|
|
||||||
async def by_api_key(key: str):
|
async def user_by_api_key(self, key: str):
|
||||||
db = await _get_collection('users')
|
db = await self._get_collection('users')
|
||||||
return await db.find_one({'api_key': key})
|
return await db.find_one({'api_key': key})
|
||||||
|
|
||||||
async def update_by_id(user_id: str, update):
|
async def update_by_id(self, user_id: str, update):
|
||||||
db = await _get_collection('users')
|
db = await self._get_collection('users')
|
||||||
return await db.update_one({'_id': user_id}, update)
|
return await db.update_one({'_id': user_id}, update)
|
||||||
|
|
||||||
async def update_by_filter(obj_filter, update):
|
async def update_by_filter(self, obj_filter, update):
|
||||||
db = await _get_collection('users')
|
db = await self._get_collection('users')
|
||||||
return await db.update_one(obj_filter, update)
|
return await db.update_one(obj_filter, update)
|
||||||
|
|
||||||
async def delete(user_id: str):
|
async def delete(self, user_id: str):
|
||||||
db = await _get_collection('users')
|
db = await self._get_collection('users')
|
||||||
await db.delete_one({'_id': user_id})
|
await db.delete_one({'_id': user_id})
|
||||||
|
|
||||||
async def demo():
|
async def demo():
|
||||||
user = await create(69420)
|
user = await UserManager().create(69420)
|
||||||
print(user)
|
print(user)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -15,8 +15,9 @@ import proxies
|
||||||
import provider_auth
|
import provider_auth
|
||||||
import load_balancing
|
import load_balancing
|
||||||
|
|
||||||
from db import logs, users
|
from db import logs
|
||||||
from db.stats import Stats
|
from db.users import UserManager
|
||||||
|
from db.stats import StatsManager
|
||||||
from helpers import network, chat, errors
|
from helpers import network, chat, errors
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
@ -76,7 +77,8 @@ async def stream(
|
||||||
input_tokens (int, optional): Total tokens calculated with tokenizer. Defaults to 0.
|
input_tokens (int, optional): Total tokens calculated with tokenizer. Defaults to 0.
|
||||||
incoming_request (starlette.requests.Request, optional): Incoming request. Defaults to None.
|
incoming_request (starlette.requests.Request, optional): Incoming request. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
db = UserManager()
|
||||||
|
stats = StatsManager()
|
||||||
is_chat = False
|
is_chat = False
|
||||||
is_stream = payload.get('stream', False)
|
is_stream = payload.get('stream', False)
|
||||||
|
|
||||||
|
@ -175,16 +177,16 @@ async def stream(
|
||||||
await logs.log_api_request(user=user, incoming_request=incoming_request, target_url=target_request['url'])
|
await logs.log_api_request(user=user, incoming_request=incoming_request, target_url=target_request['url'])
|
||||||
|
|
||||||
if credits_cost and user:
|
if credits_cost and user:
|
||||||
await users.update_by_id(user['_id'], {'$inc': {'credits': -credits_cost}})
|
await db.update_by_id(user['_id'], {'$inc': {'credits': -credits_cost}})
|
||||||
|
|
||||||
ip_address = await network.get_ip(incoming_request)
|
ip_address = await network.get_ip(incoming_request)
|
||||||
await Stats.add_date()
|
await stats.add_date()
|
||||||
await Stats.add_ip_address(ip_address)
|
await stats.add_ip_address(ip_address)
|
||||||
await Stats.add_path(path)
|
await stats.add_path(path)
|
||||||
await Stats.add_target(target_request['url'])
|
await stats.add_target(target_request['url'])
|
||||||
if is_chat:
|
if is_chat:
|
||||||
await Stats.add_model(model)
|
await stats.add_model(model)
|
||||||
await Stats.add_tokens(input_tokens, model)
|
await stats.add_tokens(input_tokens, model)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
asyncio.run(stream())
|
asyncio.run(stream())
|
||||||
|
|
|
@ -9,7 +9,7 @@ from dotenv import load_dotenv
|
||||||
import streaming
|
import streaming
|
||||||
import moderation
|
import moderation
|
||||||
|
|
||||||
from db import users
|
from users import UserManager
|
||||||
from helpers import tokens, errors
|
from helpers import tokens, errors
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
@ -25,7 +25,7 @@ async def handle(incoming_request):
|
||||||
Takes the request from the incoming request to the target endpoint.
|
Takes the request from the incoming request to the target endpoint.
|
||||||
Checks method, token amount, auth and cost along with if request is NSFW.
|
Checks method, token amount, auth and cost along with if request is NSFW.
|
||||||
"""
|
"""
|
||||||
|
users = UserManager()
|
||||||
path = incoming_request.url.path.replace('v1/v1/', 'v1/')
|
path = incoming_request.url.path.replace('v1/v1/', 'v1/')
|
||||||
|
|
||||||
allowed_methods = {'GET', 'POST', 'PUT', 'DELETE', 'PATCH'}
|
allowed_methods = {'GET', 'POST', 'PUT', 'DELETE', 'PATCH'}
|
||||||
|
@ -46,7 +46,7 @@ async def handle(incoming_request):
|
||||||
if not received_key or not received_key.startswith('Bearer '):
|
if not received_key or not received_key.startswith('Bearer '):
|
||||||
return await errors.error(401, 'No NovaAI API key given!', 'Add "Authorization: Bearer nv-..." to your request headers.')
|
return await errors.error(401, 'No NovaAI API key given!', 'Add "Authorization: Bearer nv-..." to your request headers.')
|
||||||
|
|
||||||
user = await users.by_api_key(received_key.split('Bearer ')[1].strip())
|
user = await users.user_by_api_key(received_key.split('Bearer ')[1].strip())
|
||||||
|
|
||||||
if not user or not user['status']['active']:
|
if not user or not user['status']['active']:
|
||||||
return await errors.error(401, 'Invalid or inactive NovaAI API key!', 'Create a new NovaOSS API key or reactivate your account.')
|
return await errors.error(401, 'Invalid or inactive NovaAI API key!', 'Create a new NovaOSS API key or reactivate your account.')
|
||||||
|
|
1
api/types/user.py
Normal file
1
api/types/user.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
## TODO: Create user type out of JSON object.
|
|
@ -1,10 +1,9 @@
|
||||||
async def get_all_users(client):
|
from users import UserManager
|
||||||
users = client['nova-core']['users']
|
|
||||||
return users
|
|
||||||
|
|
||||||
|
|
||||||
async def update_credits(pymongo_client, settings=None):
|
async def update_credits(pymongo_client, settings=None):
|
||||||
users = await get_all_users(pymongo_client)
|
manager = UserManager()
|
||||||
|
users = await manager.get_all_users(pymongo_client)
|
||||||
|
|
||||||
if not settings:
|
if not settings:
|
||||||
users.update_many({}, {'$inc': {'credits': 2500}})
|
users.update_many({}, {'$inc': {'credits': 2500}})
|
||||||
|
|
Loading…
Reference in a new issue