diff --git a/api/core.py b/api/core.py index b763ed3..2ed74e3 100644 --- a/api/core.py +++ b/api/core.py @@ -32,7 +32,7 @@ async def get_users(discord_id: int, incoming_request: fastapi.Request): return user -def new_user_webhook(user: dict) -> None: +async def new_user_webhook(user: dict) -> None: dhook = Webhook(os.getenv('DISCORD_WEBHOOK__USER_CREATED')) embed = Embed( @@ -40,7 +40,7 @@ def new_user_webhook(user: dict) -> None: color=0x90ee90, ) - embed.add_field(name='ID', value=user['_id'], inline=False) + embed.add_field(name='ID', value=str(user['_id']), inline=False) embed.add_field(name='Discord', value=user['auth']['discord']) embed.add_field(name='Github', value=user['auth']['github']) @@ -60,15 +60,17 @@ async def create_user(incoming_request: fastapi.Request): return fastapi.Response(status_code=400, content='Invalid or no payload received.') user = await users.create(discord_id) - new_user_webhook(user) + await new_user_webhook(user) return user if __name__ == '__main__': - new_user_webhook({ - '_id': 'JUST_A_TEST_IGNORE_ME', - 'auth': { - 'discord': 123, - 'github': 'abc' - } - }) + # new_user_webhook({ + # '_id': 'JUST_A_TEST_IGNORE_ME', + # 'auth': { + # 'discord': 123, + # 'github': 'abc' + # } + # }) + + pass diff --git a/api/db/logs.py b/api/db/logs.py index 9a79286..932370b 100644 --- a/api/db/logs.py +++ b/api/db/logs.py @@ -8,10 +8,24 @@ from helpers import network load_dotenv() -def _get_mongo(collection_name: str): +UA_SIMPLIFY = { + 'Windows NT': 'W', + 'Mozilla/5.0': 'M', + 'Win64; x64': '64', + 'Safari/537.36': 'S', + 'AppleWebKit/537.36 (KHTML, like Gecko)': 'K', +} + +async def _get_mongo(collection_name: str): return AsyncIOMotorClient(os.getenv('MONGO_URI'))['nova-core'][collection_name] +async def replacer(text: str, dict_: dict) -> str: + for k, v in dict_.items(): + text = text.replace(k, v) + return text + async def log_api_request(user: dict, incoming_request, target_url: str): + db = await _get_mongo('logs') payload = {} try: @@ -22,19 +36,23 @@ async def log_api_request(user: dict, incoming_request, target_url: str): last_prompt = None if 'messages' in payload: - last_prompt = payload['messages'][-1]['content'] + last_prompt = payload['messages'][-1]['content'][:50] + + if len(last_prompt) == 50: + last_prompt += '...' model = payload.get('model') ip_address = await network.get_ip(incoming_request) + useragent = await replacer(incoming_request.headers.get('User-Agent'), UA_SIMPLIFY) new_log_item = { 'timestamp': time.time(), 'method': incoming_request.method, 'path': incoming_request.url.path, - 'user_id': user['_id'], + 'user_id': str(user['_id']), 'security': { 'ip': ip_address, - 'useragent': incoming_request.headers.get('User-Agent') + 'useragent': useragent, }, 'details': { 'model': model, @@ -43,21 +61,25 @@ async def log_api_request(user: dict, incoming_request, target_url: str): } } - inserted = await _get_mongo('logs').insert_one(new_log_item) - log_item = await _get_mongo('logs').find_one({'_id': inserted.inserted_id}) + inserted = await db.insert_one(new_log_item) + log_item = await db.find_one({'_id': inserted.inserted_id}) return log_item async def by_id(log_id: str): - return await _get_mongo('logs').find_one({'_id': log_id}) + db = await _get_mongo('logs') + return await db.find_one({'_id': log_id}) async def by_user_id(user_id: str): - return await _get_mongo('logs').find({'user_id': user_id}) + db = await _get_mongo('logs') + return await db.find({'user_id': user_id}) async def delete_by_id(log_id: str): - return await _get_mongo('logs').delete_one({'_id': log_id}) + db = await _get_mongo('logs') + return await db.delete_one({'_id': log_id}) async def delete_by_user_id(user_id: str): - return await _get_mongo('logs').delete_many({'user_id': user_id}) + db = await _get_mongo('logs') + return await db.delete_many({'user_id': user_id}) if __name__ == '__main__': pass diff --git a/api/db/stats.py b/api/db/stats.py index d9d5edc..795b09f 100644 --- a/api/db/stats.py +++ b/api/db/stats.py @@ -8,34 +8,41 @@ from motor.motor_asyncio import AsyncIOMotorClient load_dotenv() -def _get_mongo(collection_name: str): +async def _get_mongo(collection_name: str): return AsyncIOMotorClient(os.getenv('MONGO_URI'))['nova-core'][collection_name] async def add_date(): date = datetime.datetime.now(pytz.timezone('GMT')).strftime('%Y.%m.%d') year, month, day = date.split('.') - await _get_mongo('stats').update_one({}, {'$inc': {f'dates.{year}.{month}.{day}': 1}}, upsert=True) + db = await _get_mongo('stats') + await db.update_one({}, {'$inc': {f'dates.{year}.{month}.{day}': 1}}, upsert=True) async def add_ip_address(ip_address: str): ip_address = ip_address.replace('.', '_') - await _get_mongo('stats').update_one({}, {'$inc': {f'ips.{ip_address}': 1}}, upsert=True) + db = await _get_mongo('stats') + await db.update_one({}, {'$inc': {f'ips.{ip_address}': 1}}, upsert=True) async def add_target(url: str): - await _get_mongo('stats').update_one({}, {'$inc': {f'targets.{url}': 1}}, upsert=True) + db = await _get_mongo('stats') + await db.update_one({}, {'$inc': {f'targets.{url}': 1}}, upsert=True) async def add_tokens(tokens: int, model: str): - await _get_mongo('stats').update_one({}, {'$inc': {f'tokens.{model}': tokens}}, upsert=True) + db = await _get_mongo('stats') + await db.update_one({}, {'$inc': {f'tokens.{model}': tokens}}, upsert=True) async def add_model(model: str): - await _get_mongo('stats').update_one({}, {'$inc': {f'models.{model}': 1}}, upsert=True) + db = await _get_mongo('stats') + await db.update_one({}, {'$inc': {f'models.{model}': 1}}, upsert=True) async def add_path(path: str): path = path.replace('/', '_') - await _get_mongo('stats').update_one({}, {'$inc': {f'paths.{path}': 1}}, upsert=True) + db = await _get_mongo('stats') + await db.update_one({}, {'$inc': {f'paths.{path}': 1}}, upsert=True) async def get_value(obj_filter): - return await _get_mongo('stats').find_one({obj_filter}) + db = await _get_mongo('stats') + return await db.find_one({obj_filter}) if __name__ == '__main__': asyncio.run(add_date()) diff --git a/api/db/users.py b/api/db/users.py index b2c3e31..5003250 100644 --- a/api/db/users.py +++ b/api/db/users.py @@ -12,7 +12,7 @@ load_dotenv() with open('config/credits.yml', encoding='utf8') as f: credits_config = yaml.safe_load(f) -def _get_mongo(collection_name: str): +async def _get_mongo(collection_name: str): return AsyncIOMotorClient(os.getenv('MONGO_URI'))['nova-core'][collection_name] async def create(discord_id: int=0) -> dict: @@ -46,23 +46,28 @@ async def create(discord_id: int=0) -> dict: return user async def by_id(user_id: str): - return await _get_mongo('users').find_one({'_id': user_id}) + db = await _get_mongo('users') + return await db.find_one({'_id': user_id}) async def by_discord_id(discord_id: str): - return await _get_mongo('users').find_one({'auth.discord': discord_id}) + db = await _get_mongo('users') + return await db.find_one({'auth.discord': discord_id}) async def by_api_key(key: str): - return await _get_mongo('users').find_one({'api_key': key}) + db = await _get_mongo('users') + return await db.find_one({'api_key': key}) async def update_by_id(user_id: str, update): - return await _get_mongo('users').update_one({'_id': user_id}, update) + db = await _get_mongo('users') + return await db.update_one({'_id': user_id}, update) async def update_by_filter(obj_filter, update): - return await _get_mongo('users').update_one(obj_filter, update) + db = await _get_mongo('users') + return await db.update_one(obj_filter, update) async def delete(user_id: str): - await _get_mongo('users').delete_one({'_id': user_id}) - + db = await _get_mongo('users') + await db.delete_one({'_id': user_id}) async def demo(): user = await create(69420) diff --git a/api/helpers/chat.py b/api/helpers/chat.py index 7dccb24..a1002fc 100644 --- a/api/helpers/chat.py +++ b/api/helpers/chat.py @@ -17,7 +17,7 @@ async def create_chat_id() -> str: return f'chatcmpl-{chat_id}' -def create_chat_chunk(chat_id: str, model: str, content=None) -> dict: +async def create_chat_chunk(chat_id: str, model: str, content=None) -> dict: content = content or {} delta = {} diff --git a/api/helpers/errors.py b/api/helpers/errors.py index d3846db..de4f3e5 100644 --- a/api/helpers/errors.py +++ b/api/helpers/errors.py @@ -1,7 +1,7 @@ import json import starlette -def error(code: int, message: str, tip: str) -> starlette.responses.Response: +async def error(code: int, message: str, tip: str) -> starlette.responses.Response: info = {'error': { 'code': code, 'message': message, @@ -12,7 +12,7 @@ def error(code: int, message: str, tip: str) -> starlette.responses.Response: return starlette.responses.Response(status_code=code, content=json.dumps(info)) -def yield_error(code: int, message: str, tip: str) -> str: +async def yield_error(code: int, message: str, tip: str) -> str: return json.dumps({ 'code': code, 'message': message, diff --git a/api/helpers/tokens.py b/api/helpers/tokens.py index 3791958..e1aede6 100644 --- a/api/helpers/tokens.py +++ b/api/helpers/tokens.py @@ -1,6 +1,6 @@ import tiktoken -def count_for_messages(messages: list, model: str='gpt-3.5-turbo-0613') -> int: +async def count_for_messages(messages: list, model: str='gpt-3.5-turbo-0613') -> int: """Return the number of tokens used by a list of messages.""" try: diff --git a/api/load_balancing.py b/api/load_balancing.py index c6d30bd..0755449 100644 --- a/api/load_balancing.py +++ b/api/load_balancing.py @@ -11,13 +11,15 @@ provider_modules = [ providers.closed4 ] -def _get_module_name(module) -> str: +async def _get_module_name(module) -> str: name = module.__name__ if '.' in name: return name.split('.')[-1] return name async def balance_chat_request(payload: dict) -> dict: + """Load balance the chat completion request between chat providers.""" + providers_available = [] for provider_module in provider_modules: @@ -34,20 +36,37 @@ async def balance_chat_request(payload: dict) -> dict: provider = random.choice(providers_available) target = provider.chat_completion(**payload) - target['module'] = _get_module_name(provider) + + module_name = await _get_module_name(provider) + target['module'] = module_name return target async def balance_organic_request(request: dict) -> dict: + """Load balnace to non-chat completion request between other "organic" providers which respond in the desired format already.""" + providers_available = [] + if not request.get('headers'): + request['headers'] = { + 'Content-Type': 'application/json' + } + for provider_module in provider_modules: - if provider_module.ORGANIC: - providers_available.append(provider_module) + if not provider_module.ORGANIC: + continue + + if '/moderations' in request['path']: + if not provider_module.MODERATIONS: + continue + + providers_available.append(provider_module) provider = random.choice(providers_available) target = provider.organify(request) - target['module'] = _get_module_name(provider) + + module_name = await _get_module_name(provider) + target['module'] = module_name return target diff --git a/api/moderation.py b/api/moderation.py index a885682..be9add7 100644 --- a/api/moderation.py +++ b/api/moderation.py @@ -1,18 +1,49 @@ -import os import asyncio -import openai as closedai -from typing import Union -from dotenv import load_dotenv +import aiohttp +import proxies +import load_balancing -load_dotenv() +async def is_safe(inp) -> bool: + text = inp -closedai.api_key = os.getenv('LEGIT_CLOSEDAI_KEY') + if isinstance(inp, list): + text = '' + if isinstance(inp[0], dict): + for msg in inp: + text += msg['content'] + '\n' -async def is_safe(text: Union[str, list]) -> bool: - return closedai.Moderation.create( - input=text, - )['results'][0]['flagged'] + else: + text = '\n'.join(inp) + + for _ in range(3): + req = await load_balancing.balance_organic_request( + { + 'path': '/v1/moderations', + 'payload': {'input': text} + } + ) + + async with aiohttp.ClientSession(connector=proxies.default_proxy.connector) as session: + try: + async with session.request( + method=req.get('method', 'POST'), + url=req['url'], + data=req.get('data'), + json=req.get('payload'), + headers=req.get('headers'), + cookies=req.get('cookies'), + ssl=False, + timeout=aiohttp.ClientTimeout(total=5), + ) as res: + res.raise_for_status() + + json_response = await res.json() + + return not json_response['results'][0]['flagged'] + except Exception as exc: + print('[!] moderation error:', type(exc), exc) + continue if __name__ == '__main__': - asyncio.run(is_safe('Hello')) + print(asyncio.run(is_safe('I wanna kill myself'))) diff --git a/api/streaming.py b/api/streaming.py index 3d627ee..aeb109a 100644 --- a/api/streaming.py +++ b/api/streaming.py @@ -49,19 +49,26 @@ async def stream( if is_chat and is_stream: chat_id = await chat.create_chat_id() - yield chat.create_chat_chunk( + chunk = await chat.create_chat_chunk( chat_id=chat_id, model=model, content=chat.CompletionStart ) + yield chunk - yield chat.create_chat_chunk( + chunk = await chat.create_chat_chunk( chat_id=chat_id, model=model, content=None ) - for _ in range(3): + yield chunk + + json_response = { + 'error': 'No JSON response could be received' + } + + for _ in range(5): headers = { 'Content-Type': 'application/json' } @@ -81,18 +88,17 @@ async def stream( webhook = dhooks.Webhook(os.getenv('DISCORD_WEBHOOK__API_ISSUE')) webhook.send(content=f'API Issue: **`{exc}`**\nhttps://i.imgflip.com/7uv122.jpg') - yield errors.yield_error( + error = errors.yield_error( 500, 'Sorry, the API has no working keys anymore.', 'The admins have been messaged automatically.' ) + yield error return for k, v in target_request.get('headers', {}).items(): headers[k] = v - json.dump(target_request, open('api.log.json', 'w'), indent=4) - async with aiohttp.ClientSession(connector=proxies.default_proxy.connector) as session: try: async with session.request( @@ -109,52 +115,38 @@ async def stream( timeout=aiohttp.ClientTimeout(total=float(os.getenv('TRANSFER_TIMEOUT', '120'))), ) as response: - print(5) + + if not is_stream: + json_response = await response.json() try: response.raise_for_status() except Exception as exc: if 'Too Many Requests' in str(exc): - print(429) continue - if user and incoming_request: - await logs.log_api_request( - user=user, - incoming_request=incoming_request, - target_url=target_request['url'] - ) - - if credits_cost and user: - await users.update_by_id(user['_id'], { - '$inc': {'credits': -credits_cost} - }) - - print(6) - if is_stream: try: async for chunk in response.content.iter_any(): send = False chunk = f'{chunk.decode("utf8")}\n\n' chunk = chunk.replace(os.getenv('MAGIC_WORD', 'novaOSScheckKeyword'), payload['model']) - # chunk = chunk.replace(os.getenv('MAGIC_USER_WORD', 'novaOSSuserKeyword'), user['_id']) - print(chunk) + chunk = chunk.replace(os.getenv('MAGIC_USER_WORD', 'novaOSSuserKeyword'), str(user['_id'])) if not chunk.strip(): send = False if is_chat and '{' in chunk: data = json.loads(chunk.split('data: ')[1]) + chunk = chunk.replace(data['id'], chat_id) send = True if target_request['module'] == 'twa' and data.get('text'): - chunk = chat.create_chat_chunk( + chunk = await chat.create_chat_chunk( chat_id=chat_id, model=model, content=['text'] ) - if not data['choices'][0]['delta']: send = False @@ -162,24 +154,18 @@ async def stream( send = False if send: - yield chunk + final_chunk = chunk.strip().replace('data: [DONE]', '') + '\n\n' + yield final_chunk except Exception as exc: if 'Connection closed' in str(exc): - print('connection closed: ', exc) - continue - - if not demo_mode: - ip_address = await network.get_ip(incoming_request) - - await stats.add_date() - await stats.add_ip_address(ip_address) - await stats.add_path(path) - await stats.add_target(target_request['url']) - - if is_chat: - await stats.add_model(model) - await stats.add_tokens(input_tokens, model) + error = errors.yield_error( + 500, + 'Sorry, there was an issue with the connection.', + 'Please first check if the issue on your end. If this error repeats, please don\'t heistate to contact the staff!.' + ) + yield error + return break @@ -187,21 +173,43 @@ async def stream( print('proxy error') continue - print(3) - if is_chat and is_stream: - chat_chunk = chat.create_chat_chunk( + chunk = await chat.create_chat_chunk( chat_id=chat_id, model=model, content=chat.CompletionStop ) - data = json.dumps(chat_chunk) + yield chunk yield 'data: [DONE]\n\n' if not is_stream: - json_response = await response.json() - yield json_response.encode('utf8') + yield json.dumps(json_response) + + # DONE ========================================================= + + if user and incoming_request: + await logs.log_api_request( + user=user, + incoming_request=incoming_request, + target_url=target_request['url'] + ) + + if credits_cost and user: + await users.update_by_id(user['_id'], { + '$inc': {'credits': -credits_cost} + }) + + ip_address = await network.get_ip(incoming_request) + + await stats.add_date() + await stats.add_ip_address(ip_address) + await stats.add_path(path) + await stats.add_target(target_request['url']) + + if is_chat: + await stats.add_model(model) + await stats.add_tokens(input_tokens, model) if __name__ == '__main__': asyncio.run(stream()) diff --git a/api/transfer.py b/api/transfer.py index 58b7fb7..f1fe040 100644 --- a/api/transfer.py +++ b/api/transfer.py @@ -4,11 +4,13 @@ import os import json import yaml import logging +import fastapi import starlette from dotenv import load_dotenv import streaming +import moderation from db import logs, users from helpers import tokens, errors, exceptions @@ -32,7 +34,8 @@ async def handle(incoming_request): # METHOD if incoming_request.method not in ['GET', 'POST', 'PUT', 'DELETE', 'PATCH']: - return errors.error(405, f'Method "{incoming_request.method}" is not allowed.', 'Change the request method to the correct one.') + error = await errors.error(405, f'Method "{incoming_request.method}" is not allowed.', 'Change the request method to the correct one.') + return error # PAYLOAD try: @@ -42,7 +45,7 @@ async def handle(incoming_request): # TOKENS try: - input_tokens = tokens.count_for_messages(payload['messages']) + input_tokens = await tokens.count_for_messages(payload['messages']) except (KeyError, TypeError): input_tokens = 0 @@ -50,7 +53,8 @@ async def handle(incoming_request): received_key = incoming_request.headers.get('Authorization') if not received_key: - return errors.error(401, 'No NovaAI API key given!', 'Add "Authorization: Bearer nv-..." to your request headers.') + error = await errors.error(401, 'No NovaAI API key given!', 'Add "Authorization: Bearer nv-..." to your request headers.') + return error if received_key.startswith('Bearer '): received_key = received_key.split('Bearer ')[1] @@ -59,38 +63,60 @@ async def handle(incoming_request): user = await users.by_api_key(received_key.strip()) if not user: - return errors.error(401, 'Invalid NovaAI API key!', 'Create a new NovaOSS API key.') + error = await errors.error(401, 'Invalid NovaAI API key!', 'Create a new NovaOSS API key.') + return error ban_reason = user['status']['ban_reason'] if ban_reason: - return errors.error(403, f'Your NovaAI account has been banned. Reason: "{ban_reason}".', 'Contact the staff for an appeal.') + error = await errors.error(403, f'Your NovaAI account has been banned. Reason: "{ban_reason}".', 'Contact the staff for an appeal.') + return error if not user['status']['active']: - return errors.error(418, 'Your NovaAI account is not active (paused).', 'Simply re-activate your account using a Discord command or the web panel.') + error = await errors.error(418, 'Your NovaAI account is not active (paused).', 'Simply re-activate your account using a Discord command or the web panel.') + return error # COST costs = credits_config['costs'] cost = costs['other'] + is_safe = True + if 'chat/completions' in path: for model_name, model_cost in costs['chat-models'].items(): if model_name in payload['model']: cost = model_cost + is_safe = await moderation.is_safe(payload['messages']) + + else: + inp = payload.get('input', payload.get('prompt')) + + if inp and not '/moderations' in path: + is_safe = await moderation.is_safe(inp) + + if not is_safe: + error = await errors.error(400, 'The request contains content which violates this model\'s policies.', 'We currently don\'t support any NSFW models.') + return error + + return + role_cost_multiplier = credits_config['bonuses'].get(user['role'], 1) cost = round(cost * role_cost_multiplier) if user['credits'] < cost: - return errors.error(429, 'Not enough credits.', 'Wait or earn more credits. Learn more on our website or Discord server.') + error = await errors.error(429, 'Not enough credits.', 'Wait or earn more credits. Learn more on our website or Discord server.') + return error # READY - payload['user'] = str(user['_id']) + # payload['user'] = str(user['_id']) if 'chat/completions' in path and not payload.get('stream') is True: payload['stream'] = False - return starlette.responses.StreamingResponse( + media_type = 'text/event-stream' if payload.get('stream', False) else 'application/json' + + return fastapi.responses.StreamingResponse( content=streaming.stream( user=user, path=path, @@ -99,5 +125,5 @@ async def handle(incoming_request): input_tokens=input_tokens, incoming_request=incoming_request, ), - media_type='text/event-stream' if payload.get('stream', False) else 'application/json' + media_type=media_type ) diff --git a/credit_management/autocredits.py b/rewardsystem/autocredits.py similarity index 76% rename from credit_management/autocredits.py rename to rewardsystem/autocredits.py index e30fb23..10afb91 100644 --- a/credit_management/autocredits.py +++ b/rewardsystem/autocredits.py @@ -2,11 +2,11 @@ async def get_all_users(client): users = client['nova-core']['users'] return users -async def update_credits(users, settings = None): +async def update_credits(users, settings=None): if not settings: users.update_many({}, {"$inc": {"credits": 250}}) else: for key, value in settings.items(): users.update_many({'role': key}, {"$inc": {"credits": int(value)}}) - print(f"Updated {key} to {value}") \ No newline at end of file + print(f'Updated {key} to {value}') diff --git a/credit_management/main.py b/rewardsystem/main.py similarity index 86% rename from credit_management/main.py rename to rewardsystem/main.py index f6a0148..92435f3 100644 --- a/credit_management/main.py +++ b/rewardsystem/main.py @@ -1,11 +1,13 @@ import asyncio -from settings import roles import autocredits import aiohttp -from dotenv import load_dotenv import os import pymongo +from settings import roles + +from dotenv import load_dotenv + load_dotenv() CONNECTION_STRING = os.getenv("CONNECTION_STRING") @@ -23,10 +25,9 @@ async def update_roles(users): async with session.get('http://localhost:50000/get_roles') as response: data = await response.json() except aiohttp.ClientError as e: - print(f"Error: {e}") - return - - lvlroles = [f"lvl{lvl}" for lvl in range(10, 110, 10)] + [''] + raise ValueError('Could not get roles') from exc + + lvlroles = [f'lvl{lvl}' for lvl in range(10, 110, 10)] + [''] discord_users = data users = await autocredits.get_all_users(pymongo_client) @@ -41,11 +42,12 @@ async def update_roles(users): for role in lvlroles: if role in roles: bulk_updates.append(pymongo.UpdateOne({'auth.discord': int(discord)}, {'$set': {'role': role}})) - print(f"Updated {id_} to {role}") + print(f'Updated {id_} to {role}') break + if bulk_updates: with pymongo_client: users.bulk_write(bulk_updates) if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/credit_management/role_bot.py b/rewardsystem/role_bot.py similarity index 100% rename from credit_management/role_bot.py rename to rewardsystem/role_bot.py diff --git a/credit_management/settings.py b/rewardsystem/settings.py similarity index 100% rename from credit_management/settings.py rename to rewardsystem/settings.py diff --git a/tests/__main__.py b/tests/__main__.py index 7df796e..0a9f134 100644 --- a/tests/__main__.py +++ b/tests/__main__.py @@ -23,7 +23,7 @@ MODEL = 'gpt-3.5-turbo' MESSAGES = [ { 'role': 'user', - 'content': '1+1=', + 'content': 'fuck you', } ] @@ -66,15 +66,10 @@ def test_library(): completion = closedai.ChatCompletion.create( model=MODEL, - messages=MESSAGES, - stream=True + messages=MESSAGES ) - for event in completion: - try: - print(event['choices'][0]['delta']['content']) - except: - print('-') + return completion['choices'][0]['message']['content'] def test_library_moderation(): return closedai.Moderation.create("I wanna kill myself, I wanna kill myself; It's all I hear right now, it's all I hear right now") @@ -83,8 +78,8 @@ def test_all(): """Runs all tests.""" # print(test_server()) - print(test_api()) - # print(test_library()) + # print(test_api()) + print(test_library()) # print(test_library_moderation())